diff --git a/README.md b/README.md index 079aebb06..9231eeee5 100644 --- a/README.md +++ b/README.md @@ -12,10 +12,119 @@ The OpenAI Agents SDK is a lightweight yet powerful framework for building multi 1. [**Agents**](https://openai.github.io/openai-agents-python/agents): LLMs configured with instructions, tools, guardrails, and handoffs 2. [**Handoffs**](https://openai.github.io/openai-agents-python/handoffs/): A specialized tool call used by the Agents SDK for transferring control between agents 3. [**Guardrails**](https://openai.github.io/openai-agents-python/guardrails/): Configurable safety checks for input and output validation -4. [**Tracing**](https://openai.github.io/openai-agents-python/tracing/): Built-in tracking of agent runs, allowing you to view, debug and optimize your workflows +4. [**Sessions**](#sessions): Automatic conversation history management across agent runs +5. [**Tracing**](https://openai.github.io/openai-agents-python/tracing/): Built-in tracking of agent runs, allowing you to view, debug and optimize your workflows Explore the [examples](examples) directory to see the SDK in action, and read our [documentation](https://openai.github.io/openai-agents-python/) for more details. +## Sessions + +The Agents SDK provides built-in session memory to automatically maintain conversation history across multiple agent runs, eliminating the need to manually handle `.to_input_list()` between turns. + +### Quick start + +```python +from agents import Agent, Runner, SQLiteSession + +# Create agent +agent = Agent( + name="Assistant", + instructions="Reply very concisely.", +) + +# Create a session instance +session = SQLiteSession("conversation_123") + +# First turn +result = await Runner.run( + agent, + "What city is the Golden Gate Bridge in?", + session=session +) +print(result.final_output) # "San Francisco" + +# Second turn - agent automatically remembers previous context +result = await Runner.run( + agent, + "What state is it in?", + session=session +) +print(result.final_output) # "California" + +# Also works with synchronous runner +result = Runner.run_sync( + agent, + "What's the population?", + session=session +) +print(result.final_output) # "Approximately 39 million" +``` + +### Session options + +- **No memory** (default): No session memory when session parameter is omitted +- **`session: Session = DatabaseSession(...)`**: Use a Session instance to manage conversation history + +```python +from agents import Agent, Runner, SQLiteSession + +# Custom SQLite database file +session = SQLiteSession("user_123", "conversations.db") +agent = Agent(name="Assistant") + +# Different session IDs maintain separate conversation histories +result1 = await Runner.run( + agent, + "Hello", + session=session +) +result2 = await Runner.run( + agent, + "Hello", + session=SQLiteSession("user_456", "conversations.db") +) +``` + +### Custom session implementations + +You can implement your own session memory by creating a class that follows the `Session` protocol: + +```python +from agents.memory import Session +from typing import List + +class MyCustomSession: + """Custom session implementation following the Session protocol.""" + + def __init__(self, session_id: str): + self.session_id = session_id + # Your initialization here + + async def get_items(self, limit: int | None = None) -> List[dict]: + # Retrieve conversation history for the session + pass + + async def add_items(self, items: List[dict]) -> None: + # Store new items for the session + pass + + async def pop_item(self) -> dict | None: + # Remove and return the most recent item from the session + pass + + async def clear_session(self) -> None: + # Clear all items for the session + pass + +# Use your custom session +agent = Agent(name="Assistant") +result = await Runner.run( + agent, + "Hello", + session=MyCustomSession("my_session") +) +``` + ## Get started 1. Set up your Python environment diff --git a/docs/index.md b/docs/index.md index 8aef6574e..935c4be5b 100644 --- a/docs/index.md +++ b/docs/index.md @@ -5,6 +5,7 @@ The [OpenAI Agents SDK](https://github.com/openai/openai-agents-python) enables - **Agents**, which are LLMs equipped with instructions and tools - **Handoffs**, which allow agents to delegate to other agents for specific tasks - **Guardrails**, which enable the inputs to agents to be validated +- **Sessions**, which automatically maintains conversation history across agent runs In combination with Python, these primitives are powerful enough to express complex relationships between tools and agents, and allow you to build real-world applications without a steep learning curve. In addition, the SDK comes with built-in **tracing** that lets you visualize and debug your agentic flows, as well as evaluate them and even fine-tune models for your application. @@ -21,6 +22,7 @@ Here are the main features of the SDK: - Python-first: Use built-in language features to orchestrate and chain agents, rather than needing to learn new abstractions. - Handoffs: A powerful feature to coordinate and delegate between multiple agents. - Guardrails: Run input validations and checks in parallel to your agents, breaking early if the checks fail. +- Sessions: Automatic conversation history management across agent runs, eliminating manual state handling. - Function tools: Turn any Python function into a tool, with automatic schema generation and Pydantic-powered validation. - Tracing: Built-in tracing that lets you visualize, debug and monitor your workflows, as well as use the OpenAI suite of evaluation, fine-tuning and distillation tools. diff --git a/docs/ref/memory.md b/docs/ref/memory.md new file mode 100644 index 000000000..04a2258bf --- /dev/null +++ b/docs/ref/memory.md @@ -0,0 +1,8 @@ +# Memory + +::: agents.memory + + options: + members: + - Session + - SQLiteSession diff --git a/docs/running_agents.md b/docs/running_agents.md index f631cf46f..6898f5101 100644 --- a/docs/running_agents.md +++ b/docs/running_agents.md @@ -65,7 +65,9 @@ Calling any of the run methods can result in one or more agents running (and hen At the end of the agent run, you can choose what to show to the user. For example, you might show the user every new item generated by the agents, or just the final output. Either way, the user might then ask a followup question, in which case you can call the run method again. -You can use the base [`RunResultBase.to_input_list()`][agents.result.RunResultBase.to_input_list] method to get the inputs for the next turn. +### Manual conversation management + +You can manually manage conversation history using the [`RunResultBase.to_input_list()`][agents.result.RunResultBase.to_input_list] method to get the inputs for the next turn: ```python async def main(): @@ -84,6 +86,39 @@ async def main(): # California ``` +### Automatic conversation management with Sessions + +For a simpler approach, you can use [Sessions](sessions.md) to automatically handle conversation history without manually calling `.to_input_list()`: + +```python +from agents import Agent, Runner, SQLiteSession + +async def main(): + agent = Agent(name="Assistant", instructions="Reply very concisely.") + + # Create session instance + session = SQLiteSession("conversation_123") + + with trace(workflow_name="Conversation", group_id=thread_id): + # First turn + result = await Runner.run(agent, "What city is the Golden Gate Bridge in?", session=session) + print(result.final_output) + # San Francisco + + # Second turn - agent automatically remembers previous context + result = await Runner.run(agent, "What state is it in?", session=session) + print(result.final_output) + # California +``` + +Sessions automatically: + +- Retrieves conversation history before each run +- Stores new messages after each run +- Maintains separate conversations for different session IDs + +See the [Sessions documentation](sessions.md) for more details. + ## Exceptions The SDK raises exceptions in certain cases. The full list is in [`agents.exceptions`][]. As an overview: diff --git a/docs/scripts/translate_docs.py b/docs/scripts/translate_docs.py index 5dada2681..ac40b6fa8 100644 --- a/docs/scripts/translate_docs.py +++ b/docs/scripts/translate_docs.py @@ -266,7 +266,9 @@ def translate_single_source_file(file_path: str) -> None: def main(): parser = argparse.ArgumentParser(description="Translate documentation files") - parser.add_argument("--file", type=str, help="Specific file to translate (relative to docs directory)") + parser.add_argument( + "--file", type=str, help="Specific file to translate (relative to docs directory)" + ) args = parser.parse_args() if args.file: diff --git a/docs/sessions.md b/docs/sessions.md new file mode 100644 index 000000000..8906f0e6d --- /dev/null +++ b/docs/sessions.md @@ -0,0 +1,319 @@ +# Sessions + +The Agents SDK provides built-in session memory to automatically maintain conversation history across multiple agent runs, eliminating the need to manually handle `.to_input_list()` between turns. + +Sessions stores conversation history for a specific session, allowing agents to maintain context without requiring explicit manual memory management. This is particularly useful for building chat applications or multi-turn conversations where you want the agent to remember previous interactions. + +## Quick start + +```python +from agents import Agent, Runner, SQLiteSession + +# Create agent +agent = Agent( + name="Assistant", + instructions="Reply very concisely.", +) + +# Create a session instance with a session ID +session = SQLiteSession("conversation_123") + +# First turn +result = await Runner.run( + agent, + "What city is the Golden Gate Bridge in?", + session=session +) +print(result.final_output) # "San Francisco" + +# Second turn - agent automatically remembers previous context +result = await Runner.run( + agent, + "What state is it in?", + session=session +) +print(result.final_output) # "California" + +# Also works with synchronous runner +result = Runner.run_sync( + agent, + "What's the population?", + session=session +) +print(result.final_output) # "Approximately 39 million" +``` + +## How it works + +When session memory is enabled: + +1. **Before each run**: The runner automatically retrieves the conversation history for the session and prepends it to the input items. +2. **After each run**: All new items generated during the run (user input, assistant responses, tool calls, etc.) are automatically stored in the session. +3. **Context preservation**: Each subsequent run with the same session includes the full conversation history, allowing the agent to maintain context. + +This eliminates the need to manually call `.to_input_list()` and manage conversation state between runs. + +## Memory operations + +### Basic operations + +Sessions supports several operations for managing conversation history: + +```python +from agents import SQLiteSession + +session = SQLiteSession("user_123", "conversations.db") + +# Get all items in a session +items = await session.get_items() + +# Add new items to a session +new_items = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"} +] +await session.add_items(new_items) + +# Remove and return the most recent item +last_item = await session.pop_item() +print(last_item) # {"role": "assistant", "content": "Hi there!"} + +# Clear all items from a session +await session.clear_session() +``` + +### Using pop_item for corrections + +The `pop_item` method is particularly useful when you want to undo or modify the last item in a conversation: + +```python +from agents import Agent, Runner, SQLiteSession + +agent = Agent(name="Assistant") +session = SQLiteSession("correction_example") + +# Initial conversation +result = await Runner.run( + agent, + "What's 2 + 2?", + session=session +) +print(f"Agent: {result.final_output}") + +# User wants to correct their question +user_item = await session.pop_item() # Remove user's question +assistant_item = await session.pop_item() # Remove agent's response + +# Ask a corrected question +result = await Runner.run( + agent, + "What's 2 + 3?", + session=session +) +print(f"Agent: {result.final_output}") +``` + +## Memory options + +### No memory (default) + +```python +# Default behavior - no session memory +result = await Runner.run(agent, "Hello") +``` + +### SQLite memory + +```python +from agents import SQLiteSession + +# In-memory database (lost when process ends) +session = SQLiteSession("user_123") + +# Persistent file-based database +session = SQLiteSession("user_123", "conversations.db") + +# Use the session +result = await Runner.run( + agent, + "Hello", + session=session +) +``` + +### Multiple sessions + +```python +from agents import Agent, Runner, SQLiteSession + +agent = Agent(name="Assistant") + +# Different sessions maintain separate conversation histories +session_1 = SQLiteSession("user_123", "conversations.db") +session_2 = SQLiteSession("user_456", "conversations.db") + +result1 = await Runner.run( + agent, + "Hello", + session=session_1 +) +result2 = await Runner.run( + agent, + "Hello", + session=session_2 +) +``` + +## Custom memory implementations + +You can implement your own session memory by creating a class that follows the [`Session`][agents.memory.session.Session] protocol: + +````python +from agents.memory import Session +from typing import List + +class MyCustomSession: + """Custom session implementation following the Session protocol.""" + + def __init__(self, session_id: str): + self.session_id = session_id + # Your initialization here + + async def get_items(self, limit: int | None = None) -> List[dict]: + """Retrieve conversation history for this session.""" + # Your implementation here + pass + + async def add_items(self, items: List[dict]) -> None: + """Store new items for this session.""" + # Your implementation here + pass + + async def pop_item(self) -> dict | None: + """Remove and return the most recent item from this session.""" + # Your implementation here + pass + + async def clear_session(self) -> None: + """Clear all items for this session.""" + # Your implementation here + pass + +# Use your custom session +agent = Agent(name="Assistant") +result = await Runner.run( + agent, + "Hello", + session=MyCustomSession("my_session") +) + +## Session management + +### Session ID naming + +Use meaningful session IDs that help you organize conversations: + +- User-based: `"user_12345"` +- Thread-based: `"thread_abc123"` +- Context-based: `"support_ticket_456"` + +### Memory persistence + +- Use in-memory SQLite (`SQLiteSession("session_id")`) for temporary conversations +- Use file-based SQLite (`SQLiteSession("session_id", "path/to/db.sqlite")`) for persistent conversations +- Consider implementing custom session backends for production systems (Redis, PostgreSQL, etc.) + +### Session management + +```python +# Clear a session when conversation should start fresh +await session.clear_session() + +# Different agents can share the same session +support_agent = Agent(name="Support") +billing_agent = Agent(name="Billing") +session = SQLiteSession("user_123") + +# Both agents will see the same conversation history +result1 = await Runner.run( + support_agent, + "Help me with my account", + session=session +) +result2 = await Runner.run( + billing_agent, + "What are my charges?", + session=session +) +```` + +## Complete example + +Here's a complete example showing session memory in action: + +```python +import asyncio +from agents import Agent, Runner, SQLiteSession + + +async def main(): + # Create an agent + agent = Agent( + name="Assistant", + instructions="Reply very concisely.", + ) + + # Create a session instance that will persist across runs + session = SQLiteSession("conversation_123", "conversation_history.db") + + print("=== Sessions Example ===") + print("The agent will remember previous messages automatically.\n") + + # First turn + print("First turn:") + print("User: What city is the Golden Gate Bridge in?") + result = await Runner.run( + agent, + "What city is the Golden Gate Bridge in?", + session=session + ) + print(f"Assistant: {result.final_output}") + print() + + # Second turn - the agent will remember the previous conversation + print("Second turn:") + print("User: What state is it in?") + result = await Runner.run( + agent, + "What state is it in?", + session=session + ) + print(f"Assistant: {result.final_output}") + print() + + # Third turn - continuing the conversation + print("Third turn:") + print("User: What's the population of that state?") + result = await Runner.run( + agent, + "What's the population of that state?", + session=session + ) + print(f"Assistant: {result.final_output}") + print() + + print("=== Conversation Complete ===") + print("Notice how the agent remembered the context from previous turns!") + print("Sessions automatically handles conversation history.") + + +if __name__ == "__main__": + asyncio.run(main()) +``` + +## API Reference + +For detailed API documentation, see: + +- [`Session`][agents.memory.Session] - Protocol interface +- [`SQLiteSession`][agents.memory.SQLiteSession] - SQLite implementation diff --git a/examples/basic/hello_world_jupyter.ipynb b/examples/basic/hello_world_jupyter.ipynb index 42ee8e6a2..8dd3bb379 100644 --- a/examples/basic/hello_world_jupyter.ipynb +++ b/examples/basic/hello_world_jupyter.ipynb @@ -30,7 +30,7 @@ "agent = Agent(name=\"Assistant\", instructions=\"You are a helpful assistant\")\n", "\n", "# Intended for Jupyter notebooks where there's an existing event loop\n", - "result = await Runner.run(agent, \"Write a haiku about recursion in programming.\") # type: ignore[top-level-await] # noqa: F704\n", + "result = await Runner.run(agent, \"Write a haiku about recursion in programming.\") # type: ignore[top-level-await] # noqa: F704\n", "print(result.final_output)" ] } diff --git a/examples/basic/session_example.py b/examples/basic/session_example.py new file mode 100644 index 000000000..63d1d1b7c --- /dev/null +++ b/examples/basic/session_example.py @@ -0,0 +1,77 @@ +""" +Example demonstrating session memory functionality. + +This example shows how to use session memory to maintain conversation history +across multiple agent runs without manually handling .to_input_list(). +""" + +import asyncio + +from agents import Agent, Runner, SQLiteSession + + +async def main(): + # Create an agent + agent = Agent( + name="Assistant", + instructions="Reply very concisely.", + ) + + # Create a session instance that will persist across runs + session_id = "conversation_123" + session = SQLiteSession(session_id) + + print("=== Session Example ===") + print("The agent will remember previous messages automatically.\n") + + # First turn + print("First turn:") + print("User: What city is the Golden Gate Bridge in?") + result = await Runner.run( + agent, + "What city is the Golden Gate Bridge in?", + session=session, + ) + print(f"Assistant: {result.final_output}") + print() + + # Second turn - the agent will remember the previous conversation + print("Second turn:") + print("User: What state is it in?") + result = await Runner.run(agent, "What state is it in?", session=session) + print(f"Assistant: {result.final_output}") + print() + + # Third turn - continuing the conversation + print("Third turn:") + print("User: What's the population of that state?") + result = await Runner.run( + agent, + "What's the population of that state?", + session=session, + ) + print(f"Assistant: {result.final_output}") + print() + + print("=== Conversation Complete ===") + print("Notice how the agent remembered the context from previous turns!") + print("Sessions automatically handles conversation history.") + + # Demonstrate the limit parameter - get only the latest 2 items + print("\n=== Latest Items Demo ===") + latest_items = await session.get_items(limit=2) + print("Latest 2 items:") + for i, msg in enumerate(latest_items, 1): + role = msg.get("role", "unknown") + content = msg.get("content", "") + print(f" {i}. {role}: {content}") + + print(f"\nFetched {len(latest_items)} out of total conversation history.") + + # Get all items to show the difference + all_items = await session.get_items() + print(f"Total items in session: {len(all_items)}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/reasoning_content/main.py b/examples/reasoning_content/main.py index 5f67e1779..c23b04254 100644 --- a/examples/reasoning_content/main.py +++ b/examples/reasoning_content/main.py @@ -44,7 +44,7 @@ async def stream_with_reasoning_content(): handoffs=[], tracing=ModelTracing.DISABLED, previous_response_id=None, - prompt=None + prompt=None, ): if event.type == "response.reasoning_summary_text.delta": print( @@ -82,7 +82,7 @@ async def get_response_with_reasoning_content(): handoffs=[], tracing=ModelTracing.DISABLED, previous_response_id=None, - prompt=None + prompt=None, ) # Extract reasoning content and regular content from the response diff --git a/mkdocs.yml b/mkdocs.yml index b79e6454f..a465bfefa 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -57,6 +57,7 @@ plugins: - Documentation: - agents.md - running_agents.md + - sessions.md - results.md - streaming.md - repl.md @@ -82,6 +83,7 @@ plugins: - ref/index.md - ref/agent.md - ref/run.md + - ref/memory.md - ref/repl.md - ref/tool.md - ref/result.md @@ -140,6 +142,7 @@ plugins: - ドキュメント: - agents.md - running_agents.md + - sessions.md - results.md - streaming.md - repl.md diff --git a/src/agents/__init__.py b/src/agents/__init__.py index 1296b72be..1bcb06e9f 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -40,6 +40,7 @@ TResponseInputItem, ) from .lifecycle import AgentHooks, RunHooks +from .memory import Session, SQLiteSession from .model_settings import ModelSettings from .models.interface import Model, ModelProvider, ModelTracing from .models.openai_chatcompletions import OpenAIChatCompletionsModel @@ -209,6 +210,8 @@ def enable_verbose_stdout_logging(): "ItemHelpers", "RunHooks", "AgentHooks", + "Session", + "SQLiteSession", "RunContextWrapper", "TContext", "RunErrorDetails", diff --git a/src/agents/agent.py b/src/agents/agent.py index 6c87297f1..2f77f5f2a 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -256,9 +256,7 @@ async def get_prompt( """Get the prompt for the agent.""" return await PromptUtil.to_model_input(self.prompt, run_context, self) - async def get_mcp_tools( - self, run_context: RunContextWrapper[TContext] - ) -> list[Tool]: + async def get_mcp_tools(self, run_context: RunContextWrapper[TContext]) -> list[Tool]: """Fetches the available tools from the MCP servers.""" convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False) return await MCPUtil.get_all_function_tools( diff --git a/src/agents/mcp/server.py b/src/agents/mcp/server.py index f6c2b58ef..222d75b55 100644 --- a/src/agents/mcp/server.py +++ b/src/agents/mcp/server.py @@ -118,9 +118,7 @@ async def _apply_tool_filter( return await self._apply_dynamic_tool_filter(tools, run_context, agent) def _apply_static_tool_filter( - self, - tools: list[MCPTool], - static_filter: ToolFilterStatic + self, tools: list[MCPTool], static_filter: ToolFilterStatic ) -> list[MCPTool]: """Apply static tool filtering based on allowlist and blocklist.""" filtered_tools = tools diff --git a/src/agents/memory/__init__.py b/src/agents/memory/__init__.py new file mode 100644 index 000000000..059ca57ab --- /dev/null +++ b/src/agents/memory/__init__.py @@ -0,0 +1,3 @@ +from .session import Session, SQLiteSession + +__all__ = ["Session", "SQLiteSession"] diff --git a/src/agents/memory/session.py b/src/agents/memory/session.py new file mode 100644 index 000000000..8db0971eb --- /dev/null +++ b/src/agents/memory/session.py @@ -0,0 +1,369 @@ +from __future__ import annotations + +import asyncio +import json +import sqlite3 +import threading +from abc import ABC, abstractmethod +from pathlib import Path +from typing import TYPE_CHECKING, Protocol, runtime_checkable + +if TYPE_CHECKING: + from ..items import TResponseInputItem + + +@runtime_checkable +class Session(Protocol): + """Protocol for session implementations. + + Session stores conversation history for a specific session, allowing + agents to maintain context without requiring explicit manual memory management. + """ + + session_id: str + + async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + """Retrieve the conversation history for this session. + + Args: + limit: Maximum number of items to retrieve. If None, retrieves all items. + When specified, returns the latest N items in chronological order. + + Returns: + List of input items representing the conversation history + """ + ... + + async def add_items(self, items: list[TResponseInputItem]) -> None: + """Add new items to the conversation history. + + Args: + items: List of input items to add to the history + """ + ... + + async def pop_item(self) -> TResponseInputItem | None: + """Remove and return the most recent item from the session. + + Returns: + The most recent item if it exists, None if the session is empty + """ + ... + + async def clear_session(self) -> None: + """Clear all items for this session.""" + ... + + +class SessionABC(ABC): + """Abstract base class for session implementations. + + Session stores conversation history for a specific session, allowing + agents to maintain context without requiring explicit manual memory management. + + This ABC is intended for internal use and as a base class for concrete implementations. + Third-party libraries should implement the Session protocol instead. + """ + + session_id: str + + @abstractmethod + async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + """Retrieve the conversation history for this session. + + Args: + limit: Maximum number of items to retrieve. If None, retrieves all items. + When specified, returns the latest N items in chronological order. + + Returns: + List of input items representing the conversation history + """ + ... + + @abstractmethod + async def add_items(self, items: list[TResponseInputItem]) -> None: + """Add new items to the conversation history. + + Args: + items: List of input items to add to the history + """ + ... + + @abstractmethod + async def pop_item(self) -> TResponseInputItem | None: + """Remove and return the most recent item from the session. + + Returns: + The most recent item if it exists, None if the session is empty + """ + ... + + @abstractmethod + async def clear_session(self) -> None: + """Clear all items for this session.""" + ... + + +class SQLiteSession(SessionABC): + """SQLite-based implementation of session storage. + + This implementation stores conversation history in a SQLite database. + By default, uses an in-memory database that is lost when the process ends. + For persistent storage, provide a file path. + """ + + def __init__( + self, + session_id: str, + db_path: str | Path = ":memory:", + sessions_table: str = "agent_sessions", + messages_table: str = "agent_messages", + ): + """Initialize the SQLite session. + + Args: + session_id: Unique identifier for the conversation session + db_path: Path to the SQLite database file. Defaults to ':memory:' (in-memory database) + sessions_table: Name of the table to store session metadata. Defaults to + 'agent_sessions' + messages_table: Name of the table to store message data. Defaults to 'agent_messages' + """ + self.session_id = session_id + self.db_path = db_path + self.sessions_table = sessions_table + self.messages_table = messages_table + self._local = threading.local() + self._lock = threading.Lock() + + # For in-memory databases, we need a shared connection to avoid thread isolation + # For file databases, we use thread-local connections for better concurrency + self._is_memory_db = str(db_path) == ":memory:" + if self._is_memory_db: + self._shared_connection = sqlite3.connect(":memory:", check_same_thread=False) + self._shared_connection.execute("PRAGMA journal_mode=WAL") + self._init_db_for_connection(self._shared_connection) + else: + # For file databases, initialize the schema once since it persists + init_conn = sqlite3.connect(str(self.db_path), check_same_thread=False) + init_conn.execute("PRAGMA journal_mode=WAL") + self._init_db_for_connection(init_conn) + init_conn.close() + + def _get_connection(self) -> sqlite3.Connection: + """Get a database connection.""" + if self._is_memory_db: + # Use shared connection for in-memory database to avoid thread isolation + return self._shared_connection + else: + # Use thread-local connections for file databases + if not hasattr(self._local, "connection"): + self._local.connection = sqlite3.connect( + str(self.db_path), + check_same_thread=False, + ) + self._local.connection.execute("PRAGMA journal_mode=WAL") + assert isinstance(self._local.connection, sqlite3.Connection), ( + f"Expected sqlite3.Connection, got {type(self._local.connection)}" + ) + return self._local.connection + + def _init_db_for_connection(self, conn: sqlite3.Connection) -> None: + """Initialize the database schema for a specific connection.""" + conn.execute( + f""" + CREATE TABLE IF NOT EXISTS {self.sessions_table} ( + session_id TEXT PRIMARY KEY, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """ + ) + + conn.execute( + f""" + CREATE TABLE IF NOT EXISTS {self.messages_table} ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + session_id TEXT NOT NULL, + message_data TEXT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (session_id) REFERENCES {self.sessions_table} (session_id) + ON DELETE CASCADE + ) + """ + ) + + conn.execute( + f""" + CREATE INDEX IF NOT EXISTS idx_{self.messages_table}_session_id + ON {self.messages_table} (session_id, created_at) + """ + ) + + conn.commit() + + async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + """Retrieve the conversation history for this session. + + Args: + limit: Maximum number of items to retrieve. If None, retrieves all items. + When specified, returns the latest N items in chronological order. + + Returns: + List of input items representing the conversation history + """ + + def _get_items_sync(): + conn = self._get_connection() + with self._lock if self._is_memory_db else threading.Lock(): + if limit is None: + # Fetch all items in chronological order + cursor = conn.execute( + f""" + SELECT message_data FROM {self.messages_table} + WHERE session_id = ? + ORDER BY created_at ASC + """, + (self.session_id,), + ) + else: + # Fetch the latest N items in chronological order + cursor = conn.execute( + f""" + SELECT message_data FROM {self.messages_table} + WHERE session_id = ? + ORDER BY created_at DESC + LIMIT ? + """, + (self.session_id, limit), + ) + + rows = cursor.fetchall() + + # Reverse to get chronological order when using DESC + if limit is not None: + rows = list(reversed(rows)) + + items = [] + for (message_data,) in rows: + try: + item = json.loads(message_data) + items.append(item) + except json.JSONDecodeError: + # Skip invalid JSON entries + continue + + return items + + return await asyncio.to_thread(_get_items_sync) + + async def add_items(self, items: list[TResponseInputItem]) -> None: + """Add new items to the conversation history. + + Args: + items: List of input items to add to the history + """ + if not items: + return + + def _add_items_sync(): + conn = self._get_connection() + + with self._lock if self._is_memory_db else threading.Lock(): + # Ensure session exists + conn.execute( + f""" + INSERT OR IGNORE INTO {self.sessions_table} (session_id) VALUES (?) + """, + (self.session_id,), + ) + + # Add items + message_data = [(self.session_id, json.dumps(item)) for item in items] + conn.executemany( + f""" + INSERT INTO {self.messages_table} (session_id, message_data) VALUES (?, ?) + """, + message_data, + ) + + # Update session timestamp + conn.execute( + f""" + UPDATE {self.sessions_table} + SET updated_at = CURRENT_TIMESTAMP + WHERE session_id = ? + """, + (self.session_id,), + ) + + conn.commit() + + await asyncio.to_thread(_add_items_sync) + + async def pop_item(self) -> TResponseInputItem | None: + """Remove and return the most recent item from the session. + + Returns: + The most recent item if it exists, None if the session is empty + """ + + def _pop_item_sync(): + conn = self._get_connection() + with self._lock if self._is_memory_db else threading.Lock(): + # Use DELETE with RETURNING to atomically delete and return the most recent item + cursor = conn.execute( + f""" + DELETE FROM {self.messages_table} + WHERE id = ( + SELECT id FROM {self.messages_table} + WHERE session_id = ? + ORDER BY created_at DESC + LIMIT 1 + ) + RETURNING message_data + """, + (self.session_id,), + ) + + result = cursor.fetchone() + conn.commit() + + if result: + message_data = result[0] + try: + item = json.loads(message_data) + return item + except json.JSONDecodeError: + # Return None for corrupted JSON entries (already deleted) + return None + + return None + + return await asyncio.to_thread(_pop_item_sync) + + async def clear_session(self) -> None: + """Clear all items for this session.""" + + def _clear_session_sync(): + conn = self._get_connection() + with self._lock if self._is_memory_db else threading.Lock(): + conn.execute( + f"DELETE FROM {self.messages_table} WHERE session_id = ?", + (self.session_id,), + ) + conn.execute( + f"DELETE FROM {self.sessions_table} WHERE session_id = ?", + (self.session_id,), + ) + conn.commit() + + await asyncio.to_thread(_clear_session_sync) + + def close(self) -> None: + """Close the database connection.""" + if self._is_memory_db: + if hasattr(self, "_shared_connection"): + self._shared_connection.close() + else: + if hasattr(self._local, "connection"): + self._local.connection.close() diff --git a/src/agents/model_settings.py b/src/agents/model_settings.py index 26af94ba3..f81f10989 100644 --- a/src/agents/model_settings.py +++ b/src/agents/model_settings.py @@ -17,9 +17,9 @@ class _OmitTypeAnnotation: @classmethod def __get_pydantic_core_schema__( - cls, - _source_type: Any, - _handler: GetCoreSchemaHandler, + cls, + _source_type: Any, + _handler: GetCoreSchemaHandler, ) -> core_schema.CoreSchema: def validate_from_none(value: None) -> _Omit: return _Omit() @@ -39,13 +39,14 @@ def validate_from_none(value: None) -> _Omit: from_none_schema, ] ), - serialization=core_schema.plain_serializer_function_ser_schema( - lambda instance: None - ), + serialization=core_schema.plain_serializer_function_ser_schema(lambda instance: None), ) + + Omit = Annotated[_Omit, _OmitTypeAnnotation] Headers: TypeAlias = Mapping[str, Union[str, Omit]] + @dataclass class ModelSettings: """Settings to use when calling an LLM. diff --git a/src/agents/run.py b/src/agents/run.py index e5f9378ec..2dd9524bb 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -32,6 +32,7 @@ ModelBehaviorError, OutputGuardrailTripwireTriggered, RunErrorDetails, + UserError, ) from .guardrail import ( InputGuardrail, @@ -43,6 +44,7 @@ from .items import ItemHelpers, ModelResponse, RunItem, TResponseInputItem from .lifecycle import RunHooks from .logger import logger +from .memory import Session from .model_settings import ModelSettings from .models.interface import Model, ModelProvider from .models.multi_provider import MultiProvider @@ -156,6 +158,9 @@ class RunOptions(TypedDict, Generic[TContext]): previous_response_id: NotRequired[str | None] """The ID of the previous response, if any.""" + session: NotRequired[Session | None] + """The session for the run.""" + class Runner: @classmethod @@ -169,6 +174,7 @@ async def run( hooks: RunHooks[TContext] | None = None, run_config: RunConfig | None = None, previous_response_id: str | None = None, + session: Session | None = None, ) -> RunResult: """Run a workflow starting at the given agent. The agent will run in a loop until a final output is generated. The loop runs like so: @@ -205,6 +211,7 @@ async def run( hooks=hooks, run_config=run_config, previous_response_id=previous_response_id, + session=session, ) @classmethod @@ -218,6 +225,7 @@ def run_sync( hooks: RunHooks[TContext] | None = None, run_config: RunConfig | None = None, previous_response_id: str | None = None, + session: Session | None = None, ) -> RunResult: """Run a workflow synchronously, starting at the given agent. Note that this just wraps the `run` method, so it will not work if there's already an event loop (e.g. inside an async @@ -257,6 +265,7 @@ def run_sync( hooks=hooks, run_config=run_config, previous_response_id=previous_response_id, + session=session, ) @classmethod @@ -269,6 +278,7 @@ def run_streamed( hooks: RunHooks[TContext] | None = None, run_config: RunConfig | None = None, previous_response_id: str | None = None, + session: Session | None = None, ) -> RunResultStreaming: """Run a workflow starting at the given agent in streaming mode. The returned result object contains a method you can use to stream semantic events as they are generated. @@ -305,6 +315,7 @@ def run_streamed( hooks=hooks, run_config=run_config, previous_response_id=previous_response_id, + session=session, ) @@ -325,11 +336,15 @@ async def run( hooks = kwargs.get("hooks") run_config = kwargs.get("run_config") previous_response_id = kwargs.get("previous_response_id") + session = kwargs.get("session") if hooks is None: hooks = RunHooks[Any]() if run_config is None: run_config = RunConfig() + # Prepare input with session if enabled + prepared_input = await self._prepare_input_with_session(input, session) + tool_use_tracker = AgentToolUseTracker() with TraceCtxManager( @@ -340,7 +355,7 @@ async def run( disabled=run_config.tracing_disabled, ): current_turn = 0 - original_input: str | list[TResponseInputItem] = copy.deepcopy(input) + original_input: str | list[TResponseInputItem] = copy.deepcopy(prepared_input) generated_items: list[RunItem] = [] model_responses: list[ModelResponse] = [] @@ -399,7 +414,7 @@ async def run( starting_agent, starting_agent.input_guardrails + (run_config.input_guardrails or []), - copy.deepcopy(input), + copy.deepcopy(prepared_input), context_wrapper, ), self._run_single_turn( @@ -441,7 +456,7 @@ async def run( turn_result.next_step.output, context_wrapper, ) - return RunResult( + result = RunResult( input=original_input, new_items=generated_items, raw_responses=model_responses, @@ -451,6 +466,11 @@ async def run( output_guardrail_results=output_guardrail_results, context_wrapper=context_wrapper, ) + + # Save the conversation to session if enabled + await self._save_result_to_session(session, input, result) + + return result elif isinstance(turn_result.next_step, NextStepHandoff): current_agent = cast(Agent[TContext], turn_result.next_step.new_agent) current_span.finish(reset_current=True) @@ -488,10 +508,13 @@ def run_sync( hooks = kwargs.get("hooks") run_config = kwargs.get("run_config") previous_response_id = kwargs.get("previous_response_id") + session = kwargs.get("session") + return asyncio.get_event_loop().run_until_complete( self.run( starting_agent, input, + session=session, context=context, max_turns=max_turns, hooks=hooks, @@ -511,6 +534,8 @@ def run_streamed( hooks = kwargs.get("hooks") run_config = kwargs.get("run_config") previous_response_id = kwargs.get("previous_response_id") + session = kwargs.get("session") + if hooks is None: hooks = RunHooks[Any]() if run_config is None: @@ -563,6 +588,7 @@ def run_streamed( context_wrapper=context_wrapper, run_config=run_config, previous_response_id=previous_response_id, + session=session, ) ) return streamed_result @@ -621,6 +647,7 @@ async def _start_streaming( context_wrapper: RunContextWrapper[TContext], run_config: RunConfig, previous_response_id: str | None, + session: Session | None, ): if streamed_result.trace: streamed_result.trace.start(mark_as_current=True) @@ -634,6 +661,12 @@ async def _start_streaming( streamed_result._event_queue.put_nowait(AgentUpdatedStreamEvent(new_agent=current_agent)) try: + # Prepare input with session if enabled + prepared_input = await AgentRunner._prepare_input_with_session(starting_input, session) + + # Update the streamed result with the prepared input + streamed_result.input = prepared_input + while True: if streamed_result.is_complete: break @@ -680,7 +713,7 @@ async def _start_streaming( cls._run_input_guardrails_with_queue( starting_agent, starting_agent.input_guardrails + (run_config.input_guardrails or []), - copy.deepcopy(ItemHelpers.input_to_new_input_list(starting_input)), + copy.deepcopy(ItemHelpers.input_to_new_input_list(prepared_input)), context_wrapper, streamed_result, current_span, @@ -734,6 +767,23 @@ async def _start_streaming( streamed_result.output_guardrail_results = output_guardrail_results streamed_result.final_output = turn_result.next_step.output streamed_result.is_complete = True + + # Save the conversation to session if enabled + # Create a temporary RunResult for session saving + temp_result = RunResult( + input=streamed_result.input, + new_items=streamed_result.new_items, + raw_responses=streamed_result.raw_responses, + final_output=streamed_result.final_output, + _last_agent=current_agent, + input_guardrail_results=streamed_result.input_guardrail_results, + output_guardrail_results=streamed_result.output_guardrail_results, + context_wrapper=context_wrapper, + ) + await AgentRunner._save_result_to_session( + session, starting_input, temp_result + ) + streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) elif isinstance(turn_result.next_step, NextStepRunAgain): pass @@ -1136,5 +1186,57 @@ def _get_model(cls, agent: Agent[Any], run_config: RunConfig) -> Model: return run_config.model_provider.get_model(agent.model) + @classmethod + async def _prepare_input_with_session( + cls, + input: str | list[TResponseInputItem], + session: Session | None, + ) -> str | list[TResponseInputItem]: + """Prepare input by combining it with session history if enabled.""" + if session is None: + return input + + # Validate that we don't have both a session and a list input, as this creates + # ambiguity about whether the list should append to or replace existing session history + if isinstance(input, list): + raise UserError( + "Cannot provide both a session and a list of input items. " + "When using session memory, provide only a string input to append to the " + "conversation, or use session=None and provide a list to manually manage " + "conversation history." + ) + + # Get previous conversation history + history = await session.get_items() + + # Convert input to list format + new_input_list = ItemHelpers.input_to_new_input_list(input) + + # Combine history with new input + combined_input = history + new_input_list + + return combined_input + + @classmethod + async def _save_result_to_session( + cls, + session: Session | None, + original_input: str | list[TResponseInputItem], + result: RunResult, + ) -> None: + """Save the conversation turn to session.""" + if session is None: + return + + # Convert original input to list format if needed + input_list = ItemHelpers.input_to_new_input_list(original_input) + + # Convert new items to input format + new_items_as_input = [item.to_input_item() for item in result.new_items] + + # Save all items from this turn + items_to_save = input_list + new_items_as_input + await session.add_items(items_to_save) + DEFAULT_AGENT_RUNNER = AgentRunner() diff --git a/src/agents/tool.py b/src/agents/tool.py index 3aab47752..f2861b936 100644 --- a/src/agents/tool.py +++ b/src/agents/tool.py @@ -30,7 +30,6 @@ from .util._types import MaybeAwaitable if TYPE_CHECKING: - from .agent import Agent ToolParams = ParamSpec("ToolParams") diff --git a/tests/mcp/test_tool_filtering.py b/tests/mcp/test_tool_filtering.py index c1ffff4b8..0127df806 100644 --- a/tests/mcp/test_tool_filtering.py +++ b/tests/mcp/test_tool_filtering.py @@ -3,6 +3,7 @@ external dependencies (processes, network connections) and ensure fast, reliable unit tests. FakeMCPServer delegates filtering logic to the real _MCPServerWithClientSession implementation. """ + import asyncio import pytest @@ -27,6 +28,7 @@ def create_test_context() -> RunContextWrapper: # === Static Tool Filtering Tests === + @pytest.mark.asyncio async def test_static_tool_filtering(): """Test all static tool filtering scenarios: allowed, blocked, both, none, etc.""" @@ -55,7 +57,7 @@ async def test_static_tool_filtering(): # Test both filters together (allowed first, then blocked) server.tool_filter = { "allowed_tool_names": ["tool1", "tool2", "tool3"], - "blocked_tool_names": ["tool3"] + "blocked_tool_names": ["tool3"], } tools = await server.list_tools(run_context, agent) assert len(tools) == 2 @@ -68,8 +70,7 @@ async def test_static_tool_filtering(): # Test helper function server.tool_filter = create_static_tool_filter( - allowed_tool_names=["tool1", "tool2"], - blocked_tool_names=["tool2"] + allowed_tool_names=["tool1", "tool2"], blocked_tool_names=["tool2"] ) tools = await server.list_tools(run_context, agent) assert len(tools) == 1 @@ -78,6 +79,7 @@ async def test_static_tool_filtering(): # === Dynamic Tool Filtering Core Tests === + @pytest.mark.asyncio async def test_dynamic_filter_sync_and_async(): """Test both synchronous and asynchronous dynamic filters""" @@ -181,6 +183,7 @@ def error_prone_filter(context: ToolFilterContext, tool: MCPTool) -> bool: # === Integration Tests === + @pytest.mark.asyncio async def test_agent_dynamic_filtering_integration(): """Test dynamic filtering integration with Agent methods""" diff --git a/tests/model_settings/test_serialization.py b/tests/model_settings/test_serialization.py index 94d11def3..23ea5359c 100644 --- a/tests/model_settings/test_serialization.py +++ b/tests/model_settings/test_serialization.py @@ -135,8 +135,8 @@ def test_extra_args_resolve_both_none() -> None: assert resolved.temperature == 0.5 assert resolved.top_p == 0.9 -def test_pydantic_serialization() -> None: +def test_pydantic_serialization() -> None: """Tests whether ModelSettings can be serialized with Pydantic.""" # First, lets create a ModelSettings instance diff --git a/tests/test_reasoning_content.py b/tests/test_reasoning_content.py index 5160e09c2..69e9a7d0c 100644 --- a/tests/test_reasoning_content.py +++ b/tests/test_reasoning_content.py @@ -27,12 +27,8 @@ # Helper functions to create test objects consistently def create_content_delta(content: str) -> dict[str, Any]: """Create a delta dictionary with regular content""" - return { - "content": content, - "role": None, - "function_call": None, - "tool_calls": None - } + return {"content": content, "role": None, "function_call": None, "tool_calls": None} + def create_reasoning_delta(content: str) -> dict[str, Any]: """Create a delta dictionary with reasoning content. The Only difference is reasoning_content""" @@ -41,7 +37,7 @@ def create_reasoning_delta(content: str) -> dict[str, Any]: "role": None, "function_call": None, "tool_calls": None, - "reasoning_content": content + "reasoning_content": content, } @@ -188,7 +184,7 @@ async def test_get_response_with_reasoning_content(monkeypatch) -> None: "index": 0, "finish_reason": "stop", "message": msg_with_reasoning, - "delta": None + "delta": None, } chat = ChatCompletion( @@ -274,7 +270,7 @@ async def patched_fetch_response(self, *args, **kwargs): handoffs=[], tracing=ModelTracing.DISABLED, previous_response_id=None, - prompt=None + prompt=None, ): output_events.append(event) diff --git a/tests/test_session.py b/tests/test_session.py new file mode 100644 index 000000000..032f2bb38 --- /dev/null +++ b/tests/test_session.py @@ -0,0 +1,400 @@ +"""Tests for session memory functionality.""" + +import asyncio +import tempfile +from pathlib import Path + +import pytest + +from agents import Agent, Runner, SQLiteSession, TResponseInputItem +from agents.exceptions import UserError + +from .fake_model import FakeModel +from .test_responses import get_text_message + + +# Helper functions for parametrized testing of different Runner methods +def _run_sync_wrapper(agent, input_data, **kwargs): + """Wrapper for run_sync that properly sets up an event loop.""" + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + return Runner.run_sync(agent, input_data, **kwargs) + finally: + loop.close() + + +async def run_agent_async(runner_method: str, agent, input_data, **kwargs): + """Helper function to run agent with different methods.""" + if runner_method == "run": + return await Runner.run(agent, input_data, **kwargs) + elif runner_method == "run_sync": + # For run_sync, we need to run it in a thread with its own event loop + return await asyncio.to_thread(_run_sync_wrapper, agent, input_data, **kwargs) + elif runner_method == "run_streamed": + result = Runner.run_streamed(agent, input_data, **kwargs) + # For streaming, we first try to get at least one event to trigger any early exceptions + # If there's an exception in setup (like memory validation), it will be raised here + try: + first_event = None + async for event in result.stream_events(): + if first_event is None: + first_event = event + # Continue consuming all events + pass + except Exception: + # If an exception occurs during streaming, we let it propagate up + raise + return result + else: + raise ValueError(f"Unknown runner method: {runner_method}") + + +# Parametrized tests for different runner methods +@pytest.mark.parametrize("runner_method", ["run", "run_sync", "run_streamed"]) +@pytest.mark.asyncio +async def test_session_memory_basic_functionality_parametrized(runner_method): + """Test basic session memory functionality with SQLite backend across all runner methods.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_memory.db" + session_id = "test_session_123" + session = SQLiteSession(session_id, db_path) + + model = FakeModel() + agent = Agent(name="test", model=model) + + # First turn + model.set_next_output([get_text_message("San Francisco")]) + result1 = await run_agent_async( + runner_method, + agent, + "What city is the Golden Gate Bridge in?", + session=session, + ) + assert result1.final_output == "San Francisco" + + # Second turn - should have conversation history + model.set_next_output([get_text_message("California")]) + result2 = await run_agent_async( + runner_method, + agent, + "What state is it in?", + session=session, + ) + assert result2.final_output == "California" + + # Verify that the input to the second turn includes the previous conversation + # The model should have received the full conversation history + last_input = model.last_turn_args["input"] + assert len(last_input) > 1 # Should have more than just the current message + + session.close() + + +@pytest.mark.parametrize("runner_method", ["run", "run_sync", "run_streamed"]) +@pytest.mark.asyncio +async def test_session_memory_with_explicit_instance_parametrized(runner_method): + """Test session memory with an explicit SQLiteSession instance across all runner methods.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_memory.db" + session_id = "test_session_456" + session = SQLiteSession(session_id, db_path) + + model = FakeModel() + agent = Agent(name="test", model=model) + + # First turn + model.set_next_output([get_text_message("Hello")]) + result1 = await run_agent_async(runner_method, agent, "Hi there", session=session) + assert result1.final_output == "Hello" + + # Second turn + model.set_next_output([get_text_message("I remember you said hi")]) + result2 = await run_agent_async( + runner_method, + agent, + "Do you remember what I said?", + session=session, + ) + assert result2.final_output == "I remember you said hi" + + session.close() + + +@pytest.mark.parametrize("runner_method", ["run", "run_sync", "run_streamed"]) +@pytest.mark.asyncio +async def test_session_memory_disabled_parametrized(runner_method): + """Test that session memory is disabled when session=None across all runner methods.""" + model = FakeModel() + agent = Agent(name="test", model=model) + + # First turn (no session parameters = disabled) + model.set_next_output([get_text_message("Hello")]) + result1 = await run_agent_async(runner_method, agent, "Hi there") + assert result1.final_output == "Hello" + + # Second turn - should NOT have conversation history + model.set_next_output([get_text_message("I don't remember")]) + result2 = await run_agent_async(runner_method, agent, "Do you remember what I said?") + assert result2.final_output == "I don't remember" + + # Verify that the input to the second turn is just the current message + last_input = model.last_turn_args["input"] + assert len(last_input) == 1 # Should only have the current message + + +@pytest.mark.parametrize("runner_method", ["run", "run_sync", "run_streamed"]) +@pytest.mark.asyncio +async def test_session_memory_different_sessions_parametrized(runner_method): + """Test that different session IDs maintain separate conversation histories across all runner + methods.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_memory.db" + + model = FakeModel() + agent = Agent(name="test", model=model) + + # Session 1 + session_id_1 = "session_1" + session_1 = SQLiteSession(session_id_1, db_path) + + model.set_next_output([get_text_message("I like cats")]) + result1 = await run_agent_async(runner_method, agent, "I like cats", session=session_1) + assert result1.final_output == "I like cats" + + # Session 2 - different session + session_id_2 = "session_2" + session_2 = SQLiteSession(session_id_2, db_path) + + model.set_next_output([get_text_message("I like dogs")]) + result2 = await run_agent_async(runner_method, agent, "I like dogs", session=session_2) + assert result2.final_output == "I like dogs" + + # Back to Session 1 - should remember cats, not dogs + model.set_next_output([get_text_message("Yes, you mentioned cats")]) + result3 = await run_agent_async( + runner_method, + agent, + "What did I say I like?", + session=session_1, + ) + assert result3.final_output == "Yes, you mentioned cats" + + session_1.close() + session_2.close() + + +@pytest.mark.asyncio +async def test_sqlite_session_memory_direct(): + """Test SQLiteSession class directly.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_direct.db" + session_id = "direct_test" + session = SQLiteSession(session_id, db_path) + + # Test adding and retrieving items + items: list[TResponseInputItem] = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ] + + await session.add_items(items) + retrieved = await session.get_items() + + assert len(retrieved) == 2 + assert retrieved[0].get("role") == "user" + assert retrieved[0].get("content") == "Hello" + assert retrieved[1].get("role") == "assistant" + assert retrieved[1].get("content") == "Hi there!" + + # Test clearing session + await session.clear_session() + retrieved_after_clear = await session.get_items() + assert len(retrieved_after_clear) == 0 + + session.close() + + +@pytest.mark.asyncio +async def test_sqlite_session_memory_pop_item(): + """Test SQLiteSession pop_item functionality.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_pop.db" + session_id = "pop_test" + session = SQLiteSession(session_id, db_path) + + # Test popping from empty session + popped = await session.pop_item() + assert popped is None + + # Add items + items: list[TResponseInputItem] = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + {"role": "user", "content": "How are you?"}, + ] + + await session.add_items(items) + + # Verify all items are there + retrieved = await session.get_items() + assert len(retrieved) == 3 + + # Pop the most recent item + popped = await session.pop_item() + assert popped is not None + assert popped.get("role") == "user" + assert popped.get("content") == "How are you?" + + # Verify item was removed + retrieved_after_pop = await session.get_items() + assert len(retrieved_after_pop) == 2 + assert retrieved_after_pop[-1].get("content") == "Hi there!" + + # Pop another item + popped2 = await session.pop_item() + assert popped2 is not None + assert popped2.get("role") == "assistant" + assert popped2.get("content") == "Hi there!" + + # Pop the last item + popped3 = await session.pop_item() + assert popped3 is not None + assert popped3.get("role") == "user" + assert popped3.get("content") == "Hello" + + # Try to pop from empty session again + popped4 = await session.pop_item() + assert popped4 is None + + # Verify session is empty + final_items = await session.get_items() + assert len(final_items) == 0 + + session.close() + + +@pytest.mark.asyncio +async def test_session_memory_pop_different_sessions(): + """Test that pop_item only affects the specified session.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_pop_sessions.db" + + session_1_id = "session_1" + session_2_id = "session_2" + session_1 = SQLiteSession(session_1_id, db_path) + session_2 = SQLiteSession(session_2_id, db_path) + + # Add items to both sessions + items_1: list[TResponseInputItem] = [ + {"role": "user", "content": "Session 1 message"}, + ] + items_2: list[TResponseInputItem] = [ + {"role": "user", "content": "Session 2 message 1"}, + {"role": "user", "content": "Session 2 message 2"}, + ] + + await session_1.add_items(items_1) + await session_2.add_items(items_2) + + # Pop from session 2 + popped = await session_2.pop_item() + assert popped is not None + assert popped.get("content") == "Session 2 message 2" + + # Verify session 1 is unaffected + session_1_items = await session_1.get_items() + assert len(session_1_items) == 1 + assert session_1_items[0].get("content") == "Session 1 message" + + # Verify session 2 has one item left + session_2_items = await session_2.get_items() + assert len(session_2_items) == 1 + assert session_2_items[0].get("content") == "Session 2 message 1" + + session_1.close() + session_2.close() + + +@pytest.mark.asyncio +async def test_sqlite_session_get_items_with_limit(): + """Test SQLiteSession get_items with limit parameter.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_count.db" + session_id = "count_test" + session = SQLiteSession(session_id, db_path) + + # Add multiple items + items: list[TResponseInputItem] = [ + {"role": "user", "content": "Message 1"}, + {"role": "assistant", "content": "Response 1"}, + {"role": "user", "content": "Message 2"}, + {"role": "assistant", "content": "Response 2"}, + {"role": "user", "content": "Message 3"}, + {"role": "assistant", "content": "Response 3"}, + ] + + await session.add_items(items) + + # Test getting all items (default behavior) + all_items = await session.get_items() + assert len(all_items) == 6 + assert all_items[0].get("content") == "Message 1" + assert all_items[-1].get("content") == "Response 3" + + # Test getting latest 2 items + latest_2 = await session.get_items(limit=2) + assert len(latest_2) == 2 + assert latest_2[0].get("content") == "Message 3" + assert latest_2[1].get("content") == "Response 3" + + # Test getting latest 4 items + latest_4 = await session.get_items(limit=4) + assert len(latest_4) == 4 + assert latest_4[0].get("content") == "Message 2" + assert latest_4[1].get("content") == "Response 2" + assert latest_4[2].get("content") == "Message 3" + assert latest_4[3].get("content") == "Response 3" + + # Test getting more items than available + latest_10 = await session.get_items(limit=10) + assert len(latest_10) == 6 # Should return all available items + assert latest_10[0].get("content") == "Message 1" + assert latest_10[-1].get("content") == "Response 3" + + # Test getting 0 items + latest_0 = await session.get_items(limit=0) + assert len(latest_0) == 0 + + session.close() + + +@pytest.mark.parametrize("runner_method", ["run", "run_sync", "run_streamed"]) +@pytest.mark.asyncio +async def test_session_memory_rejects_both_session_and_list_input(runner_method): + """Test that passing both a session and list input raises a UserError across all runner + methods. + """ + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_validation.db" + session_id = "test_validation_parametrized" + session = SQLiteSession(session_id, db_path) + + model = FakeModel() + agent = Agent(name="test", model=model) + + # Test that providing both a session and a list input raises a UserError + model.set_next_output([get_text_message("This shouldn't run")]) + + list_input = [ + {"role": "user", "content": "Test message"}, + ] + + with pytest.raises(UserError) as exc_info: + await run_agent_async(runner_method, agent, list_input, session=session) + + # Verify the error message explains the issue + assert "Cannot provide both a session and a list of input items" in str(exc_info.value) + assert "manually manage conversation history" in str(exc_info.value) + + session.close()