diff --git a/AGENTS.md b/AGENTS.md index ff37db32..9e857801 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -35,6 +35,8 @@ Welcome to the OpenAI Agents SDK repository. This file contains the main points Coverage can be generated with `make coverage`. +*Note*: Use `uv run ...` to execute arbitrary cli commands within the project's virtual environment. + ## Snapshot tests Some tests rely on inline snapshots. See `tests/README.md` for details on updating them: diff --git a/README.md b/README.md index 7dcd97b3..0e8eff64 100644 --- a/README.md +++ b/README.md @@ -9,10 +9,119 @@ The OpenAI Agents SDK is a lightweight yet powerful framework for building multi 1. [**Agents**](https://openai.github.io/openai-agents-python/agents): LLMs configured with instructions, tools, guardrails, and handoffs 2. [**Handoffs**](https://openai.github.io/openai-agents-python/handoffs/): A specialized tool call used by the Agents SDK for transferring control between agents 3. [**Guardrails**](https://openai.github.io/openai-agents-python/guardrails/): Configurable safety checks for input and output validation -4. [**Tracing**](https://openai.github.io/openai-agents-python/tracing/): Built-in tracking of agent runs, allowing you to view, debug and optimize your workflows +4. [**Sessions**](#sessions): Automatic conversation history management across agent runs +5. [**Tracing**](https://openai.github.io/openai-agents-python/tracing/): Built-in tracking of agent runs, allowing you to view, debug and optimize your workflows Explore the [examples](examples) directory to see the SDK in action, and read our [documentation](https://openai.github.io/openai-agents-python/) for more details. +## Sessions + +The Agents SDK provides built-in session memory to automatically maintain conversation history across multiple agent runs, eliminating the need to manually handle `.to_input_list()` between turns. + +### Quick start + +```python +from agents import Agent, Runner, SQLiteSession + +# Create agent +agent = Agent( + name="Assistant", + instructions="Reply very concisely.", +) + +# Create a session instance +session = SQLiteSession("conversation_123") + +# First turn +result = await Runner.run( + agent, + "What city is the Golden Gate Bridge in?", + session=session +) +print(result.final_output) # "San Francisco" + +# Second turn - agent automatically remembers previous context +result = await Runner.run( + agent, + "What state is it in?", + session=session +) +print(result.final_output) # "California" + +# Also works with synchronous runner +result = Runner.run_sync( + agent, + "What's the population?", + session=session +) +print(result.final_output) # "Approximately 39 million" +``` + +### Session options + +- **No memory** (default): No session memory when session parameter is omitted +- **`session: Session = DatabaseSession(...)`**: Use a Session instance to manage conversation history + +```python +from agents import Agent, Runner, SQLiteSession + +# Custom SQLite database file +session = SQLiteSession("user_123", "conversations.db") +agent = Agent(name="Assistant") + +# Different session IDs maintain separate conversation histories +result1 = await Runner.run( + agent, + "Hello", + session=session +) +result2 = await Runner.run( + agent, + "Hello", + session=SQLiteSession("user_456", "conversations.db") +) +``` + +### Custom session implementations + +You can implement your own session memory by creating a class that follows the `Session` protocol: + +```python +from agents.memory import Session +from typing import List + +class MyCustomSession: + """Custom session implementation following the Session protocol.""" + + def __init__(self, session_id: str): + self.session_id = session_id + # Your initialization here + + async def get_messages(self) -> List[dict]: + # Retrieve conversation history for the session + pass + + async def add_messages(self, messages: List[dict]) -> None: + # Store new messages for the session + pass + + async def pop_message(self) -> dict | None: + # Remove and return the most recent message from the session + pass + + async def clear_session(self) -> None: + # Clear all messages for the session + pass + +# Use your custom session +agent = Agent(name="Assistant") +result = await Runner.run( + agent, + "Hello", + session=MyCustomSession("my_session") +) +``` + ## Get started 1. Set up your Python environment diff --git a/docs/index.md b/docs/index.md index 8aef6574..935c4be5 100644 --- a/docs/index.md +++ b/docs/index.md @@ -5,6 +5,7 @@ The [OpenAI Agents SDK](https://github.com/openai/openai-agents-python) enables - **Agents**, which are LLMs equipped with instructions and tools - **Handoffs**, which allow agents to delegate to other agents for specific tasks - **Guardrails**, which enable the inputs to agents to be validated +- **Sessions**, which automatically maintains conversation history across agent runs In combination with Python, these primitives are powerful enough to express complex relationships between tools and agents, and allow you to build real-world applications without a steep learning curve. In addition, the SDK comes with built-in **tracing** that lets you visualize and debug your agentic flows, as well as evaluate them and even fine-tune models for your application. @@ -21,6 +22,7 @@ Here are the main features of the SDK: - Python-first: Use built-in language features to orchestrate and chain agents, rather than needing to learn new abstractions. - Handoffs: A powerful feature to coordinate and delegate between multiple agents. - Guardrails: Run input validations and checks in parallel to your agents, breaking early if the checks fail. +- Sessions: Automatic conversation history management across agent runs, eliminating manual state handling. - Function tools: Turn any Python function into a tool, with automatic schema generation and Pydantic-powered validation. - Tracing: Built-in tracing that lets you visualize, debug and monitor your workflows, as well as use the OpenAI suite of evaluation, fine-tuning and distillation tools. diff --git a/docs/ref/memory.md b/docs/ref/memory.md new file mode 100644 index 00000000..04a2258b --- /dev/null +++ b/docs/ref/memory.md @@ -0,0 +1,8 @@ +# Memory + +::: agents.memory + + options: + members: + - Session + - SQLiteSession diff --git a/docs/running_agents.md b/docs/running_agents.md index f631cf46..6898f510 100644 --- a/docs/running_agents.md +++ b/docs/running_agents.md @@ -65,7 +65,9 @@ Calling any of the run methods can result in one or more agents running (and hen At the end of the agent run, you can choose what to show to the user. For example, you might show the user every new item generated by the agents, or just the final output. Either way, the user might then ask a followup question, in which case you can call the run method again. -You can use the base [`RunResultBase.to_input_list()`][agents.result.RunResultBase.to_input_list] method to get the inputs for the next turn. +### Manual conversation management + +You can manually manage conversation history using the [`RunResultBase.to_input_list()`][agents.result.RunResultBase.to_input_list] method to get the inputs for the next turn: ```python async def main(): @@ -84,6 +86,39 @@ async def main(): # California ``` +### Automatic conversation management with Sessions + +For a simpler approach, you can use [Sessions](sessions.md) to automatically handle conversation history without manually calling `.to_input_list()`: + +```python +from agents import Agent, Runner, SQLiteSession + +async def main(): + agent = Agent(name="Assistant", instructions="Reply very concisely.") + + # Create session instance + session = SQLiteSession("conversation_123") + + with trace(workflow_name="Conversation", group_id=thread_id): + # First turn + result = await Runner.run(agent, "What city is the Golden Gate Bridge in?", session=session) + print(result.final_output) + # San Francisco + + # Second turn - agent automatically remembers previous context + result = await Runner.run(agent, "What state is it in?", session=session) + print(result.final_output) + # California +``` + +Sessions automatically: + +- Retrieves conversation history before each run +- Stores new messages after each run +- Maintains separate conversations for different session IDs + +See the [Sessions documentation](sessions.md) for more details. + ## Exceptions The SDK raises exceptions in certain cases. The full list is in [`agents.exceptions`][]. As an overview: diff --git a/docs/sessions.md b/docs/sessions.md new file mode 100644 index 00000000..36f89c63 --- /dev/null +++ b/docs/sessions.md @@ -0,0 +1,319 @@ +# Sessions + +The Agents SDK provides built-in session memory to automatically maintain conversation history across multiple agent runs, eliminating the need to manually handle `.to_input_list()` between turns. + +Sessions stores conversation history for a specific session, allowing agents to maintain context without requiring explicit manual memory management. This is particularly useful for building chat applications or multi-turn conversations where you want the agent to remember previous interactions. + +## Quick start + +```python +from agents import Agent, Runner, SQLiteSession + +# Create agent +agent = Agent( + name="Assistant", + instructions="Reply very concisely.", +) + +# Create a session instance with a session ID +session = SQLiteSession("conversation_123") + +# First turn +result = await Runner.run( + agent, + "What city is the Golden Gate Bridge in?", + session=session +) +print(result.final_output) # "San Francisco" + +# Second turn - agent automatically remembers previous context +result = await Runner.run( + agent, + "What state is it in?", + session=session +) +print(result.final_output) # "California" + +# Also works with synchronous runner +result = Runner.run_sync( + agent, + "What's the population?", + session=session +) +print(result.final_output) # "Approximately 39 million" +``` + +## How it works + +When session memory is enabled: + +1. **Before each run**: The runner automatically retrieves the conversation history for the session and prepends it to the input messages. +2. **After each run**: All new messages generated during the run (user input, assistant responses, tool calls, etc.) are automatically stored in the session. +3. **Context preservation**: Each subsequent run with the same session includes the full conversation history, allowing the agent to maintain context. + +This eliminates the need to manually call `.to_input_list()` and manage conversation state between runs. + +## Memory operations + +### Basic operations + +Sessions supports several operations for managing conversation history: + +```python +from agents import SQLiteSession + +session = SQLiteSession("user_123", "conversations.db") + +# Get all messages in a session +messages = await session.get_messages() + +# Add new messages to a session +new_messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"} +] +await session.add_messages(new_messages) + +# Remove and return the most recent message +last_message = await session.pop_message() +print(last_message) # {"role": "assistant", "content": "Hi there!"} + +# Clear all messages from a session +await session.clear_session() +``` + +### Using pop_message for corrections + +The `pop_message` method is particularly useful when you want to undo or modify the last message in a conversation: + +```python +from agents import Agent, Runner, SQLiteSession + +agent = Agent(name="Assistant") +session = SQLiteSession("correction_example") + +# Initial conversation +result = await Runner.run( + agent, + "What's 2 + 2?", + session=session +) +print(f"Agent: {result.final_output}") + +# User wants to correct their question +user_message = await session.pop_message() # Remove user's question +assistant_message = await session.pop_message() # Remove agent's response + +# Ask a corrected question +result = await Runner.run( + agent, + "What's 2 + 3?", + session=session +) +print(f"Agent: {result.final_output}") +``` + +## Memory options + +### No memory (default) + +```python +# Default behavior - no session memory +result = await Runner.run(agent, "Hello") +``` + +### SQLite memory + +```python +from agents import SQLiteSession + +# In-memory database (lost when process ends) +session = SQLiteSession("user_123") + +# Persistent file-based database +session = SQLiteSession("user_123", "conversations.db") + +# Use the session +result = await Runner.run( + agent, + "Hello", + session=session +) +``` + +### Multiple sessions + +```python +from agents import Agent, Runner, SQLiteSession + +agent = Agent(name="Assistant") + +# Different sessions maintain separate conversation histories +session_1 = SQLiteSession("user_123", "conversations.db") +session_2 = SQLiteSession("user_456", "conversations.db") + +result1 = await Runner.run( + agent, + "Hello", + session=session_1 +) +result2 = await Runner.run( + agent, + "Hello", + session=session_2 +) +``` + +## Custom memory implementations + +You can implement your own session memory by creating a class that follows the [`Session`][agents.memory.session.Session] protocol: + +````python +from agents.memory import Session +from typing import List + +class MyCustomSession: + """Custom session implementation following the Session protocol.""" + + def __init__(self, session_id: str): + self.session_id = session_id + # Your initialization here + + async def get_messages(self) -> List[dict]: + """Retrieve conversation history for this session.""" + # Your implementation here + pass + + async def add_messages(self, messages: List[dict]) -> None: + """Store new messages for this session.""" + # Your implementation here + pass + + async def pop_message(self) -> dict | None: + """Remove and return the most recent message from this session.""" + # Your implementation here + pass + + async def clear_session(self) -> None: + """Clear all messages for this session.""" + # Your implementation here + pass + +# Use your custom session +agent = Agent(name="Assistant") +result = await Runner.run( + agent, + "Hello", + session=MyCustomSession("my_session") +) + +## Session management + +### Session ID naming + +Use meaningful session IDs that help you organize conversations: + +- User-based: `"user_12345"` +- Thread-based: `"thread_abc123"` +- Context-based: `"support_ticket_456"` + +### Memory persistence + +- Use in-memory SQLite (`SQLiteSession("session_id")`) for temporary conversations +- Use file-based SQLite (`SQLiteSession("session_id", "path/to/db.sqlite")`) for persistent conversations +- Consider implementing custom session backends for production systems (Redis, PostgreSQL, etc.) + +### Session management + +```python +# Clear a session when conversation should start fresh +await session.clear_session() + +# Different agents can share the same session +support_agent = Agent(name="Support") +billing_agent = Agent(name="Billing") +session = SQLiteSession("user_123") + +# Both agents will see the same conversation history +result1 = await Runner.run( + support_agent, + "Help me with my account", + session=session +) +result2 = await Runner.run( + billing_agent, + "What are my charges?", + session=session +) +```` + +## Complete example + +Here's a complete example showing session memory in action: + +```python +import asyncio +from agents import Agent, Runner, SQLiteSession + + +async def main(): + # Create an agent + agent = Agent( + name="Assistant", + instructions="Reply very concisely.", + ) + + # Create a session instance that will persist across runs + session = SQLiteSession("conversation_123", "conversation_history.db") + + print("=== Sessions Example ===") + print("The agent will remember previous messages automatically.\n") + + # First turn + print("First turn:") + print("User: What city is the Golden Gate Bridge in?") + result = await Runner.run( + agent, + "What city is the Golden Gate Bridge in?", + session=session + ) + print(f"Assistant: {result.final_output}") + print() + + # Second turn - the agent will remember the previous conversation + print("Second turn:") + print("User: What state is it in?") + result = await Runner.run( + agent, + "What state is it in?", + session=session + ) + print(f"Assistant: {result.final_output}") + print() + + # Third turn - continuing the conversation + print("Third turn:") + print("User: What's the population of that state?") + result = await Runner.run( + agent, + "What's the population of that state?", + session=session + ) + print(f"Assistant: {result.final_output}") + print() + + print("=== Conversation Complete ===") + print("Notice how the agent remembered the context from previous turns!") + print("Sessions automatically handles conversation history.") + + +if __name__ == "__main__": + asyncio.run(main()) +``` + +## API Reference + +For detailed API documentation, see: + +- [`Session`][agents.memory.Session] - Protocol interface +- [`SQLiteSession`][agents.memory.SQLiteSession] - SQLite implementation diff --git a/examples/basic/session_example.py b/examples/basic/session_example.py new file mode 100644 index 00000000..56431721 --- /dev/null +++ b/examples/basic/session_example.py @@ -0,0 +1,76 @@ +""" +Example demonstrating session memory functionality. + +This example shows how to use session memory to maintain conversation history +across multiple agent runs without manually handling .to_input_list(). +""" + +import asyncio +from agents import Agent, Runner, SQLiteSession + + +async def main(): + # Create an agent + agent = Agent( + name="Assistant", + instructions="Reply very concisely.", + ) + + # Create a session instance that will persist across runs + session_id = "conversation_123" + session = SQLiteSession(session_id) + + print("=== Session Example ===") + print("The agent will remember previous messages automatically.\n") + + # First turn + print("First turn:") + print("User: What city is the Golden Gate Bridge in?") + result = await Runner.run( + agent, + "What city is the Golden Gate Bridge in?", + session=session, + ) + print(f"Assistant: {result.final_output}") + print() + + # Second turn - the agent will remember the previous conversation + print("Second turn:") + print("User: What state is it in?") + result = await Runner.run(agent, "What state is it in?", session=session) + print(f"Assistant: {result.final_output}") + print() + + # Third turn - continuing the conversation + print("Third turn:") + print("User: What's the population of that state?") + result = await Runner.run( + agent, + "What's the population of that state?", + session=session, + ) + print(f"Assistant: {result.final_output}") + print() + + print("=== Conversation Complete ===") + print("Notice how the agent remembered the context from previous turns!") + print("Sessions automatically handles conversation history.") + + # Demonstrate the amount parameter - get only the latest 2 messages + print("\n=== Latest Messages Demo ===") + latest_messages = await session.get_messages(amount=2) + print("Latest 2 messages:") + for i, msg in enumerate(latest_messages, 1): + role = msg.get("role", "unknown") + content = msg.get("content", "") + print(f" {i}. {role}: {content}") + + print(f"\nFetched {len(latest_messages)} out of total conversation history.") + + # Get all messages to show the difference + all_messages = await session.get_messages() + print(f"Total messages in session: {len(all_messages)}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/mkdocs.yml b/mkdocs.yml index ad719670..8bc3b57f 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -57,6 +57,7 @@ plugins: - Documentation: - agents.md - running_agents.md + - sessions.md - results.md - streaming.md - tools.md @@ -80,6 +81,7 @@ plugins: - ref/index.md - ref/agent.md - ref/run.md + - ref/memory.md - ref/tool.md - ref/result.md - ref/stream_events.md @@ -137,6 +139,7 @@ plugins: - ドキュメント: - agents.md - running_agents.md + - sessions.md - results.md - streaming.md - tools.md diff --git a/src/agents/__init__.py b/src/agents/__init__.py index 58949157..6ee871b2 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -39,6 +39,7 @@ TResponseInputItem, ) from .lifecycle import AgentHooks, RunHooks +from .memory import Session, SQLiteSession from .model_settings import ModelSettings from .models.interface import Model, ModelProvider, ModelTracing from .models.openai_chatcompletions import OpenAIChatCompletionsModel @@ -127,7 +128,9 @@ def set_default_openai_key(key: str, use_for_tracing: bool = True) -> None: _config.set_default_openai_key(key, use_for_tracing) -def set_default_openai_client(client: AsyncOpenAI, use_for_tracing: bool = True) -> None: +def set_default_openai_client( + client: AsyncOpenAI, use_for_tracing: bool = True +) -> None: """Set the default OpenAI client to use for LLM requests and/or tracing. If provided, this client will be used instead of the default OpenAI client. @@ -202,6 +205,8 @@ def enable_verbose_stdout_logging(): "ItemHelpers", "RunHooks", "AgentHooks", + "Session", + "SQLiteSession", "RunContextWrapper", "TContext", "RunResult", diff --git a/src/agents/memory/__init__.py b/src/agents/memory/__init__.py new file mode 100644 index 00000000..059ca57a --- /dev/null +++ b/src/agents/memory/__init__.py @@ -0,0 +1,3 @@ +from .session import Session, SQLiteSession + +__all__ = ["Session", "SQLiteSession"] diff --git a/src/agents/memory/session.py b/src/agents/memory/session.py new file mode 100644 index 00000000..230981e0 --- /dev/null +++ b/src/agents/memory/session.py @@ -0,0 +1,331 @@ +from __future__ import annotations + +import asyncio +import json +import sqlite3 +import threading +from pathlib import Path +from typing import Any, TYPE_CHECKING, Protocol, runtime_checkable + +if TYPE_CHECKING: + from ..items import TResponseInputItem + + +@runtime_checkable +class Session(Protocol): + """Protocol for session implementations. + + Session stores conversation history for a specific session, allowing + agents to maintain context without requiring explicit manual memory management. + """ + + session_id: str + + async def get_messages(self, amount: int | None = None) -> list[TResponseInputItem]: + """Retrieve the conversation history for this session. + + Args: + amount: Maximum number of messages to retrieve. If None, retrieves all messages. + When specified, returns the latest N messages in chronological order. + + Returns: + List of input items representing the conversation history + """ + ... + + async def add_messages(self, messages: list[TResponseInputItem]) -> None: + """Add new messages to the conversation history. + + Args: + messages: List of input items to add to the history + """ + ... + + async def pop_message(self) -> TResponseInputItem | None: + """Remove and return the most recent message from the session. + + Returns: + The most recent message if it exists, None if the session is empty + """ + ... + + async def clear_session(self) -> None: + """Clear all messages for this session.""" + ... + + +class SQLiteSession(Session): + """SQLite-based implementation of session storage. + + This implementation stores conversation history in a SQLite database. + By default, uses an in-memory database that is lost when the process ends. + For persistent storage, provide a file path. + """ + + def __init__( + self, + session_id: str, + db_path: str | Path = ":memory:", + sessions_table: str = "agent_sessions", + messages_table: str = "agent_messages", + ): + """Initialize the SQLite session. + + Args: + session_id: Unique identifier for the conversation session + db_path: Path to the SQLite database file. Defaults to ':memory:' (in-memory database) + sessions_table: Name of the table to store session metadata. Defaults to 'agent_sessions' + messages_table: Name of the table to store message data. Defaults to 'agent_messages' + """ + self.session_id = session_id + self.db_path = db_path + self.sessions_table = sessions_table + self.messages_table = messages_table + self._local = threading.local() + self._lock = threading.Lock() + + # For in-memory databases, we need a shared connection to avoid thread isolation + # For file databases, we use thread-local connections for better concurrency + self._is_memory_db = str(db_path) == ":memory:" + if self._is_memory_db: + self._shared_connection = sqlite3.connect( + ":memory:", check_same_thread=False + ) + self._shared_connection.execute("PRAGMA journal_mode=WAL") + self._init_db_for_connection(self._shared_connection) + else: + # For file databases, initialize the schema once since it persists + init_conn = sqlite3.connect(str(self.db_path), check_same_thread=False) + init_conn.execute("PRAGMA journal_mode=WAL") + self._init_db_for_connection(init_conn) + init_conn.close() + + def _get_connection(self) -> sqlite3.Connection: + """Get a database connection.""" + if self._is_memory_db: + # Use shared connection for in-memory database to avoid thread isolation + return self._shared_connection + else: + # Use thread-local connections for file databases + if not hasattr(self._local, "connection"): + self._local.connection = sqlite3.connect( + str(self.db_path), + check_same_thread=False, + ) + self._local.connection.execute("PRAGMA journal_mode=WAL") + return self._local.connection + + def _init_db_for_connection(self, conn: sqlite3.Connection) -> None: + """Initialize the database schema for a specific connection.""" + conn.execute( + f""" + CREATE TABLE IF NOT EXISTS {self.sessions_table} ( + session_id TEXT PRIMARY KEY, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """ + ) + + conn.execute( + f""" + CREATE TABLE IF NOT EXISTS {self.messages_table} ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + session_id TEXT NOT NULL, + message_data TEXT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (session_id) REFERENCES {self.sessions_table} (session_id) ON DELETE CASCADE + ) + """ + ) + + conn.execute( + f""" + CREATE INDEX IF NOT EXISTS idx_{self.messages_table}_session_id + ON {self.messages_table} (session_id, created_at) + """ + ) + + conn.commit() + + async def get_messages(self, amount: int | None = None) -> list[TResponseInputItem]: + """Retrieve the conversation history for this session. + + Args: + amount: Maximum number of messages to retrieve. If None, retrieves all messages. + When specified, returns the latest N messages in chronological order. + + Returns: + List of input items representing the conversation history + """ + + def _get_messages_sync(): + conn = self._get_connection() + with self._lock if self._is_memory_db else threading.Lock(): + if amount is None: + # Fetch all messages in chronological order + cursor = conn.execute( + f""" + SELECT message_data FROM {self.messages_table} + WHERE session_id = ? + ORDER BY created_at ASC + """, + (self.session_id,), + ) + else: + # Fetch the latest N messages in chronological order + # First get the total count to calculate offset + count_cursor = conn.execute( + f""" + SELECT COUNT(*) FROM {self.messages_table} + WHERE session_id = ? + """, + (self.session_id,), + ) + total_count = count_cursor.fetchone()[0] + + # Calculate offset to get the latest N messages + offset = max(0, total_count - amount) + + cursor = conn.execute( + f""" + SELECT message_data FROM {self.messages_table} + WHERE session_id = ? + ORDER BY created_at ASC + LIMIT ? OFFSET ? + """, + (self.session_id, amount, offset), + ) + + messages = [] + for (message_data,) in cursor.fetchall(): + try: + message = json.loads(message_data) + messages.append(message) + except json.JSONDecodeError: + # Skip invalid JSON entries + continue + + return messages + + return await asyncio.to_thread(_get_messages_sync) + + async def add_messages(self, messages: list[TResponseInputItem]) -> None: + """Add new messages to the conversation history. + + Args: + messages: List of input items to add to the history + """ + if not messages: + return + + def _add_messages_sync(): + conn = self._get_connection() + + with self._lock if self._is_memory_db else threading.Lock(): + # Ensure session exists + conn.execute( + f""" + INSERT OR IGNORE INTO {self.sessions_table} (session_id) VALUES (?) + """, + (self.session_id,), + ) + + # Add messages + message_data = [ + (self.session_id, json.dumps(message)) for message in messages + ] + conn.executemany( + f""" + INSERT INTO {self.messages_table} (session_id, message_data) VALUES (?, ?) + """, + message_data, + ) + + # Update session timestamp + conn.execute( + f""" + UPDATE {self.sessions_table} SET updated_at = CURRENT_TIMESTAMP WHERE session_id = ? + """, + (self.session_id,), + ) + + conn.commit() + + await asyncio.to_thread(_add_messages_sync) + + async def pop_message(self) -> TResponseInputItem | None: + """Remove and return the most recent message from the session. + + Returns: + The most recent message if it exists, None if the session is empty + """ + + def _pop_message_sync(): + conn = self._get_connection() + with self._lock if self._is_memory_db else threading.Lock(): + cursor = conn.execute( + f""" + SELECT id, message_data FROM {self.messages_table} + WHERE session_id = ? + ORDER BY created_at DESC + LIMIT 1 + """, + (self.session_id,), + ) + + result = cursor.fetchone() + if result: + message_id, message_data = result + try: + message = json.loads(message_data) + # Delete the message by ID + conn.execute( + f""" + DELETE FROM {self.messages_table} WHERE id = ? + """, + (message_id,), + ) + conn.commit() + return message + except json.JSONDecodeError: + # Skip invalid JSON entries, but still delete the corrupted record + conn.execute( + f""" + DELETE FROM {self.messages_table} WHERE id = ? + """, + (message_id,), + ) + conn.commit() + return None + + return None + + return await asyncio.to_thread(_pop_message_sync) + + async def clear_session(self) -> None: + """Clear all messages for this session.""" + + def _clear_session_sync(): + conn = self._get_connection() + with self._lock if self._is_memory_db else threading.Lock(): + conn.execute( + f"DELETE FROM {self.messages_table} WHERE session_id = ?", + (self.session_id,), + ) + conn.execute( + f"DELETE FROM {self.sessions_table} WHERE session_id = ?", + (self.session_id,), + ) + conn.commit() + + await asyncio.to_thread(_clear_session_sync) + + def close(self) -> None: + """Close the database connection.""" + if self._is_memory_db: + if hasattr(self, "_shared_connection"): + self._shared_connection.close() + else: + if hasattr(self._local, "connection"): + self._local.connection.close() diff --git a/src/agents/run.py b/src/agents/run.py index b196c3bf..2a96c82f 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -27,11 +27,17 @@ ModelBehaviorError, OutputGuardrailTripwireTriggered, ) -from .guardrail import InputGuardrail, InputGuardrailResult, OutputGuardrail, OutputGuardrailResult +from .guardrail import ( + InputGuardrail, + InputGuardrailResult, + OutputGuardrail, + OutputGuardrailResult, +) from .handoffs import Handoff, HandoffInputFilter, handoff from .items import ItemHelpers, ModelResponse, RunItem, TResponseInputItem from .lifecycle import RunHooks from .logger import logger +from .memory import Session from .model_settings import ModelSettings from .models.interface import Model, ModelProvider from .models.multi_provider import MultiProvider @@ -118,6 +124,7 @@ async def run( hooks: RunHooks[TContext] | None = None, run_config: RunConfig | None = None, previous_response_id: str | None = None, + session: Session | None = None, ) -> RunResult: """Run a workflow starting at the given agent. The agent will run in a loop until a final output is generated. The loop runs like so: @@ -144,6 +151,8 @@ async def run( run_config: Global settings for the entire agent run. previous_response_id: The ID of the previous response, if using OpenAI models via the Responses API, this allows you to skip passing in input from the previous turn. + session: Session instance for conversation history persistence. + If None, no conversation history will be maintained. Returns: A run result containing all the inputs, guardrail results and the output of the last @@ -154,6 +163,9 @@ async def run( if run_config is None: run_config = RunConfig() + # Prepare input with session if enabled + prepared_input = await cls._prepare_input_with_session(input, session) + tool_use_tracker = AgentToolUseTracker() with TraceCtxManager( @@ -164,7 +176,9 @@ async def run( disabled=run_config.tracing_disabled, ): current_turn = 0 - original_input: str | list[TResponseInputItem] = copy.deepcopy(input) + original_input: str | list[TResponseInputItem] = copy.deepcopy( + prepared_input + ) generated_items: list[RunItem] = [] model_responses: list[ModelResponse] = [] @@ -183,7 +197,9 @@ async def run( # Start an agent span if we don't have one. This span is ended if the current # agent changes, or if the agent loop ends. if current_span is None: - handoff_names = [h.agent_name for h in cls._get_handoffs(current_agent)] + handoff_names = [ + h.agent_name for h in cls._get_handoffs(current_agent) + ] if output_schema := cls._get_output_schema(current_agent): output_type_name = output_schema.name() else: @@ -220,7 +236,7 @@ async def run( starting_agent, starting_agent.input_guardrails + (run_config.input_guardrails or []), - copy.deepcopy(input), + copy.deepcopy(prepared_input), context_wrapper, ), cls._run_single_turn( @@ -257,12 +273,13 @@ async def run( if isinstance(turn_result.next_step, NextStepFinalOutput): output_guardrail_results = await cls._run_output_guardrails( - current_agent.output_guardrails + (run_config.output_guardrails or []), + current_agent.output_guardrails + + (run_config.output_guardrails or []), current_agent, turn_result.next_step.output, context_wrapper, ) - return RunResult( + result = RunResult( input=original_input, new_items=generated_items, raw_responses=model_responses, @@ -272,8 +289,15 @@ async def run( output_guardrail_results=output_guardrail_results, context_wrapper=context_wrapper, ) + + # Save the conversation to session if enabled + await cls._save_result_to_session(session, input, result) + + return result elif isinstance(turn_result.next_step, NextStepHandoff): - current_agent = cast(Agent[TContext], turn_result.next_step.new_agent) + current_agent = cast( + Agent[TContext], turn_result.next_step.new_agent + ) current_span.finish(reset_current=True) current_span = None should_run_agent_start_hooks = True @@ -298,6 +322,7 @@ def run_sync( hooks: RunHooks[TContext] | None = None, run_config: RunConfig | None = None, previous_response_id: str | None = None, + session: Session | None = None, ) -> RunResult: """Run a workflow synchronously, starting at the given agent. Note that this just wraps the `run` method, so it will not work if there's already an event loop (e.g. inside an async @@ -328,6 +353,8 @@ def run_sync( run_config: Global settings for the entire agent run. previous_response_id: The ID of the previous response, if using OpenAI models via the Responses API, this allows you to skip passing in input from the previous turn. + session: Session instance for conversation history persistence. + If None, no conversation history will be maintained. Returns: A run result containing all the inputs, guardrail results and the output of the last @@ -342,6 +369,7 @@ def run_sync( hooks=hooks, run_config=run_config, previous_response_id=previous_response_id, + session=session, ) ) @@ -355,6 +383,7 @@ def run_streamed( hooks: RunHooks[TContext] | None = None, run_config: RunConfig | None = None, previous_response_id: str | None = None, + session: Session | None = None, ) -> RunResultStreaming: """Run a workflow starting at the given agent in streaming mode. The returned result object contains a method you can use to stream semantic events as they are generated. @@ -383,6 +412,8 @@ def run_streamed( run_config: Global settings for the entire agent run. previous_response_id: The ID of the previous response, if using OpenAI models via the Responses API, this allows you to skip passing in input from the previous turn. + session: Session instance for conversation history persistence. + If None, no conversation history will be maintained. Returns: A result object that contains data about the run, as well as a method to stream events. """ @@ -438,6 +469,7 @@ def run_streamed( context_wrapper=context_wrapper, run_config=run_config, previous_response_id=previous_response_id, + session=session, ) ) return streamed_result @@ -496,138 +528,188 @@ async def _run_streamed_impl( context_wrapper: RunContextWrapper[TContext], run_config: RunConfig, previous_response_id: str | None, + session: Session | None, ): - if streamed_result.trace: - streamed_result.trace.start(mark_as_current=True) - current_span: Span[AgentSpanData] | None = None - current_agent = starting_agent - current_turn = 0 - should_run_agent_start_hooks = True - tool_use_tracker = AgentToolUseTracker() - - streamed_result._event_queue.put_nowait(AgentUpdatedStreamEvent(new_agent=current_agent)) try: - while True: - if streamed_result.is_complete: - break - - # Start an agent span if we don't have one. This span is ended if the current - # agent changes, or if the agent loop ends. - if current_span is None: - handoff_names = [h.agent_name for h in cls._get_handoffs(current_agent)] - if output_schema := cls._get_output_schema(current_agent): - output_type_name = output_schema.name() - else: - output_type_name = "str" + if streamed_result.trace: + streamed_result.trace.start(mark_as_current=True) - current_span = agent_span( - name=current_agent.name, - handoffs=handoff_names, - output_type=output_type_name, - ) - current_span.start(mark_as_current=True) + # Prepare input with session if enabled + prepared_input = await cls._prepare_input_with_session( + starting_input, session + ) - all_tools = await cls._get_all_tools(current_agent) - tool_names = [t.name for t in all_tools] - current_span.span_data.tools = tool_names - current_turn += 1 - streamed_result.current_turn = current_turn + # Update the streamed result with the prepared input + streamed_result.input = prepared_input - if current_turn > max_turns: - _error_tracing.attach_error_to_span( - current_span, - SpanError( - message="Max turns exceeded", - data={"max_turns": max_turns}, - ), - ) - streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) - break - - if current_turn == 1: - # Run the input guardrails in the background and put the results on the queue - streamed_result._input_guardrails_task = asyncio.create_task( - cls._run_input_guardrails_with_queue( - starting_agent, - starting_agent.input_guardrails + (run_config.input_guardrails or []), - copy.deepcopy(ItemHelpers.input_to_new_input_list(starting_input)), - context_wrapper, - streamed_result, - current_span, + current_agent = starting_agent + current_turn = 0 + should_run_agent_start_hooks = True + tool_use_tracker = AgentToolUseTracker() + + streamed_result._event_queue.put_nowait( + AgentUpdatedStreamEvent(new_agent=current_agent) + ) + + try: + while True: + if streamed_result.is_complete: + break + + # Start an agent span if we don't have one. This span is ended if the current + # agent changes, or if the agent loop ends. + if current_span is None: + handoff_names = [ + h.agent_name for h in cls._get_handoffs(current_agent) + ] + if output_schema := cls._get_output_schema(current_agent): + output_type_name = output_schema.name() + else: + output_type_name = "str" + + current_span = agent_span( + name=current_agent.name, + handoffs=handoff_names, + output_type=output_type_name, ) - ) - try: - turn_result = await cls._run_single_turn_streamed( - streamed_result, - current_agent, - hooks, - context_wrapper, - run_config, - should_run_agent_start_hooks, - tool_use_tracker, - all_tools, - previous_response_id, - ) - should_run_agent_start_hooks = False + current_span.start(mark_as_current=True) - streamed_result.raw_responses = streamed_result.raw_responses + [ - turn_result.model_response - ] - streamed_result.input = turn_result.original_input - streamed_result.new_items = turn_result.generated_items + all_tools = await cls._get_all_tools(current_agent) + tool_names = [t.name for t in all_tools] + current_span.span_data.tools = tool_names + current_turn += 1 + streamed_result.current_turn = current_turn - if isinstance(turn_result.next_step, NextStepHandoff): - current_agent = turn_result.next_step.new_agent - current_span.finish(reset_current=True) - current_span = None - should_run_agent_start_hooks = True - streamed_result._event_queue.put_nowait( - AgentUpdatedStreamEvent(new_agent=current_agent) + if current_turn > max_turns: + _error_tracing.attach_error_to_span( + current_span, + SpanError( + message="Max turns exceeded", + data={"max_turns": max_turns}, + ), ) - elif isinstance(turn_result.next_step, NextStepFinalOutput): - streamed_result._output_guardrails_task = asyncio.create_task( - cls._run_output_guardrails( - current_agent.output_guardrails - + (run_config.output_guardrails or []), - current_agent, - turn_result.next_step.output, + streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) + break + + if current_turn == 1: + # Run the input guardrails in the background and put the results on the queue + streamed_result._input_guardrails_task = asyncio.create_task( + cls._run_input_guardrails_with_queue( + starting_agent, + starting_agent.input_guardrails + + (run_config.input_guardrails or []), + copy.deepcopy( + ItemHelpers.input_to_new_input_list(prepared_input) + ), context_wrapper, + streamed_result, + current_span, ) ) + try: + turn_result = await cls._run_single_turn_streamed( + streamed_result, + current_agent, + hooks, + context_wrapper, + run_config, + should_run_agent_start_hooks, + tool_use_tracker, + all_tools, + previous_response_id, + ) + should_run_agent_start_hooks = False + + streamed_result.raw_responses = ( + streamed_result.raw_responses + [turn_result.model_response] + ) + streamed_result.input = turn_result.original_input + streamed_result.new_items = turn_result.generated_items + + if isinstance(turn_result.next_step, NextStepHandoff): + current_agent = turn_result.next_step.new_agent + current_span.finish(reset_current=True) + current_span = None + should_run_agent_start_hooks = True + streamed_result._event_queue.put_nowait( + AgentUpdatedStreamEvent(new_agent=current_agent) + ) + elif isinstance(turn_result.next_step, NextStepFinalOutput): + streamed_result._output_guardrails_task = ( + asyncio.create_task( + cls._run_output_guardrails( + current_agent.output_guardrails + + (run_config.output_guardrails or []), + current_agent, + turn_result.next_step.output, + context_wrapper, + ) + ) + ) - try: - output_guardrail_results = await streamed_result._output_guardrails_task - except Exception: - # Exceptions will be checked in the stream_events loop - output_guardrail_results = [] + try: + output_guardrail_results = ( + await streamed_result._output_guardrails_task + ) + except Exception: + # Exceptions will be checked in the stream_events loop + output_guardrail_results = [] - streamed_result.output_guardrail_results = output_guardrail_results - streamed_result.final_output = turn_result.next_step.output + streamed_result.output_guardrail_results = ( + output_guardrail_results + ) + streamed_result.final_output = turn_result.next_step.output + streamed_result.is_complete = True + + # Save the conversation to session if enabled + # Create a temporary RunResult for session saving + temp_result = RunResult( + input=streamed_result.input, + new_items=streamed_result.new_items, + raw_responses=streamed_result.raw_responses, + final_output=streamed_result.final_output, + _last_agent=current_agent, + input_guardrail_results=streamed_result.input_guardrail_results, + output_guardrail_results=streamed_result.output_guardrail_results, + context_wrapper=context_wrapper, + ) + await cls._save_result_to_session( + session, starting_input, temp_result + ) + + streamed_result._event_queue.put_nowait( + QueueCompleteSentinel() + ) + elif isinstance(turn_result.next_step, NextStepRunAgain): + pass + except Exception as e: + if current_span: + _error_tracing.attach_error_to_span( + current_span, + SpanError( + message="Error in agent run", + data={"error": str(e)}, + ), + ) streamed_result.is_complete = True streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) - elif isinstance(turn_result.next_step, NextStepRunAgain): - pass - except Exception as e: - if current_span: - _error_tracing.attach_error_to_span( - current_span, - SpanError( - message="Error in agent run", - data={"error": str(e)}, - ), - ) - streamed_result.is_complete = True - streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) - raise + raise + streamed_result.is_complete = True + finally: + if current_span: + current_span.finish(reset_current=True) + if streamed_result.trace: + streamed_result.trace.finish(reset_current=True) + + except Exception: + # Ensure that any exception (including those during setup) results in a completion sentinel + # being put in the queue so that stream_events() doesn't hang streamed_result.is_complete = True - finally: - if current_span: - current_span.finish(reset_current=True) - if streamed_result.trace: - streamed_result.trace.finish(reset_current=True) + streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) + raise @classmethod async def _run_single_turn_streamed( @@ -662,7 +744,9 @@ async def _run_single_turn_streamed( handoffs = cls._get_handoffs(agent) model = cls._get_model(agent, run_config) model_settings = agent.model_settings.resolve(run_config.model_settings) - model_settings = RunImpl.maybe_reset_tool_choice(agent, tool_use_tracker, model_settings) + model_settings = RunImpl.maybe_reset_tool_choice( + agent, tool_use_tracker, model_settings + ) final_response: ModelResponse | None = None @@ -723,7 +807,9 @@ async def _run_single_turn_streamed( tool_use_tracker=tool_use_tracker, ) - RunImpl.stream_step_result_to_queue(single_step_result, streamed_result._event_queue) + RunImpl.stream_step_result_to_queue( + single_step_result, streamed_result._event_queue + ) return single_step_result @classmethod @@ -757,7 +843,9 @@ async def _run_single_turn( output_schema = cls._get_output_schema(agent) handoffs = cls._get_handoffs(agent) input = ItemHelpers.input_to_new_input_list(original_input) - input.extend([generated_item.to_input_item() for generated_item in generated_items]) + input.extend( + [generated_item.to_input_item() for generated_item in generated_items] + ) new_response = await cls._get_new_response( agent, @@ -875,7 +963,9 @@ async def _run_output_guardrails( guardrail_tasks = [ asyncio.create_task( - RunImpl.run_single_output_guardrail(guardrail, agent, agent_output, context) + RunImpl.run_single_output_guardrail( + guardrail, agent, agent_output, context + ) ) for guardrail in guardrails ] @@ -916,7 +1006,9 @@ async def _get_new_response( ) -> ModelResponse: model = cls._get_model(agent, run_config) model_settings = agent.model_settings.resolve(run_config.model_settings) - model_settings = RunImpl.maybe_reset_tool_choice(agent, tool_use_tracker, model_settings) + model_settings = RunImpl.maybe_reset_tool_choice( + agent, tool_use_tracker, model_settings + ) new_response = await model.get_response( system_instructions=system_prompt, @@ -968,3 +1060,45 @@ def _get_model(cls, agent: Agent[Any], run_config: RunConfig) -> Model: return agent.model return run_config.model_provider.get_model(agent.model) + + @classmethod + async def _prepare_input_with_session( + cls, + input: str | list[TResponseInputItem], + session: Session | None, + ) -> str | list[TResponseInputItem]: + """Prepare input by combining it with session history if enabled.""" + if session is None: + return input + + # Get previous conversation history + history = await session.get_messages() + + # Convert input to list format + new_input_list = ItemHelpers.input_to_new_input_list(input) + + # Combine history with new input + combined_input = history + new_input_list + + return combined_input + + @classmethod + async def _save_result_to_session( + cls, + session: Session | None, + original_input: str | list[TResponseInputItem], + result: RunResult, + ) -> None: + """Save the conversation turn to session.""" + if session is None: + return + + # Convert original input to list format if needed + input_list = ItemHelpers.input_to_new_input_list(original_input) + + # Convert new items to input format + new_items_as_input = [item.to_input_item() for item in result.new_items] + + # Save all messages from this turn + messages_to_save = input_list + new_items_as_input + await session.add_messages(messages_to_save) diff --git a/tests/test_session.py b/tests/test_session.py new file mode 100644 index 00000000..41418a72 --- /dev/null +++ b/tests/test_session.py @@ -0,0 +1,489 @@ +"""Tests for session memory functionality.""" + +import pytest +import tempfile +from pathlib import Path +import asyncio + +from agents import Agent, Runner, SQLiteSession + +from .fake_model import FakeModel +from .test_responses import get_text_message + + +# Helper functions for parametrized testing of different Runner methods +def _run_sync_wrapper(agent, input_data, **kwargs): + """Wrapper for run_sync that properly sets up an event loop.""" + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + return Runner.run_sync(agent, input_data, **kwargs) + finally: + loop.close() + + +async def run_agent_async(runner_method: str, agent, input_data, **kwargs): + """Helper function to run agent with different methods.""" + if runner_method == "run": + return await Runner.run(agent, input_data, **kwargs) + elif runner_method == "run_sync": + # For run_sync, we need to run it in a thread with its own event loop + return await asyncio.to_thread(_run_sync_wrapper, agent, input_data, **kwargs) + elif runner_method == "run_streamed": + result = Runner.run_streamed(agent, input_data, **kwargs) + # For streaming, we first try to get at least one event to trigger any early exceptions + # If there's an exception in setup (like memory validation), it will be raised here + try: + first_event = None + async for event in result.stream_events(): + if first_event is None: + first_event = event + # Continue consuming all events + pass + except Exception: + # If an exception occurs during streaming, we let it propagate up + raise + return result + else: + raise ValueError(f"Unknown runner method: {runner_method}") + + +# Parametrized tests for different runner methods +@pytest.mark.parametrize("runner_method", ["run", "run_sync", "run_streamed"]) +@pytest.mark.asyncio +async def test_session_memory_basic_functionality_parametrized(runner_method): + """Test basic session memory functionality with SQLite backend across all runner methods.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_memory.db" + session_id = "test_session_123" + session = SQLiteSession(session_id, db_path) + + model = FakeModel() + agent = Agent(name="test", model=model) + + # First turn + model.set_next_output([get_text_message("San Francisco")]) + result1 = await run_agent_async( + runner_method, + agent, + "What city is the Golden Gate Bridge in?", + session=session, + ) + assert result1.final_output == "San Francisco" + + # Second turn - should have conversation history + model.set_next_output([get_text_message("California")]) + result2 = await run_agent_async( + runner_method, + agent, + "What state is it in?", + session=session, + ) + assert result2.final_output == "California" + + # Verify that the input to the second turn includes the previous conversation + # The model should have received the full conversation history + last_input = model.last_turn_args["input"] + assert len(last_input) > 1 # Should have more than just the current message + + session.close() + + +@pytest.mark.parametrize("runner_method", ["run", "run_sync", "run_streamed"]) +@pytest.mark.asyncio +async def test_session_memory_with_explicit_instance_parametrized(runner_method): + """Test session memory with an explicit SQLiteSession instance across all runner methods.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_memory.db" + session_id = "test_session_456" + session = SQLiteSession(session_id, db_path) + + model = FakeModel() + agent = Agent(name="test", model=model) + + # First turn + model.set_next_output([get_text_message("Hello")]) + result1 = await run_agent_async( + runner_method, agent, "Hi there", session=session + ) + assert result1.final_output == "Hello" + + # Second turn + model.set_next_output([get_text_message("I remember you said hi")]) + result2 = await run_agent_async( + runner_method, + agent, + "Do you remember what I said?", + session=session, + ) + assert result2.final_output == "I remember you said hi" + + session.close() + + +@pytest.mark.parametrize("runner_method", ["run", "run_sync", "run_streamed"]) +@pytest.mark.asyncio +async def test_session_memory_disabled_parametrized(runner_method): + """Test that session memory is disabled when session=None across all runner methods.""" + model = FakeModel() + agent = Agent(name="test", model=model) + + # First turn (no session parameters = disabled) + model.set_next_output([get_text_message("Hello")]) + result1 = await run_agent_async(runner_method, agent, "Hi there") + assert result1.final_output == "Hello" + + # Second turn - should NOT have conversation history + model.set_next_output([get_text_message("I don't remember")]) + result2 = await run_agent_async( + runner_method, agent, "Do you remember what I said?" + ) + assert result2.final_output == "I don't remember" + + # Verify that the input to the second turn is just the current message + last_input = model.last_turn_args["input"] + assert len(last_input) == 1 # Should only have the current message + + +@pytest.mark.parametrize("runner_method", ["run", "run_sync", "run_streamed"]) +@pytest.mark.asyncio +async def test_session_memory_different_sessions_parametrized(runner_method): + """Test that different session IDs maintain separate conversation histories across all runner methods.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_memory.db" + + model = FakeModel() + agent = Agent(name="test", model=model) + + # Session 1 + session_id_1 = "session_1" + session_1 = SQLiteSession(session_id_1, db_path) + + model.set_next_output([get_text_message("I like cats")]) + result1 = await run_agent_async( + runner_method, agent, "I like cats", session=session_1 + ) + assert result1.final_output == "I like cats" + + # Session 2 - different session + session_id_2 = "session_2" + session_2 = SQLiteSession(session_id_2, db_path) + + model.set_next_output([get_text_message("I like dogs")]) + result2 = await run_agent_async( + runner_method, agent, "I like dogs", session=session_2 + ) + assert result2.final_output == "I like dogs" + + # Back to Session 1 - should remember cats, not dogs + model.set_next_output([get_text_message("Yes, you mentioned cats")]) + result3 = await run_agent_async( + runner_method, + agent, + "What did I say I like?", + session=session_1, + ) + assert result3.final_output == "Yes, you mentioned cats" + + session_1.close() + session_2.close() + + +@pytest.mark.asyncio +async def test_sqlite_session_memory_direct(): + """Test SQLiteSession class directly.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_direct.db" + session_id = "direct_test" + session = SQLiteSession(session_id, db_path) + + # Test adding and retrieving messages + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ] + + await session.add_messages(messages) + retrieved = await session.get_messages() + + assert len(retrieved) == 2 + assert retrieved[0]["role"] == "user" + assert retrieved[0]["content"] == "Hello" + assert retrieved[1]["role"] == "assistant" + assert retrieved[1]["content"] == "Hi there!" + + # Test clearing session + await session.clear_session() + retrieved_after_clear = await session.get_messages() + assert len(retrieved_after_clear) == 0 + + session.close() + + +@pytest.mark.asyncio +async def test_sqlite_session_memory_pop_message(): + """Test SQLiteSession pop_message functionality.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_pop.db" + session_id = "pop_test" + session = SQLiteSession(session_id, db_path) + + # Test popping from empty session + popped = await session.pop_message() + assert popped is None + + # Add messages + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + {"role": "user", "content": "How are you?"}, + ] + + await session.add_messages(messages) + + # Verify all messages are there + retrieved = await session.get_messages() + assert len(retrieved) == 3 + + # Pop the most recent message + popped = await session.pop_message() + assert popped is not None + assert popped["role"] == "user" + assert popped["content"] == "How are you?" + + # Verify message was removed + retrieved_after_pop = await session.get_messages() + assert len(retrieved_after_pop) == 2 + assert retrieved_after_pop[-1]["content"] == "Hi there!" + + # Pop another message + popped2 = await session.pop_message() + assert popped2 is not None + assert popped2["role"] == "assistant" + assert popped2["content"] == "Hi there!" + + # Pop the last message + popped3 = await session.pop_message() + assert popped3 is not None + assert popped3["role"] == "user" + assert popped3["content"] == "Hello" + + # Try to pop from empty session again + popped4 = await session.pop_message() + assert popped4 is None + + # Verify session is empty + final_messages = await session.get_messages() + assert len(final_messages) == 0 + + session.close() + + +@pytest.mark.asyncio +async def test_session_memory_pop_different_sessions(): + """Test that pop_message only affects the specified session.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_pop_sessions.db" + + session_1_id = "session_1" + session_2_id = "session_2" + session_1 = SQLiteSession(session_1_id, db_path) + session_2 = SQLiteSession(session_2_id, db_path) + + # Add messages to both sessions + messages_1 = [ + {"role": "user", "content": "Session 1 message"}, + ] + messages_2 = [ + {"role": "user", "content": "Session 2 message 1"}, + {"role": "user", "content": "Session 2 message 2"}, + ] + + await session_1.add_messages(messages_1) + await session_2.add_messages(messages_2) + + # Pop from session 2 + popped = await session_2.pop_message() + assert popped is not None + assert popped["content"] == "Session 2 message 2" + + # Verify session 1 is unaffected + session_1_messages = await session_1.get_messages() + assert len(session_1_messages) == 1 + assert session_1_messages[0]["content"] == "Session 1 message" + + # Verify session 2 has one message left + session_2_messages = await session_2.get_messages() + assert len(session_2_messages) == 1 + assert session_2_messages[0]["content"] == "Session 2 message 1" + + session_1.close() + session_2.close() + + +@pytest.mark.asyncio +async def test_sqlite_session_get_messages_with_amount(): + """Test SQLiteSession get_messages with amount parameter.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_amount.db" + session_id = "amount_test" + session = SQLiteSession(session_id, db_path) + + # Add multiple messages + messages = [ + {"role": "user", "content": "Message 1"}, + {"role": "assistant", "content": "Response 1"}, + {"role": "user", "content": "Message 2"}, + {"role": "assistant", "content": "Response 2"}, + {"role": "user", "content": "Message 3"}, + {"role": "assistant", "content": "Response 3"}, + ] + + await session.add_messages(messages) + + # Test getting all messages (default behavior) + all_messages = await session.get_messages() + assert len(all_messages) == 6 + assert all_messages[0]["content"] == "Message 1" + assert all_messages[-1]["content"] == "Response 3" + + # Test getting latest 2 messages + latest_2 = await session.get_messages(amount=2) + assert len(latest_2) == 2 + assert latest_2[0]["content"] == "Message 3" + assert latest_2[1]["content"] == "Response 3" + + # Test getting latest 4 messages + latest_4 = await session.get_messages(amount=4) + assert len(latest_4) == 4 + assert latest_4[0]["content"] == "Message 2" + assert latest_4[1]["content"] == "Response 2" + assert latest_4[2]["content"] == "Message 3" + assert latest_4[3]["content"] == "Response 3" + + # Test getting more messages than available + latest_10 = await session.get_messages(amount=10) + assert len(latest_10) == 6 # Should return all available messages + assert latest_10[0]["content"] == "Message 1" + assert latest_10[-1]["content"] == "Response 3" + + # Test getting 0 messages + latest_0 = await session.get_messages(amount=0) + assert len(latest_0) == 0 + + session.close() + + +# Original non-parametrized tests for backwards compatibility +@pytest.mark.asyncio +async def test_session_memory_basic_functionality(): + """Test basic session memory functionality with SQLite backend.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_memory.db" + session_id = "test_session_123" + session = SQLiteSession(session_id, db_path) + + model = FakeModel() + agent = Agent(name="test", model=model) + + # First turn + model.set_next_output([get_text_message("San Francisco")]) + result1 = await Runner.run( + agent, + "What city is the Golden Gate Bridge in?", + session=session, + ) + assert result1.final_output == "San Francisco" + + # Second turn - should have conversation history + model.set_next_output([get_text_message("California")]) + result2 = await Runner.run(agent, "What state is it in?", session=session) + assert result2.final_output == "California" + + # Verify that the input to the second turn includes the previous conversation + # The model should have received the full conversation history + last_input = model.last_turn_args["input"] + assert len(last_input) > 1 # Should have more than just the current message + + session.close() + + +@pytest.mark.asyncio +async def test_session_memory_with_explicit_instance(): + """Test session memory with an explicit SQLiteSession instance.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_memory.db" + session_id = "test_session_456" + session = SQLiteSession(session_id, db_path) + + model = FakeModel() + agent = Agent(name="test", model=model) + + # First turn + model.set_next_output([get_text_message("Hello")]) + result1 = await Runner.run(agent, "Hi there", session=session) + assert result1.final_output == "Hello" + + # Second turn + model.set_next_output([get_text_message("I remember you said hi")]) + result2 = await Runner.run( + agent, "Do you remember what I said?", session=session + ) + assert result2.final_output == "I remember you said hi" + + session.close() + + +@pytest.mark.asyncio +async def test_session_memory_disabled(): + """Test that session memory is disabled when session=None.""" + model = FakeModel() + agent = Agent(name="test", model=model) + + # First turn (no session parameters = disabled) + model.set_next_output([get_text_message("Hello")]) + result1 = await Runner.run(agent, "Hi there") + assert result1.final_output == "Hello" + + # Second turn - should NOT have conversation history + model.set_next_output([get_text_message("I don't remember")]) + result2 = await Runner.run(agent, "Do you remember what I said?") + assert result2.final_output == "I don't remember" + + # Verify that the input to the second turn is just the current message + last_input = model.last_turn_args["input"] + assert len(last_input) == 1 # Should only have the current message + + +@pytest.mark.asyncio +async def test_session_memory_different_sessions(): + """Test that different session IDs maintain separate conversation histories.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_memory.db" + + model = FakeModel() + agent = Agent(name="test", model=model) + + # Session 1 + session_id_1 = "session_1" + session_1 = SQLiteSession(session_id_1, db_path) + + model.set_next_output([get_text_message("I like cats")]) + result1 = await Runner.run(agent, "I like cats", session=session_1) + assert result1.final_output == "I like cats" + + # Session 2 - different session + session_id_2 = "session_2" + session_2 = SQLiteSession(session_id_2, db_path) + + model.set_next_output([get_text_message("I like dogs")]) + result2 = await Runner.run(agent, "I like dogs", session=session_2) + assert result2.final_output == "I like dogs" + + # Back to Session 1 - should remember cats, not dogs + model.set_next_output([get_text_message("Yes, you mentioned cats")]) + result3 = await Runner.run(agent, "What did I say I like?", session=session_1) + assert result3.final_output == "Yes, you mentioned cats" + + session_1.close() + session_2.close()