From d849c939d44ee7040f12012edd636791edf524ba Mon Sep 17 00:00:00 2001 From: Stephan Fitzpatrick Date: Fri, 23 May 2025 22:13:46 -0700 Subject: [PATCH 01/26] Add session memory functionality to the Agents SDK ## Summary - Introduced `SessionMemory` and `SQLiteSessionMemory` classes for automatic conversation history management. - Updated `Agent` class to support session memory configuration. - Enhanced `Runner` class to handle input preparation and result saving with session memory. - Added example demonstrating session memory usage. - Implemented tests for session memory functionality. ## Testing - `make format` - `make lint` - `make mypy` - `make tests` --- README.md | 53 +++++- examples/basic/session_memory_example.py | 62 +++++++ src/agents/__init__.py | 7 +- src/agents/agent.py | 35 +++- src/agents/memory/__init__.py | 3 + src/agents/memory/session_memory.py | 202 +++++++++++++++++++++++ src/agents/run.py | 173 +++++++++++++++++-- tests/test_session_memory.py | 191 +++++++++++++++++++++ 8 files changed, 700 insertions(+), 26 deletions(-) create mode 100644 examples/basic/session_memory_example.py create mode 100644 src/agents/memory/__init__.py create mode 100644 src/agents/memory/session_memory.py create mode 100644 tests/test_session_memory.py diff --git a/README.md b/README.md index 7dcd97b3..25a72be0 100644 --- a/README.md +++ b/README.md @@ -9,10 +9,61 @@ The OpenAI Agents SDK is a lightweight yet powerful framework for building multi 1. [**Agents**](https://openai.github.io/openai-agents-python/agents): LLMs configured with instructions, tools, guardrails, and handoffs 2. [**Handoffs**](https://openai.github.io/openai-agents-python/handoffs/): A specialized tool call used by the Agents SDK for transferring control between agents 3. [**Guardrails**](https://openai.github.io/openai-agents-python/guardrails/): Configurable safety checks for input and output validation -4. [**Tracing**](https://openai.github.io/openai-agents-python/tracing/): Built-in tracking of agent runs, allowing you to view, debug and optimize your workflows +4. [**Session Memory**](#session-memory): Automatic conversation history management across agent runs +5. [**Tracing**](https://openai.github.io/openai-agents-python/tracing/): Built-in tracking of agent runs, allowing you to view, debug and optimize your workflows Explore the [examples](examples) directory to see the SDK in action, and read our [documentation](https://openai.github.io/openai-agents-python/) for more details. +## Session Memory + +The Agents SDK provides built-in session memory to automatically maintain conversation history across multiple agent runs, eliminating the need to manually handle `.to_input_list()` between turns. + +### Quick start + +```python +from agents import Agent, Runner, RunConfig + +# Create agent with session memory enabled +agent = Agent( + name="Assistant", + instructions="Reply very concisely.", + memory=True # Enable automatic session memory +) + +# Use session ID to maintain conversation history +run_config = RunConfig(session_id="conversation_123") + +# First turn +result = await Runner.run(agent, "What city is the Golden Gate Bridge in?", run_config=run_config) +print(result.final_output) # "San Francisco" + +# Second turn - agent automatically remembers previous context +result = await Runner.run(agent, "What state is it in?", run_config=run_config) +print(result.final_output) # "California" + +# Also works with synchronous runner +result = Runner.run_sync(agent, "What's the population?", run_config=run_config) +print(result.final_output) # "Approximately 39 million" +``` + +### Memory options + +- **`memory=None`** (default): No session memory +- **`memory=True`**: Use default in-memory SQLite session memory +- **`memory=SessionMemory`**: Use custom session memory implementation + +```python +from agents import SQLiteSessionMemory + +# Custom SQLite database file +memory = SQLiteSessionMemory("conversations.db") +agent = Agent(name="Assistant", memory=memory) + +# Different session IDs maintain separate conversation histories +run_config_1 = RunConfig(session_id="user_123") +run_config_2 = RunConfig(session_id="user_456") +``` + ## Get started 1. Set up your Python environment diff --git a/examples/basic/session_memory_example.py b/examples/basic/session_memory_example.py new file mode 100644 index 00000000..e3c9e467 --- /dev/null +++ b/examples/basic/session_memory_example.py @@ -0,0 +1,62 @@ +""" +Example demonstrating session memory functionality. + +This example shows how to use session memory to maintain conversation history +across multiple agent runs without manually handling .to_input_list(). +""" + +import asyncio +from agents import Agent, Runner, RunConfig + + +async def main(): + # Create an agent with session memory enabled + agent = Agent( + name="Assistant", + instructions="Reply very concisely.", + memory=True, # Enable default SQLite session memory + ) + + # Define a session ID for this conversation + session_id = "conversation_123" + + # Create run config with session ID + run_config = RunConfig(session_id=session_id) + + print("=== Session Memory Example ===") + print("The agent will remember previous messages automatically.\n") + + # First turn + print("First turn:") + print("User: What city is the Golden Gate Bridge in?") + result = await Runner.run( + agent, "What city is the Golden Gate Bridge in?", run_config=run_config + ) + print(f"Assistant: {result.final_output}") + print() + + # Second turn - the agent will remember the previous conversation + print("Second turn:") + print("User: What state is it in?") + result = await Runner.run(agent, "What state is it in?", run_config=run_config) + print(f"Assistant: {result.final_output}") + print() + + # Third turn - continuing the conversation + print("Third turn:") + print("User: What's the population of that state?") + result = await Runner.run( + agent, "What's the population of that state?", run_config=run_config + ) + print(f"Assistant: {result.final_output}") + print() + + print("=== Conversation Complete ===") + print("Notice how the agent remembered the context from previous turns!") + print( + "No need to manually handle .to_input_list() - session memory handles it automatically." + ) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/agents/__init__.py b/src/agents/__init__.py index 58949157..93653ef3 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -39,6 +39,7 @@ TResponseInputItem, ) from .lifecycle import AgentHooks, RunHooks +from .memory import SessionMemory, SQLiteSessionMemory from .model_settings import ModelSettings from .models.interface import Model, ModelProvider, ModelTracing from .models.openai_chatcompletions import OpenAIChatCompletionsModel @@ -127,7 +128,9 @@ def set_default_openai_key(key: str, use_for_tracing: bool = True) -> None: _config.set_default_openai_key(key, use_for_tracing) -def set_default_openai_client(client: AsyncOpenAI, use_for_tracing: bool = True) -> None: +def set_default_openai_client( + client: AsyncOpenAI, use_for_tracing: bool = True +) -> None: """Set the default OpenAI client to use for LLM requests and/or tracing. If provided, this client will be used instead of the default OpenAI client. @@ -202,6 +205,8 @@ def enable_verbose_stdout_logging(): "ItemHelpers", "RunHooks", "AgentHooks", + "SessionMemory", + "SQLiteSessionMemory", "RunContextWrapper", "TContext", "RunResult", diff --git a/src/agents/agent.py b/src/agents/agent.py index e22f579f..c9d3a50d 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -14,6 +14,7 @@ from .items import ItemHelpers from .logger import logger from .mcp import MCPUtil +from .memory import SessionMemory from .model_settings import ModelSettings from .models.interface import Model from .run_context import RunContextWrapper, TContext @@ -156,7 +157,9 @@ class Agent(Generic[TContext]): """ tool_use_behavior: ( - Literal["run_llm_again", "stop_on_first_tool"] | StopAtTools | ToolsToFinalOutputFunction + Literal["run_llm_again", "stop_on_first_tool"] + | StopAtTools + | ToolsToFinalOutputFunction ) = "run_llm_again" """This lets you configure how tool use is handled. - "run_llm_again": The default behavior. Tools are run, and then the LLM receives the results @@ -178,6 +181,17 @@ class Agent(Generic[TContext]): """Whether to reset the tool choice to the default value after a tool has been called. Defaults to True. This ensures that the agent doesn't enter an infinite loop of tool usage.""" + memory: bool | SessionMemory | None = None + """Session memory for maintaining conversation history across runs. + + - None: No session memory (default behavior) + - True: Use default SQLite-based session memory + - SessionMemory instance: Use custom session memory implementation + + When memory is enabled, the agent will automatically maintain conversation history + and you won't need to manually handle .to_input_list() between runs. + """ + def clone(self, **kwargs: Any) -> Agent[TContext]: """Make a copy of the agent, with the given arguments changed. For example, you could do: ``` @@ -209,7 +223,8 @@ def as_tool( """ @function_tool( - name_override=tool_name or _transforms.transform_string_function_style(self.name), + name_override=tool_name + or _transforms.transform_string_function_style(self.name), description_override=tool_description or "", ) async def run_agent(context: RunContextWrapper, input: str) -> str: @@ -227,7 +242,9 @@ async def run_agent(context: RunContextWrapper, input: str) -> str: return run_agent - async def get_system_prompt(self, run_context: RunContextWrapper[TContext]) -> str | None: + async def get_system_prompt( + self, run_context: RunContextWrapper[TContext] + ) -> str | None: """Get the system prompt for the agent.""" if isinstance(self.instructions, str): return self.instructions @@ -237,14 +254,20 @@ async def get_system_prompt(self, run_context: RunContextWrapper[TContext]) -> s else: return cast(str, self.instructions(run_context, self)) elif self.instructions is not None: - logger.error(f"Instructions must be a string or a function, got {self.instructions}") + logger.error( + f"Instructions must be a string or a function, got {self.instructions}" + ) return None async def get_mcp_tools(self) -> list[Tool]: """Fetches the available tools from the MCP servers.""" - convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False) - return await MCPUtil.get_all_function_tools(self.mcp_servers, convert_schemas_to_strict) + convert_schemas_to_strict = self.mcp_config.get( + "convert_schemas_to_strict", False + ) + return await MCPUtil.get_all_function_tools( + self.mcp_servers, convert_schemas_to_strict + ) async def get_all_tools(self) -> list[Tool]: """All agent tools, including MCP tools and function tools.""" diff --git a/src/agents/memory/__init__.py b/src/agents/memory/__init__.py new file mode 100644 index 00000000..59a83d2d --- /dev/null +++ b/src/agents/memory/__init__.py @@ -0,0 +1,3 @@ +from .session_memory import SessionMemory, SQLiteSessionMemory + +__all__ = ["SessionMemory", "SQLiteSessionMemory"] diff --git a/src/agents/memory/session_memory.py b/src/agents/memory/session_memory.py new file mode 100644 index 00000000..78a393dc --- /dev/null +++ b/src/agents/memory/session_memory.py @@ -0,0 +1,202 @@ +from __future__ import annotations + +import abc +import json +import sqlite3 +import threading +from pathlib import Path +from typing import Any, TYPE_CHECKING + +if TYPE_CHECKING: + from ..items import TResponseInputItem + + +class SessionMemory(abc.ABC): + """Abstract base class for session memory implementations. + + Session memory stores conversation history across agent runs, allowing + agents to maintain context without requiring explicit manual memory management. + """ + + @abc.abstractmethod + async def get_messages(self, session_id: str) -> list[TResponseInputItem]: + """Retrieve the conversation history for a given session. + + Args: + session_id: Unique identifier for the conversation session + + Returns: + List of input items representing the conversation history + """ + pass + + @abc.abstractmethod + async def add_messages( + self, session_id: str, messages: list[TResponseInputItem] + ) -> None: + """Add new messages to the conversation history. + + Args: + session_id: Unique identifier for the conversation session + messages: List of input items to add to the history + """ + pass + + @abc.abstractmethod + async def clear_session(self, session_id: str) -> None: + """Clear all messages for a given session. + + Args: + session_id: Unique identifier for the conversation session + """ + pass + + +class SQLiteSessionMemory(SessionMemory): + """SQLite-based implementation of session memory. + + This implementation stores conversation history in a SQLite database. + By default, uses an in-memory database that is lost when the process ends. + For persistent storage, provide a file path. + """ + + def __init__(self, db_path: str | Path = ":memory:"): + """Initialize the SQLite session memory. + + Args: + db_path: Path to the SQLite database file. Defaults to ':memory:' (in-memory database) + """ + self.db_path = db_path + self._local = threading.local() + self._init_db() + + def _get_connection(self) -> sqlite3.Connection: + """Get a thread-local database connection.""" + if not hasattr(self._local, "connection"): + self._local.connection = sqlite3.connect( + str(self.db_path) if self.db_path != ":memory:" else self.db_path, + check_same_thread=False, + ) + self._local.connection.execute("PRAGMA journal_mode=WAL") + return self._local.connection + + def _init_db(self) -> None: + """Initialize the database schema.""" + conn = self._get_connection() + conn.execute( + """ + CREATE TABLE IF NOT EXISTS sessions ( + session_id TEXT PRIMARY KEY, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """ + ) + + conn.execute( + """ + CREATE TABLE IF NOT EXISTS messages ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + session_id TEXT NOT NULL, + message_data TEXT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (session_id) REFERENCES sessions (session_id) ON DELETE CASCADE + ) + """ + ) + + conn.execute( + """ + CREATE INDEX IF NOT EXISTS idx_messages_session_id + ON messages (session_id, created_at) + """ + ) + + conn.commit() + + async def get_messages(self, session_id: str) -> list[TResponseInputItem]: + """Retrieve the conversation history for a given session. + + Args: + session_id: Unique identifier for the conversation session + + Returns: + List of input items representing the conversation history + """ + conn = self._get_connection() + cursor = conn.execute( + """ + SELECT message_data FROM messages + WHERE session_id = ? + ORDER BY created_at ASC + """, + (session_id,), + ) + + messages = [] + for (message_data,) in cursor.fetchall(): + try: + message = json.loads(message_data) + messages.append(message) + except json.JSONDecodeError: + # Skip invalid JSON entries + continue + + return messages + + async def add_messages( + self, session_id: str, messages: list[TResponseInputItem] + ) -> None: + """Add new messages to the conversation history. + + Args: + session_id: Unique identifier for the conversation session + messages: List of input items to add to the history + """ + if not messages: + return + + conn = self._get_connection() + + # Ensure session exists + conn.execute( + """ + INSERT OR IGNORE INTO sessions (session_id) VALUES (?) + """, + (session_id,), + ) + + # Add messages + message_data = [(session_id, json.dumps(message)) for message in messages] + conn.executemany( + """ + INSERT INTO messages (session_id, message_data) VALUES (?, ?) + """, + message_data, + ) + + # Update session timestamp + conn.execute( + """ + UPDATE sessions SET updated_at = CURRENT_TIMESTAMP WHERE session_id = ? + """, + (session_id,), + ) + + conn.commit() + + async def clear_session(self, session_id: str) -> None: + """Clear all messages for a given session. + + Args: + session_id: Unique identifier for the conversation session + """ + conn = self._get_connection() + conn.execute("DELETE FROM messages WHERE session_id = ?", (session_id,)) + conn.execute("DELETE FROM sessions WHERE session_id = ?", (session_id,)) + conn.commit() + + def close(self) -> None: + """Close the database connection.""" + if hasattr(self._local, "connection"): + self._local.connection.close() diff --git a/src/agents/run.py b/src/agents/run.py index b196c3bf..0bcc0cae 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -27,11 +27,17 @@ ModelBehaviorError, OutputGuardrailTripwireTriggered, ) -from .guardrail import InputGuardrail, InputGuardrailResult, OutputGuardrail, OutputGuardrailResult +from .guardrail import ( + InputGuardrail, + InputGuardrailResult, + OutputGuardrail, + OutputGuardrailResult, +) from .handoffs import Handoff, HandoffInputFilter, handoff from .items import ItemHelpers, ModelResponse, RunItem, TResponseInputItem from .lifecycle import RunHooks from .logger import logger +from .memory import SessionMemory, SQLiteSessionMemory from .model_settings import ModelSettings from .models.interface import Model, ModelProvider from .models.multi_provider import MultiProvider @@ -105,6 +111,12 @@ class RunConfig: An optional dictionary of additional metadata to include with the trace. """ + session_id: str | None = None + """ + A session identifier for memory persistence. If provided and the agent has memory enabled, + conversation history will be automatically managed using this session ID. + """ + class Runner: @classmethod @@ -154,6 +166,11 @@ async def run( if run_config is None: run_config = RunConfig() + # Prepare input with session memory if enabled + prepared_input, session_memory = await cls._prepare_input_with_memory( + starting_agent, input, run_config + ) + tool_use_tracker = AgentToolUseTracker() with TraceCtxManager( @@ -164,7 +181,9 @@ async def run( disabled=run_config.tracing_disabled, ): current_turn = 0 - original_input: str | list[TResponseInputItem] = copy.deepcopy(input) + original_input: str | list[TResponseInputItem] = copy.deepcopy( + prepared_input + ) generated_items: list[RunItem] = [] model_responses: list[ModelResponse] = [] @@ -183,7 +202,9 @@ async def run( # Start an agent span if we don't have one. This span is ended if the current # agent changes, or if the agent loop ends. if current_span is None: - handoff_names = [h.agent_name for h in cls._get_handoffs(current_agent)] + handoff_names = [ + h.agent_name for h in cls._get_handoffs(current_agent) + ] if output_schema := cls._get_output_schema(current_agent): output_type_name = output_schema.name() else: @@ -220,7 +241,7 @@ async def run( starting_agent, starting_agent.input_guardrails + (run_config.input_guardrails or []), - copy.deepcopy(input), + copy.deepcopy(prepared_input), context_wrapper, ), cls._run_single_turn( @@ -257,12 +278,13 @@ async def run( if isinstance(turn_result.next_step, NextStepFinalOutput): output_guardrail_results = await cls._run_output_guardrails( - current_agent.output_guardrails + (run_config.output_guardrails or []), + current_agent.output_guardrails + + (run_config.output_guardrails or []), current_agent, turn_result.next_step.output, context_wrapper, ) - return RunResult( + result = RunResult( input=original_input, new_items=generated_items, raw_responses=model_responses, @@ -272,8 +294,17 @@ async def run( output_guardrail_results=output_guardrail_results, context_wrapper=context_wrapper, ) + + # Save the conversation to session memory if enabled + await cls._save_result_to_memory( + session_memory, run_config.session_id, input, result + ) + + return result elif isinstance(turn_result.next_step, NextStepHandoff): - current_agent = cast(Agent[TContext], turn_result.next_step.new_agent) + current_agent = cast( + Agent[TContext], turn_result.next_step.new_agent + ) current_span.finish(reset_current=True) current_span = None should_run_agent_start_hooks = True @@ -500,13 +531,23 @@ async def _run_streamed_impl( if streamed_result.trace: streamed_result.trace.start(mark_as_current=True) + # Prepare input with session memory if enabled + prepared_input, session_memory = await cls._prepare_input_with_memory( + starting_agent, starting_input, run_config + ) + + # Update the streamed result with the prepared input + streamed_result.input = prepared_input + current_span: Span[AgentSpanData] | None = None current_agent = starting_agent current_turn = 0 should_run_agent_start_hooks = True tool_use_tracker = AgentToolUseTracker() - streamed_result._event_queue.put_nowait(AgentUpdatedStreamEvent(new_agent=current_agent)) + streamed_result._event_queue.put_nowait( + AgentUpdatedStreamEvent(new_agent=current_agent) + ) try: while True: @@ -516,7 +557,9 @@ async def _run_streamed_impl( # Start an agent span if we don't have one. This span is ended if the current # agent changes, or if the agent loop ends. if current_span is None: - handoff_names = [h.agent_name for h in cls._get_handoffs(current_agent)] + handoff_names = [ + h.agent_name for h in cls._get_handoffs(current_agent) + ] if output_schema := cls._get_output_schema(current_agent): output_type_name = output_schema.name() else: @@ -551,8 +594,11 @@ async def _run_streamed_impl( streamed_result._input_guardrails_task = asyncio.create_task( cls._run_input_guardrails_with_queue( starting_agent, - starting_agent.input_guardrails + (run_config.input_guardrails or []), - copy.deepcopy(ItemHelpers.input_to_new_input_list(starting_input)), + starting_agent.input_guardrails + + (run_config.input_guardrails or []), + copy.deepcopy( + ItemHelpers.input_to_new_input_list(prepared_input) + ), context_wrapper, streamed_result, current_span, @@ -598,14 +644,38 @@ async def _run_streamed_impl( ) try: - output_guardrail_results = await streamed_result._output_guardrails_task + output_guardrail_results = ( + await streamed_result._output_guardrails_task + ) except Exception: # Exceptions will be checked in the stream_events loop output_guardrail_results = [] - streamed_result.output_guardrail_results = output_guardrail_results + streamed_result.output_guardrail_results = ( + output_guardrail_results + ) streamed_result.final_output = turn_result.next_step.output streamed_result.is_complete = True + + # Save the conversation to session memory if enabled + # Create a temporary RunResult for memory saving + temp_result = RunResult( + input=streamed_result.input, + new_items=streamed_result.new_items, + raw_responses=streamed_result.raw_responses, + final_output=streamed_result.final_output, + _last_agent=current_agent, + input_guardrail_results=streamed_result.input_guardrail_results, + output_guardrail_results=streamed_result.output_guardrail_results, + context_wrapper=context_wrapper, + ) + await cls._save_result_to_memory( + session_memory, + run_config.session_id, + starting_input, + temp_result, + ) + streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) elif isinstance(turn_result.next_step, NextStepRunAgain): pass @@ -662,7 +732,9 @@ async def _run_single_turn_streamed( handoffs = cls._get_handoffs(agent) model = cls._get_model(agent, run_config) model_settings = agent.model_settings.resolve(run_config.model_settings) - model_settings = RunImpl.maybe_reset_tool_choice(agent, tool_use_tracker, model_settings) + model_settings = RunImpl.maybe_reset_tool_choice( + agent, tool_use_tracker, model_settings + ) final_response: ModelResponse | None = None @@ -723,7 +795,9 @@ async def _run_single_turn_streamed( tool_use_tracker=tool_use_tracker, ) - RunImpl.stream_step_result_to_queue(single_step_result, streamed_result._event_queue) + RunImpl.stream_step_result_to_queue( + single_step_result, streamed_result._event_queue + ) return single_step_result @classmethod @@ -757,7 +831,9 @@ async def _run_single_turn( output_schema = cls._get_output_schema(agent) handoffs = cls._get_handoffs(agent) input = ItemHelpers.input_to_new_input_list(original_input) - input.extend([generated_item.to_input_item() for generated_item in generated_items]) + input.extend( + [generated_item.to_input_item() for generated_item in generated_items] + ) new_response = await cls._get_new_response( agent, @@ -875,7 +951,9 @@ async def _run_output_guardrails( guardrail_tasks = [ asyncio.create_task( - RunImpl.run_single_output_guardrail(guardrail, agent, agent_output, context) + RunImpl.run_single_output_guardrail( + guardrail, agent, agent_output, context + ) ) for guardrail in guardrails ] @@ -916,7 +994,9 @@ async def _get_new_response( ) -> ModelResponse: model = cls._get_model(agent, run_config) model_settings = agent.model_settings.resolve(run_config.model_settings) - model_settings = RunImpl.maybe_reset_tool_choice(agent, tool_use_tracker, model_settings) + model_settings = RunImpl.maybe_reset_tool_choice( + agent, tool_use_tracker, model_settings + ) new_response = await model.get_response( system_instructions=system_prompt, @@ -968,3 +1048,60 @@ def _get_model(cls, agent: Agent[Any], run_config: RunConfig) -> Model: return agent.model return run_config.model_provider.get_model(agent.model) + + @classmethod + def _get_session_memory(cls, agent: Agent[Any]) -> SessionMemory | None: + """Get the session memory instance for the agent, if any.""" + if agent.memory is None: + return None + elif agent.memory is True: + return SQLiteSessionMemory() + elif isinstance(agent.memory, SessionMemory): + return agent.memory + else: + raise ValueError(f"Invalid memory configuration: {agent.memory}") + + @classmethod + async def _prepare_input_with_memory( + cls, + agent: Agent[Any], + input: str | list[TResponseInputItem], + run_config: RunConfig, + ) -> tuple[str | list[TResponseInputItem], SessionMemory | None]: + """Prepare input by combining it with session memory if enabled.""" + memory = cls._get_session_memory(agent) + if memory is None or run_config.session_id is None: + return input, memory + + # Get previous conversation history + history = await memory.get_messages(run_config.session_id) + + # Convert input to list format + new_input_list = ItemHelpers.input_to_new_input_list(input) + + # Combine history with new input + combined_input = history + new_input_list + + return combined_input, memory + + @classmethod + async def _save_result_to_memory( + cls, + memory: SessionMemory | None, + session_id: str | None, + original_input: str | list[TResponseInputItem], + result: RunResult, + ) -> None: + """Save the conversation turn to session memory.""" + if memory is None or session_id is None: + return + + # Convert original input to list format if needed + input_list = ItemHelpers.input_to_new_input_list(original_input) + + # Convert new items to input format + new_items_as_input = [item.to_input_item() for item in result.new_items] + + # Save all messages from this turn + messages_to_save = input_list + new_items_as_input + await memory.add_messages(session_id, messages_to_save) diff --git a/tests/test_session_memory.py b/tests/test_session_memory.py new file mode 100644 index 00000000..5493c125 --- /dev/null +++ b/tests/test_session_memory.py @@ -0,0 +1,191 @@ +"""Tests for session memory functionality.""" + +import pytest +import tempfile +import os +from pathlib import Path + +from agents import Agent, Runner, RunConfig, SQLiteSessionMemory +from agents.memory import SessionMemory + +from .fake_model import FakeModel +from .test_responses import get_text_message + + +@pytest.mark.asyncio +async def test_session_memory_basic_functionality(): + """Test basic session memory functionality with SQLite backend.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_memory.db" + memory = SQLiteSessionMemory(db_path) + + model = FakeModel() + agent = Agent(name="test", model=model, memory=memory) + + session_id = "test_session_123" + run_config = RunConfig(session_id=session_id) + + # First turn + model.set_next_output([get_text_message("San Francisco")]) + result1 = await Runner.run( + agent, "What city is the Golden Gate Bridge in?", run_config=run_config + ) + assert result1.final_output == "San Francisco" + + # Second turn - should have conversation history + model.set_next_output([get_text_message("California")]) + result2 = await Runner.run(agent, "What state is it in?", run_config=run_config) + assert result2.final_output == "California" + + # Verify that the input to the second turn includes the previous conversation + # The model should have received the full conversation history + last_input = model.last_turn_args["input"] + assert len(last_input) > 1 # Should have more than just the current message + + memory.close() + + +@pytest.mark.asyncio +async def test_session_memory_with_boolean_true(): + """Test session memory when agent.memory=True (default SQLite).""" + model = FakeModel() + agent = Agent(name="test", model=model, memory=True) # Use default SQLite memory + + session_id = "test_session_456" + run_config = RunConfig(session_id=session_id) + + # First turn + model.set_next_output([get_text_message("Hello")]) + result1 = await Runner.run(agent, "Hi there", run_config=run_config) + assert result1.final_output == "Hello" + + # Second turn + model.set_next_output([get_text_message("I remember you said hi")]) + result2 = await Runner.run( + agent, "Do you remember what I said?", run_config=run_config + ) + assert result2.final_output == "I remember you said hi" + + +@pytest.mark.asyncio +async def test_session_memory_disabled(): + """Test that session memory is disabled when memory=None.""" + model = FakeModel() + agent = Agent(name="test", model=model, memory=None) # No session memory + + session_id = "test_session_789" + run_config = RunConfig(session_id=session_id) + + # First turn + model.set_next_output([get_text_message("Hello")]) + result1 = await Runner.run(agent, "Hi there", run_config=run_config) + assert result1.final_output == "Hello" + + # Second turn - should NOT have conversation history + model.set_next_output([get_text_message("I don't remember")]) + result2 = await Runner.run( + agent, "Do you remember what I said?", run_config=run_config + ) + assert result2.final_output == "I don't remember" + + # Verify that the input to the second turn is just the current message + last_input = model.last_turn_args["input"] + assert len(last_input) == 1 # Should only have the current message + + +@pytest.mark.asyncio +async def test_session_memory_different_sessions(): + """Test that different session IDs maintain separate conversation histories.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_memory.db" + memory = SQLiteSessionMemory(db_path) + + model = FakeModel() + agent = Agent(name="test", model=model, memory=memory) + + # Session 1 + session_id_1 = "session_1" + run_config_1 = RunConfig(session_id=session_id_1) + + model.set_next_output([get_text_message("I like cats")]) + result1 = await Runner.run(agent, "I like cats", run_config=run_config_1) + assert result1.final_output == "I like cats" + + # Session 2 - different session + session_id_2 = "session_2" + run_config_2 = RunConfig(session_id=session_id_2) + + model.set_next_output([get_text_message("I like dogs")]) + result2 = await Runner.run(agent, "I like dogs", run_config=run_config_2) + assert result2.final_output == "I like dogs" + + # Back to Session 1 - should remember cats, not dogs + model.set_next_output([get_text_message("Yes, you mentioned cats")]) + result3 = await Runner.run( + agent, "What did I say I like?", run_config=run_config_1 + ) + assert result3.final_output == "Yes, you mentioned cats" + + memory.close() + + +@pytest.mark.asyncio +async def test_session_memory_no_session_id(): + """Test that session memory is disabled when no session_id is provided.""" + model = FakeModel() + agent = Agent( + name="test", model=model, memory=True # Memory enabled but no session_id + ) + + run_config = RunConfig() # No session_id + + # First turn + model.set_next_output([get_text_message("Hello")]) + result1 = await Runner.run(agent, "Hi there", run_config=run_config) + assert result1.final_output == "Hello" + + # Second turn - should NOT have conversation history since no session_id + model.set_next_output([get_text_message("I don't remember")]) + result2 = await Runner.run( + agent, "Do you remember what I said?", run_config=run_config + ) + assert result2.final_output == "I don't remember" + + +@pytest.mark.asyncio +async def test_sqlite_session_memory_direct(): + """Test SQLiteSessionMemory class directly.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_direct.db" + memory = SQLiteSessionMemory(db_path) + + session_id = "direct_test" + + # Test adding and retrieving messages + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ] + + await memory.add_messages(session_id, messages) + retrieved = await memory.get_messages(session_id) + + assert len(retrieved) == 2 + assert retrieved[0]["role"] == "user" + assert retrieved[0]["content"] == "Hello" + assert retrieved[1]["role"] == "assistant" + assert retrieved[1]["content"] == "Hi there!" + + # Test clearing session + await memory.clear_session(session_id) + retrieved_after_clear = await memory.get_messages(session_id) + assert len(retrieved_after_clear) == 0 + + memory.close() + + +def test_session_memory_invalid_config(): + """Test that invalid memory configuration raises ValueError.""" + with pytest.raises(ValueError, match="Invalid memory configuration"): + agent = Agent(name="test", memory="invalid") + Runner._get_session_memory(agent) From 63a786f102871f408fa3f41e4bdb55970859f0e1 Mon Sep 17 00:00:00 2001 From: Stephan Fitzpatrick Date: Fri, 23 May 2025 22:23:46 -0700 Subject: [PATCH 02/26] Enhance session memory validation in Runner class - Added a check to raise a ValueError if `session_id` is not provided when session memory is enabled. - Updated the `SessionMemory` class to use a Protocol instead of an abstract base class, simplifying the implementation. - Modified tests to ensure an exception is raised when attempting to run with memory enabled but no session_id is provided. --- src/agents/memory/session_memory.py | 17 +++++++---------- src/agents/run.py | 8 +++++++- tests/test_session_memory.py | 18 ++++++------------ 3 files changed, 20 insertions(+), 23 deletions(-) diff --git a/src/agents/memory/session_memory.py b/src/agents/memory/session_memory.py index 78a393dc..376115bb 100644 --- a/src/agents/memory/session_memory.py +++ b/src/agents/memory/session_memory.py @@ -1,24 +1,23 @@ from __future__ import annotations -import abc import json import sqlite3 import threading from pathlib import Path -from typing import Any, TYPE_CHECKING +from typing import Any, TYPE_CHECKING, Protocol, runtime_checkable if TYPE_CHECKING: from ..items import TResponseInputItem -class SessionMemory(abc.ABC): - """Abstract base class for session memory implementations. +@runtime_checkable +class SessionMemory(Protocol): + """Protocol for session memory implementations. Session memory stores conversation history across agent runs, allowing agents to maintain context without requiring explicit manual memory management. """ - @abc.abstractmethod async def get_messages(self, session_id: str) -> list[TResponseInputItem]: """Retrieve the conversation history for a given session. @@ -28,9 +27,8 @@ async def get_messages(self, session_id: str) -> list[TResponseInputItem]: Returns: List of input items representing the conversation history """ - pass + ... - @abc.abstractmethod async def add_messages( self, session_id: str, messages: list[TResponseInputItem] ) -> None: @@ -40,16 +38,15 @@ async def add_messages( session_id: Unique identifier for the conversation session messages: List of input items to add to the history """ - pass + ... - @abc.abstractmethod async def clear_session(self, session_id: str) -> None: """Clear all messages for a given session. Args: session_id: Unique identifier for the conversation session """ - pass + ... class SQLiteSessionMemory(SessionMemory): diff --git a/src/agents/run.py b/src/agents/run.py index 0bcc0cae..213d5d3a 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -1070,9 +1070,15 @@ async def _prepare_input_with_memory( ) -> tuple[str | list[TResponseInputItem], SessionMemory | None]: """Prepare input by combining it with session memory if enabled.""" memory = cls._get_session_memory(agent) - if memory is None or run_config.session_id is None: + if memory is None: return input, memory + if run_config.session_id is None: + raise ValueError( + "session_id is required when memory is enabled. " + "Please provide a session_id in the RunConfig." + ) + # Get previous conversation history history = await memory.get_messages(run_config.session_id) diff --git a/tests/test_session_memory.py b/tests/test_session_memory.py index 5493c125..cb271d01 100644 --- a/tests/test_session_memory.py +++ b/tests/test_session_memory.py @@ -131,7 +131,7 @@ async def test_session_memory_different_sessions(): @pytest.mark.asyncio async def test_session_memory_no_session_id(): - """Test that session memory is disabled when no session_id is provided.""" + """Test that session memory raises an exception when no session_id is provided.""" model = FakeModel() agent = Agent( name="test", model=model, memory=True # Memory enabled but no session_id @@ -139,17 +139,11 @@ async def test_session_memory_no_session_id(): run_config = RunConfig() # No session_id - # First turn - model.set_next_output([get_text_message("Hello")]) - result1 = await Runner.run(agent, "Hi there", run_config=run_config) - assert result1.final_output == "Hello" - - # Second turn - should NOT have conversation history since no session_id - model.set_next_output([get_text_message("I don't remember")]) - result2 = await Runner.run( - agent, "Do you remember what I said?", run_config=run_config - ) - assert result2.final_output == "I don't remember" + # Should raise ValueError when trying to run with memory enabled but no session_id + with pytest.raises( + ValueError, match="session_id is required when memory is enabled" + ): + await Runner.run(agent, "Hi there", run_config=run_config) @pytest.mark.asyncio From 6aab1e8b30ace06d2900d405b1cee4b216c8bc7c Mon Sep 17 00:00:00 2001 From: Stephan Fitzpatrick Date: Fri, 23 May 2025 22:29:13 -0700 Subject: [PATCH 03/26] Add custom session memory implementation examples to README - Introduced a section on creating custom memory implementations following the `SessionMemory` protocol. - Added code examples demonstrating how to implement and use a custom memory class. - Highlighted the requirement for `session_id` when session memory is enabled, with examples illustrating correct usage. --- README.md | 41 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/README.md b/README.md index 25a72be0..4346b4e6 100644 --- a/README.md +++ b/README.md @@ -64,6 +64,47 @@ run_config_1 = RunConfig(session_id="user_123") run_config_2 = RunConfig(session_id="user_456") ``` +### Custom memory implementations + +You can implement your own session memory by creating a class that follows the `SessionMemory` protocol: + +```python +from agents.memory import SessionMemory +from typing import List + +class MyCustomMemory: + """Custom memory implementation following the SessionMemory protocol.""" + + async def get_messages(self, session_id: str) -> List[dict]: + # Retrieve conversation history for the session + pass + + async def add_messages(self, session_id: str, messages: List[dict]) -> None: + # Store new messages for the session + pass + + async def clear_session(self, session_id: str) -> None: + # Clear all messages for the session + pass + +# Use your custom memory +agent = Agent(name="Assistant", memory=MyCustomMemory()) +``` + +### Important: session_id requirement + +When session memory is enabled (either with `memory=True` or a custom `SessionMemory` implementation), you **must** provide a `session_id` in the `RunConfig`. If you don't, the runner will raise a `ValueError`: + +```python +agent = Agent(name="Assistant", memory=True) + +# This will raise ValueError: "session_id is required when memory is enabled" +result = await Runner.run(agent, "Hello", run_config=RunConfig()) + +# This works correctly +result = await Runner.run(agent, "Hello", run_config=RunConfig(session_id="my_session")) +``` + ## Get started 1. Set up your Python environment From 53d4132e32529be06feea73c8f52ba9306916731 Mon Sep 17 00:00:00 2001 From: Stephan Fitzpatrick Date: Fri, 23 May 2025 22:37:05 -0700 Subject: [PATCH 04/26] Implement session memory instance reuse in Runner class - Updated the Runner class to ensure that when memory=True, a single instance of SQLiteSessionMemory is created and reused across runs. - Added a test to verify that the same memory instance is returned for multiple calls when memory is enabled. - Ensured the agent stores the memory instance for consistency. --- src/agents/run.py | 6 +++++- tests/test_session_memory.py | 19 +++++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/src/agents/run.py b/src/agents/run.py index 213d5d3a..79613087 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -1055,7 +1055,11 @@ def _get_session_memory(cls, agent: Agent[Any]) -> SessionMemory | None: if agent.memory is None: return None elif agent.memory is True: - return SQLiteSessionMemory() + # For memory=True, we need to create a memory instance if it doesn't exist + # and store it on the agent to ensure consistency across runs + if not hasattr(agent, '_session_memory_instance'): + agent._session_memory_instance = SQLiteSessionMemory() + return agent._session_memory_instance elif isinstance(agent.memory, SessionMemory): return agent.memory else: diff --git a/tests/test_session_memory.py b/tests/test_session_memory.py index cb271d01..f568f0aa 100644 --- a/tests/test_session_memory.py +++ b/tests/test_session_memory.py @@ -67,6 +67,25 @@ async def test_session_memory_with_boolean_true(): assert result2.final_output == "I remember you said hi" +@pytest.mark.asyncio +async def test_session_memory_instance_reuse(): + """Test that when memory=True, the same memory instance is reused across runs.""" + agent = Agent(name="test", memory=True) + + # Get memory instance for the first time + memory1 = Runner._get_session_memory(agent) + + # Get memory instance for the second time + memory2 = Runner._get_session_memory(agent) + + # Should be the exact same instance + assert memory1 is memory2 + + # Should have created the _session_memory_instance attribute + assert hasattr(agent, "_session_memory_instance") + assert agent._session_memory_instance is memory1 + + @pytest.mark.asyncio async def test_session_memory_disabled(): """Test that session memory is disabled when memory=None.""" From 14600ee2a7f67403308470e83eb1ef4008ff9572 Mon Sep 17 00:00:00 2001 From: Stephan Fitzpatrick Date: Fri, 23 May 2025 22:53:18 -0700 Subject: [PATCH 05/26] Refactor session memory usage in README and examples - Updated README and example scripts to utilize `SQLiteSessionMemory` explicitly instead of using a boolean flag for memory. - Modified `RunConfig` to accept a memory instance directly, enhancing clarity and flexibility in session management. - Adjusted tests to reflect the new memory handling approach, ensuring consistent behavior across different configurations. --- README.md | 40 +++++++----- examples/basic/session_memory_example.py | 18 +++--- src/agents/run.py | 36 +++++------ tests/test_session_memory.py | 80 ++++++++++-------------- 4 files changed, 85 insertions(+), 89 deletions(-) diff --git a/README.md b/README.md index 4346b4e6..d63f864f 100644 --- a/README.md +++ b/README.md @@ -21,17 +21,22 @@ The Agents SDK provides built-in session memory to automatically maintain conver ### Quick start ```python -from agents import Agent, Runner, RunConfig +from agents import Agent, Runner, RunConfig, SQLiteSessionMemory -# Create agent with session memory enabled +# Create agent agent = Agent( name="Assistant", instructions="Reply very concisely.", - memory=True # Enable automatic session memory ) -# Use session ID to maintain conversation history -run_config = RunConfig(session_id="conversation_123") +# Create a session memory instance +memory = SQLiteSessionMemory() + +# Configure run with session memory and session ID +run_config = RunConfig( + memory=memory, # Use our session memory instance + session_id="conversation_123" +) # First turn result = await Runner.run(agent, "What city is the Golden Gate Bridge in?", run_config=run_config) @@ -49,19 +54,18 @@ print(result.final_output) # "Approximately 39 million" ### Memory options - **`memory=None`** (default): No session memory -- **`memory=True`**: Use default in-memory SQLite session memory -- **`memory=SessionMemory`**: Use custom session memory implementation +- **`memory=SessionMemory`**: Use the provided session memory implementation ```python -from agents import SQLiteSessionMemory +from agents import Agent, Runner, RunConfig, SQLiteSessionMemory # Custom SQLite database file memory = SQLiteSessionMemory("conversations.db") -agent = Agent(name="Assistant", memory=memory) +agent = Agent(name="Assistant") # Different session IDs maintain separate conversation histories -run_config_1 = RunConfig(session_id="user_123") -run_config_2 = RunConfig(session_id="user_456") +run_config_1 = RunConfig(memory=memory, session_id="user_123") +run_config_2 = RunConfig(memory=memory, session_id="user_456") ``` ### Custom memory implementations @@ -88,21 +92,25 @@ class MyCustomMemory: pass # Use your custom memory -agent = Agent(name="Assistant", memory=MyCustomMemory()) +agent = Agent(name="Assistant") +run_config = RunConfig(memory=MyCustomMemory(), session_id="my_session") ``` ### Important: session_id requirement -When session memory is enabled (either with `memory=True` or a custom `SessionMemory` implementation), you **must** provide a `session_id` in the `RunConfig`. If you don't, the runner will raise a `ValueError`: +When session memory is enabled, you **must** provide a `session_id` in the `RunConfig`. If you don't, the runner will raise a `ValueError`: ```python -agent = Agent(name="Assistant", memory=True) +from agents import Agent, Runner, RunConfig, SQLiteSessionMemory + +agent = Agent(name="Assistant") +memory = SQLiteSessionMemory() # This will raise ValueError: "session_id is required when memory is enabled" -result = await Runner.run(agent, "Hello", run_config=RunConfig()) +result = await Runner.run(agent, "Hello", run_config=RunConfig(memory=memory)) # This works correctly -result = await Runner.run(agent, "Hello", run_config=RunConfig(session_id="my_session")) +result = await Runner.run(agent, "Hello", run_config=RunConfig(memory=memory, session_id="my_session")) ``` ## Get started diff --git a/examples/basic/session_memory_example.py b/examples/basic/session_memory_example.py index e3c9e467..3fe41d63 100644 --- a/examples/basic/session_memory_example.py +++ b/examples/basic/session_memory_example.py @@ -6,22 +6,26 @@ """ import asyncio -from agents import Agent, Runner, RunConfig +from agents import Agent, Runner, RunConfig, SQLiteSessionMemory async def main(): - # Create an agent with session memory enabled + # Create an agent agent = Agent( name="Assistant", instructions="Reply very concisely.", - memory=True, # Enable default SQLite session memory ) + # Create a session memory instance that will persist across runs + memory = SQLiteSessionMemory() + # Define a session ID for this conversation session_id = "conversation_123" - # Create run config with session ID - run_config = RunConfig(session_id=session_id) + # Create run config with session memory and session ID + run_config = RunConfig( + memory=memory, session_id=session_id # Use our session memory instance + ) print("=== Session Memory Example ===") print("The agent will remember previous messages automatically.\n") @@ -53,9 +57,7 @@ async def main(): print("=== Conversation Complete ===") print("Notice how the agent remembered the context from previous turns!") - print( - "No need to manually handle .to_input_list() - session memory handles it automatically." - ) + print("Session memory in RunConfig handles conversation history automatically.") if __name__ == "__main__": diff --git a/src/agents/run.py b/src/agents/run.py index 79613087..26c4a66b 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -111,10 +111,17 @@ class RunConfig: An optional dictionary of additional metadata to include with the trace. """ + memory: SessionMemory | None = None + """ + Session memory instance for conversation history persistence. + - None (default): No session memory + - SessionMemory instance: Use the provided session memory implementation + """ + session_id: str | None = None """ - A session identifier for memory persistence. If provided and the agent has memory enabled, - conversation history will be automatically managed using this session ID. + A session identifier for memory persistence. Required when memory is enabled. + Conversation history will be automatically managed using this session ID. """ @@ -168,7 +175,7 @@ async def run( # Prepare input with session memory if enabled prepared_input, session_memory = await cls._prepare_input_with_memory( - starting_agent, input, run_config + input, run_config ) tool_use_tracker = AgentToolUseTracker() @@ -533,7 +540,7 @@ async def _run_streamed_impl( # Prepare input with session memory if enabled prepared_input, session_memory = await cls._prepare_input_with_memory( - starting_agent, starting_input, run_config + starting_input, run_config ) # Update the streamed result with the prepared input @@ -1050,30 +1057,23 @@ def _get_model(cls, agent: Agent[Any], run_config: RunConfig) -> Model: return run_config.model_provider.get_model(agent.model) @classmethod - def _get_session_memory(cls, agent: Agent[Any]) -> SessionMemory | None: - """Get the session memory instance for the agent, if any.""" - if agent.memory is None: + def _get_session_memory(cls, run_config: RunConfig) -> SessionMemory | None: + """Get the session memory instance from run config, if any.""" + if run_config.memory is None: return None - elif agent.memory is True: - # For memory=True, we need to create a memory instance if it doesn't exist - # and store it on the agent to ensure consistency across runs - if not hasattr(agent, '_session_memory_instance'): - agent._session_memory_instance = SQLiteSessionMemory() - return agent._session_memory_instance - elif isinstance(agent.memory, SessionMemory): - return agent.memory + elif isinstance(run_config.memory, SessionMemory): + return run_config.memory else: - raise ValueError(f"Invalid memory configuration: {agent.memory}") + raise ValueError(f"Invalid memory configuration: {run_config.memory}") @classmethod async def _prepare_input_with_memory( cls, - agent: Agent[Any], input: str | list[TResponseInputItem], run_config: RunConfig, ) -> tuple[str | list[TResponseInputItem], SessionMemory | None]: """Prepare input by combining it with session memory if enabled.""" - memory = cls._get_session_memory(agent) + memory = cls._get_session_memory(run_config) if memory is None: return input, memory diff --git a/tests/test_session_memory.py b/tests/test_session_memory.py index f568f0aa..49f8b748 100644 --- a/tests/test_session_memory.py +++ b/tests/test_session_memory.py @@ -20,10 +20,10 @@ async def test_session_memory_basic_functionality(): memory = SQLiteSessionMemory(db_path) model = FakeModel() - agent = Agent(name="test", model=model, memory=memory) + agent = Agent(name="test", model=model) session_id = "test_session_123" - run_config = RunConfig(session_id=session_id) + run_config = RunConfig(memory=memory, session_id=session_id) # First turn model.set_next_output([get_text_message("San Francisco")]) @@ -46,54 +46,41 @@ async def test_session_memory_basic_functionality(): @pytest.mark.asyncio -async def test_session_memory_with_boolean_true(): - """Test session memory when agent.memory=True (default SQLite).""" - model = FakeModel() - agent = Agent(name="test", model=model, memory=True) # Use default SQLite memory - - session_id = "test_session_456" - run_config = RunConfig(session_id=session_id) - - # First turn - model.set_next_output([get_text_message("Hello")]) - result1 = await Runner.run(agent, "Hi there", run_config=run_config) - assert result1.final_output == "Hello" - - # Second turn - model.set_next_output([get_text_message("I remember you said hi")]) - result2 = await Runner.run( - agent, "Do you remember what I said?", run_config=run_config - ) - assert result2.final_output == "I remember you said hi" - - -@pytest.mark.asyncio -async def test_session_memory_instance_reuse(): - """Test that when memory=True, the same memory instance is reused across runs.""" - agent = Agent(name="test", memory=True) +async def test_session_memory_with_explicit_instance(): + """Test session memory with an explicit SQLiteSessionMemory instance.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_memory.db" + memory = SQLiteSessionMemory(db_path) - # Get memory instance for the first time - memory1 = Runner._get_session_memory(agent) + model = FakeModel() + agent = Agent(name="test", model=model) - # Get memory instance for the second time - memory2 = Runner._get_session_memory(agent) + session_id = "test_session_456" + run_config = RunConfig(memory=memory, session_id=session_id) - # Should be the exact same instance - assert memory1 is memory2 + # First turn + model.set_next_output([get_text_message("Hello")]) + result1 = await Runner.run(agent, "Hi there", run_config=run_config) + assert result1.final_output == "Hello" + + # Second turn + model.set_next_output([get_text_message("I remember you said hi")]) + result2 = await Runner.run( + agent, "Do you remember what I said?", run_config=run_config + ) + assert result2.final_output == "I remember you said hi" - # Should have created the _session_memory_instance attribute - assert hasattr(agent, "_session_memory_instance") - assert agent._session_memory_instance is memory1 + memory.close() @pytest.mark.asyncio async def test_session_memory_disabled(): """Test that session memory is disabled when memory=None.""" model = FakeModel() - agent = Agent(name="test", model=model, memory=None) # No session memory + agent = Agent(name="test", model=model) session_id = "test_session_789" - run_config = RunConfig(session_id=session_id) + run_config = RunConfig(memory=None, session_id=session_id) # No session memory # First turn model.set_next_output([get_text_message("Hello")]) @@ -120,11 +107,11 @@ async def test_session_memory_different_sessions(): memory = SQLiteSessionMemory(db_path) model = FakeModel() - agent = Agent(name="test", model=model, memory=memory) + agent = Agent(name="test", model=model) # Session 1 session_id_1 = "session_1" - run_config_1 = RunConfig(session_id=session_id_1) + run_config_1 = RunConfig(memory=memory, session_id=session_id_1) model.set_next_output([get_text_message("I like cats")]) result1 = await Runner.run(agent, "I like cats", run_config=run_config_1) @@ -132,7 +119,7 @@ async def test_session_memory_different_sessions(): # Session 2 - different session session_id_2 = "session_2" - run_config_2 = RunConfig(session_id=session_id_2) + run_config_2 = RunConfig(memory=memory, session_id=session_id_2) model.set_next_output([get_text_message("I like dogs")]) result2 = await Runner.run(agent, "I like dogs", run_config=run_config_2) @@ -152,11 +139,10 @@ async def test_session_memory_different_sessions(): async def test_session_memory_no_session_id(): """Test that session memory raises an exception when no session_id is provided.""" model = FakeModel() - agent = Agent( - name="test", model=model, memory=True # Memory enabled but no session_id - ) + agent = Agent(name="test", model=model) + memory = SQLiteSessionMemory() - run_config = RunConfig() # No session_id + run_config = RunConfig(memory=memory) # Memory enabled but no session_id # Should raise ValueError when trying to run with memory enabled but no session_id with pytest.raises( @@ -200,5 +186,5 @@ async def test_sqlite_session_memory_direct(): def test_session_memory_invalid_config(): """Test that invalid memory configuration raises ValueError.""" with pytest.raises(ValueError, match="Invalid memory configuration"): - agent = Agent(name="test", memory="invalid") - Runner._get_session_memory(agent) + run_config = RunConfig(memory="invalid") + Runner._get_session_memory(run_config) From 54e9dfac557c4839d97c39b7cacfc60d38f4e652 Mon Sep 17 00:00:00 2001 From: Stephan Fitzpatrick Date: Fri, 23 May 2025 23:00:47 -0700 Subject: [PATCH 06/26] Add session memory documentation and examples - Updated `mkdocs.yml` to include `session_memory.md` in the documentation. - Enhanced `index.md` to highlight the new **Session Memory** feature for automatic conversation history management. - Modified `running_agents.md` to include details about the `memory` and `session_id` parameters in `RunConfig`. - Added comprehensive documentation for session memory functionality in the new `session_memory.md` file, including usage examples and best practices. --- docs/index.md | 2 + docs/running_agents.md | 40 ++++++- docs/session_memory.md | 243 +++++++++++++++++++++++++++++++++++++++++ mkdocs.yml | 2 + 4 files changed, 286 insertions(+), 1 deletion(-) create mode 100644 docs/session_memory.md diff --git a/docs/index.md b/docs/index.md index 8aef6574..5111dda7 100644 --- a/docs/index.md +++ b/docs/index.md @@ -5,6 +5,7 @@ The [OpenAI Agents SDK](https://github.com/openai/openai-agents-python) enables - **Agents**, which are LLMs equipped with instructions and tools - **Handoffs**, which allow agents to delegate to other agents for specific tasks - **Guardrails**, which enable the inputs to agents to be validated +- **Session Memory**, which automatically maintains conversation history across agent runs In combination with Python, these primitives are powerful enough to express complex relationships between tools and agents, and allow you to build real-world applications without a steep learning curve. In addition, the SDK comes with built-in **tracing** that lets you visualize and debug your agentic flows, as well as evaluate them and even fine-tune models for your application. @@ -21,6 +22,7 @@ Here are the main features of the SDK: - Python-first: Use built-in language features to orchestrate and chain agents, rather than needing to learn new abstractions. - Handoffs: A powerful feature to coordinate and delegate between multiple agents. - Guardrails: Run input validations and checks in parallel to your agents, breaking early if the checks fail. +- Session Memory: Automatic conversation history management across agent runs, eliminating manual state handling. - Function tools: Turn any Python function into a tool, with automatic schema generation and Pydantic-powered validation. - Tracing: Built-in tracing that lets you visualize, debug and monitor your workflows, as well as use the OpenAI suite of evaluation, fine-tuning and distillation tools. diff --git a/docs/running_agents.md b/docs/running_agents.md index f631cf46..5c30a34b 100644 --- a/docs/running_agents.md +++ b/docs/running_agents.md @@ -49,6 +49,8 @@ The `run_config` parameter lets you configure some global settings for the agent - [`model`][agents.run.RunConfig.model]: Allows setting a global LLM model to use, irrespective of what `model` each Agent has. - [`model_provider`][agents.run.RunConfig.model_provider]: A model provider for looking up model names, which defaults to OpenAI. - [`model_settings`][agents.run.RunConfig.model_settings]: Overrides agent-specific settings. For example, you can set a global `temperature` or `top_p`. +- [`memory`][agents.run.RunConfig.memory]: Session memory instance for automatic conversation history management. See [Session Memory](session_memory.md) for details. +- [`session_id`][agents.run.RunConfig.session_id]: Unique identifier for the conversation session. Required when `memory` is enabled. - [`input_guardrails`][agents.run.RunConfig.input_guardrails], [`output_guardrails`][agents.run.RunConfig.output_guardrails]: A list of input or output guardrails to include on all runs. - [`handoff_input_filter`][agents.run.RunConfig.handoff_input_filter]: A global input filter to apply to all handoffs, if the handoff doesn't already have one. The input filter allows you to edit the inputs that are sent to the new agent. See the documentation in [`Handoff.input_filter`][agents.handoffs.Handoff.input_filter] for more details. - [`tracing_disabled`][agents.run.RunConfig.tracing_disabled]: Allows you to disable [tracing](tracing.md) for the entire run. @@ -65,7 +67,9 @@ Calling any of the run methods can result in one or more agents running (and hen At the end of the agent run, you can choose what to show to the user. For example, you might show the user every new item generated by the agents, or just the final output. Either way, the user might then ask a followup question, in which case you can call the run method again. -You can use the base [`RunResultBase.to_input_list()`][agents.result.RunResultBase.to_input_list] method to get the inputs for the next turn. +### Manual conversation management + +You can manually manage conversation history using the [`RunResultBase.to_input_list()`][agents.result.RunResultBase.to_input_list] method to get the inputs for the next turn: ```python async def main(): @@ -84,6 +88,40 @@ async def main(): # California ``` +### Automatic conversation management with Session Memory + +For a simpler approach, you can use [Session Memory](session_memory.md) to automatically handle conversation history without manually calling `.to_input_list()`: + +```python +from agents import Agent, Runner, RunConfig, SQLiteSessionMemory + +async def main(): + agent = Agent(name="Assistant", instructions="Reply very concisely.") + + # Create session memory and run config + memory = SQLiteSessionMemory() + run_config = RunConfig(memory=memory, session_id="conversation_123") + + with trace(workflow_name="Conversation", group_id=thread_id): + # First turn + result = await Runner.run(agent, "What city is the Golden Gate Bridge in?", run_config=run_config) + print(result.final_output) + # San Francisco + + # Second turn - agent automatically remembers previous context + result = await Runner.run(agent, "What state is it in?", run_config=run_config) + print(result.final_output) + # California +``` + +Session memory automatically: + +- Retrieves conversation history before each run +- Stores new messages after each run +- Maintains separate conversations for different session IDs + +See the [Session Memory documentation](session_memory.md) for more details. + ## Exceptions The SDK raises exceptions in certain cases. The full list is in [`agents.exceptions`][]. As an overview: diff --git a/docs/session_memory.md b/docs/session_memory.md new file mode 100644 index 00000000..11b7fc22 --- /dev/null +++ b/docs/session_memory.md @@ -0,0 +1,243 @@ +# Session Memory + +The Agents SDK provides built-in session memory to automatically maintain conversation history across multiple agent runs, eliminating the need to manually handle `.to_input_list()` between turns. + +Session memory stores conversation history across agent runs, allowing agents to maintain context without requiring explicit manual memory management. This is particularly useful for building chat applications or multi-turn conversations where you want the agent to remember previous interactions. + +## Quick start + +```python +from agents import Agent, Runner, RunConfig, SQLiteSessionMemory + +# Create agent +agent = Agent( + name="Assistant", + instructions="Reply very concisely.", +) + +# Create a session memory instance +memory = SQLiteSessionMemory() + +# Configure run with session memory and session ID +run_config = RunConfig( + memory=memory, # Use our session memory instance + session_id="conversation_123" +) + +# First turn +result = await Runner.run(agent, "What city is the Golden Gate Bridge in?", run_config=run_config) +print(result.final_output) # "San Francisco" + +# Second turn - agent automatically remembers previous context +result = await Runner.run(agent, "What state is it in?", run_config=run_config) +print(result.final_output) # "California" + +# Also works with synchronous runner +result = Runner.run_sync(agent, "What's the population?", run_config=run_config) +print(result.final_output) # "Approximately 39 million" +``` + +## How it works + +When session memory is enabled: + +1. **Before each run**: The runner automatically retrieves the conversation history for the given `session_id` and prepends it to the input messages. +2. **After each run**: All new messages generated during the run (user input, assistant responses, tool calls, etc.) are automatically stored in the session memory. +3. **Context preservation**: Each subsequent run in the same session includes the full conversation history, allowing the agent to maintain context. + +This eliminates the need to manually call `.to_input_list()` and manage conversation state between runs. + +## Memory options + +### No memory (default) + +```python +# Default behavior - no session memory +result = await Runner.run(agent, "Hello") +``` + +### SQLite memory + +```python +from agents import SQLiteSessionMemory + +# In-memory database (lost when process ends) +memory = SQLiteSessionMemory() + +# Persistent file-based database +memory = SQLiteSessionMemory("conversations.db") + +run_config = RunConfig(memory=memory, session_id="user_123") +``` + +### Multiple sessions + +```python +from agents import Agent, Runner, RunConfig, SQLiteSessionMemory + +memory = SQLiteSessionMemory("conversations.db") +agent = Agent(name="Assistant") + +# Different session IDs maintain separate conversation histories +run_config_user1 = RunConfig(memory=memory, session_id="user_123") +run_config_user2 = RunConfig(memory=memory, session_id="user_456") + +# These will have completely separate conversation histories +result1 = await Runner.run(agent, "Hello", run_config=run_config_user1) +result2 = await Runner.run(agent, "Hello", run_config=run_config_user2) +``` + +## Custom memory implementations + +You can implement your own session memory by creating a class that follows the [`SessionMemory`][agents.memory.session_memory.SessionMemory] protocol: + +```python +from agents.memory import SessionMemory +from typing import List + +class MyCustomMemory: + """Custom memory implementation following the SessionMemory protocol.""" + + async def get_messages(self, session_id: str) -> List[dict]: + """Retrieve conversation history for the session.""" + # Your implementation here + pass + + async def add_messages(self, session_id: str, messages: List[dict]) -> None: + """Store new messages for the session.""" + # Your implementation here + pass + + async def clear_session(self, session_id: str) -> None: + """Clear all messages for the session.""" + # Your implementation here + pass + +# Use your custom memory +agent = Agent(name="Assistant") +run_config = RunConfig(memory=MyCustomMemory(), session_id="my_session") +``` + +## Requirements and validation + +### session_id requirement + +When session memory is enabled, you **must** provide a `session_id` in the `RunConfig`. If you don't, the runner will raise a `ValueError`: + +```python +from agents import Agent, Runner, RunConfig, SQLiteSessionMemory + +agent = Agent(name="Assistant") +memory = SQLiteSessionMemory() + +# This will raise ValueError: "session_id is required when memory is enabled" +result = await Runner.run(agent, "Hello", run_config=RunConfig(memory=memory)) + +# This works correctly +result = await Runner.run(agent, "Hello", run_config=RunConfig(memory=memory, session_id="my_session")) +``` + +## Best practices + +### Session ID naming + +Use meaningful session IDs that help you organize conversations: + +- User-based: `"user_12345"` +- Thread-based: `"thread_abc123"` +- Context-based: `"support_ticket_456"` + +### Memory persistence + +- Use in-memory SQLite (`SQLiteSessionMemory()`) for temporary conversations +- Use file-based SQLite (`SQLiteSessionMemory("path/to/db.sqlite")`) for persistent conversations +- Consider implementing custom memory backends for production systems (Redis, PostgreSQL, etc.) + +### Session management + +```python +# Clear a session when conversation should start fresh +await memory.clear_session("user_123") + +# Different agents can share the same session memory +support_agent = Agent(name="Support") +billing_agent = Agent(name="Billing") + +# Both agents will see the same conversation history +support_config = RunConfig(memory=memory, session_id="user_123") +billing_config = RunConfig(memory=memory, session_id="user_123") +``` + +## Complete example + +Here's a complete example showing session memory in action: + +```python +import asyncio +from agents import Agent, Runner, RunConfig, SQLiteSessionMemory + + +async def main(): + # Create an agent + agent = Agent( + name="Assistant", + instructions="Reply very concisely.", + ) + + # Create a session memory instance that will persist across runs + memory = SQLiteSessionMemory("conversation_history.db") + + # Define a session ID for this conversation + session_id = "conversation_123" + + # Create run config with session memory and session ID + run_config = RunConfig( + memory=memory, + session_id=session_id + ) + + print("=== Session Memory Example ===") + print("The agent will remember previous messages automatically.\n") + + # First turn + print("First turn:") + print("User: What city is the Golden Gate Bridge in?") + result = await Runner.run( + agent, "What city is the Golden Gate Bridge in?", run_config=run_config + ) + print(f"Assistant: {result.final_output}") + print() + + # Second turn - the agent will remember the previous conversation + print("Second turn:") + print("User: What state is it in?") + result = await Runner.run(agent, "What state is it in?", run_config=run_config) + print(f"Assistant: {result.final_output}") + print() + + # Third turn - continuing the conversation + print("Third turn:") + print("User: What's the population of that state?") + result = await Runner.run( + agent, "What's the population of that state?", run_config=run_config + ) + print(f"Assistant: {result.final_output}") + print() + + print("=== Conversation Complete ===") + print("Notice how the agent remembered the context from previous turns!") + print("Session memory in RunConfig handles conversation history automatically.") + + +if __name__ == "__main__": + asyncio.run(main()) +``` + +## API Reference + +For detailed API documentation, see: + +- [`SessionMemory`][agents.memory.session_memory.SessionMemory] - Protocol interface +- [`SQLiteSessionMemory`][agents.memory.session_memory.SQLiteSessionMemory] - SQLite implementation +- [`RunConfig.memory`][agents.run.RunConfig.memory] - Run configuration +- [`RunConfig.session_id`][agents.run.RunConfig.session_id] - Session identifier diff --git a/mkdocs.yml b/mkdocs.yml index ad719670..eff3bcbf 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -57,6 +57,7 @@ plugins: - Documentation: - agents.md - running_agents.md + - session_memory.md - results.md - streaming.md - tools.md @@ -137,6 +138,7 @@ plugins: - ドキュメント: - agents.md - running_agents.md + - session_memory.md - results.md - streaming.md - tools.md From 5e4f456a31921452ea310b6a78400ad0606161d9 Mon Sep 17 00:00:00 2001 From: Stephan Fitzpatrick Date: Fri, 23 May 2025 23:07:26 -0700 Subject: [PATCH 07/26] Add memory documentation and update references - Included `memory.md` in the documentation by updating `mkdocs.yml`. - Corrected links in `session_memory.md` to point to the appropriate memory classes. - Created a new `memory.md` file detailing the `SessionMemory` and `SQLiteSessionMemory` classes. --- docs/ref/memory.md | 8 ++++++++ docs/session_memory.md | 4 ++-- mkdocs.yml | 1 + 3 files changed, 11 insertions(+), 2 deletions(-) create mode 100644 docs/ref/memory.md diff --git a/docs/ref/memory.md b/docs/ref/memory.md new file mode 100644 index 00000000..a72612f9 --- /dev/null +++ b/docs/ref/memory.md @@ -0,0 +1,8 @@ +# Memory + +::: agents.memory + + options: + members: + - SessionMemory + - SQLiteSessionMemory diff --git a/docs/session_memory.md b/docs/session_memory.md index 11b7fc22..5c924a64 100644 --- a/docs/session_memory.md +++ b/docs/session_memory.md @@ -237,7 +237,7 @@ if __name__ == "__main__": For detailed API documentation, see: -- [`SessionMemory`][agents.memory.session_memory.SessionMemory] - Protocol interface -- [`SQLiteSessionMemory`][agents.memory.session_memory.SQLiteSessionMemory] - SQLite implementation +- [`SessionMemory`][agents.memory.SessionMemory] - Protocol interface +- [`SQLiteSessionMemory`][agents.memory.SQLiteSessionMemory] - SQLite implementation - [`RunConfig.memory`][agents.run.RunConfig.memory] - Run configuration - [`RunConfig.session_id`][agents.run.RunConfig.session_id] - Session identifier diff --git a/mkdocs.yml b/mkdocs.yml index eff3bcbf..cfdac7dc 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -81,6 +81,7 @@ plugins: - ref/index.md - ref/agent.md - ref/run.md + - ref/memory.md - ref/tool.md - ref/result.md - ref/stream_events.md From 14edb3cd226c539836ac2ed2372d5750a11d145e Mon Sep 17 00:00:00 2001 From: Stephan Fitzpatrick Date: Fri, 23 May 2025 23:14:02 -0700 Subject: [PATCH 08/26] revert changes to Agent. memory is a concern of the runner --- src/agents/agent.py | 35 ++++++----------------------------- 1 file changed, 6 insertions(+), 29 deletions(-) diff --git a/src/agents/agent.py b/src/agents/agent.py index c9d3a50d..e22f579f 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -14,7 +14,6 @@ from .items import ItemHelpers from .logger import logger from .mcp import MCPUtil -from .memory import SessionMemory from .model_settings import ModelSettings from .models.interface import Model from .run_context import RunContextWrapper, TContext @@ -157,9 +156,7 @@ class Agent(Generic[TContext]): """ tool_use_behavior: ( - Literal["run_llm_again", "stop_on_first_tool"] - | StopAtTools - | ToolsToFinalOutputFunction + Literal["run_llm_again", "stop_on_first_tool"] | StopAtTools | ToolsToFinalOutputFunction ) = "run_llm_again" """This lets you configure how tool use is handled. - "run_llm_again": The default behavior. Tools are run, and then the LLM receives the results @@ -181,17 +178,6 @@ class Agent(Generic[TContext]): """Whether to reset the tool choice to the default value after a tool has been called. Defaults to True. This ensures that the agent doesn't enter an infinite loop of tool usage.""" - memory: bool | SessionMemory | None = None - """Session memory for maintaining conversation history across runs. - - - None: No session memory (default behavior) - - True: Use default SQLite-based session memory - - SessionMemory instance: Use custom session memory implementation - - When memory is enabled, the agent will automatically maintain conversation history - and you won't need to manually handle .to_input_list() between runs. - """ - def clone(self, **kwargs: Any) -> Agent[TContext]: """Make a copy of the agent, with the given arguments changed. For example, you could do: ``` @@ -223,8 +209,7 @@ def as_tool( """ @function_tool( - name_override=tool_name - or _transforms.transform_string_function_style(self.name), + name_override=tool_name or _transforms.transform_string_function_style(self.name), description_override=tool_description or "", ) async def run_agent(context: RunContextWrapper, input: str) -> str: @@ -242,9 +227,7 @@ async def run_agent(context: RunContextWrapper, input: str) -> str: return run_agent - async def get_system_prompt( - self, run_context: RunContextWrapper[TContext] - ) -> str | None: + async def get_system_prompt(self, run_context: RunContextWrapper[TContext]) -> str | None: """Get the system prompt for the agent.""" if isinstance(self.instructions, str): return self.instructions @@ -254,20 +237,14 @@ async def get_system_prompt( else: return cast(str, self.instructions(run_context, self)) elif self.instructions is not None: - logger.error( - f"Instructions must be a string or a function, got {self.instructions}" - ) + logger.error(f"Instructions must be a string or a function, got {self.instructions}") return None async def get_mcp_tools(self) -> list[Tool]: """Fetches the available tools from the MCP servers.""" - convert_schemas_to_strict = self.mcp_config.get( - "convert_schemas_to_strict", False - ) - return await MCPUtil.get_all_function_tools( - self.mcp_servers, convert_schemas_to_strict - ) + convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False) + return await MCPUtil.get_all_function_tools(self.mcp_servers, convert_schemas_to_strict) async def get_all_tools(self) -> list[Tool]: """All agent tools, including MCP tools and function tools.""" From b422736cf4dde8450db0c1dfcde36e1b92d75c42 Mon Sep 17 00:00:00 2001 From: Stephan Fitzpatrick Date: Fri, 23 May 2025 23:18:36 -0700 Subject: [PATCH 09/26] Enhance SQLiteSessionMemory to support customizable table names for sessions and messages - Updated the constructor to accept `sessions_table` and `messages_table` parameters, allowing users to specify custom table names. - Modified SQL queries to utilize the provided table names, ensuring flexibility in database schema. - Adjusted index creation and deletion queries to reflect the new table name parameters. --- src/agents/memory/session_memory.py | 51 ++++++++++++++++++----------- 1 file changed, 32 insertions(+), 19 deletions(-) diff --git a/src/agents/memory/session_memory.py b/src/agents/memory/session_memory.py index 376115bb..9bad42d3 100644 --- a/src/agents/memory/session_memory.py +++ b/src/agents/memory/session_memory.py @@ -57,13 +57,22 @@ class SQLiteSessionMemory(SessionMemory): For persistent storage, provide a file path. """ - def __init__(self, db_path: str | Path = ":memory:"): + def __init__( + self, + db_path: str | Path = ":memory:", + sessions_table: str = "agent_sessions", + messages_table: str = "agent_messages", + ): """Initialize the SQLite session memory. Args: db_path: Path to the SQLite database file. Defaults to ':memory:' (in-memory database) + sessions_table: Name of the table to store session metadata. Defaults to 'agent_sessions' + messages_table: Name of the table to store message data. Defaults to 'agent_messages' """ self.db_path = db_path + self.sessions_table = sessions_table + self.messages_table = messages_table self._local = threading.local() self._init_db() @@ -81,8 +90,8 @@ def _init_db(self) -> None: """Initialize the database schema.""" conn = self._get_connection() conn.execute( - """ - CREATE TABLE IF NOT EXISTS sessions ( + f""" + CREATE TABLE IF NOT EXISTS {self.sessions_table} ( session_id TEXT PRIMARY KEY, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP @@ -91,21 +100,21 @@ def _init_db(self) -> None: ) conn.execute( - """ - CREATE TABLE IF NOT EXISTS messages ( + f""" + CREATE TABLE IF NOT EXISTS {self.messages_table} ( id INTEGER PRIMARY KEY AUTOINCREMENT, session_id TEXT NOT NULL, message_data TEXT NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - FOREIGN KEY (session_id) REFERENCES sessions (session_id) ON DELETE CASCADE + FOREIGN KEY (session_id) REFERENCES {self.sessions_table} (session_id) ON DELETE CASCADE ) """ ) conn.execute( - """ - CREATE INDEX IF NOT EXISTS idx_messages_session_id - ON messages (session_id, created_at) + f""" + CREATE INDEX IF NOT EXISTS idx_{self.messages_table}_session_id + ON {self.messages_table} (session_id, created_at) """ ) @@ -122,8 +131,8 @@ async def get_messages(self, session_id: str) -> list[TResponseInputItem]: """ conn = self._get_connection() cursor = conn.execute( - """ - SELECT message_data FROM messages + f""" + SELECT message_data FROM {self.messages_table} WHERE session_id = ? ORDER BY created_at ASC """, @@ -157,8 +166,8 @@ async def add_messages( # Ensure session exists conn.execute( - """ - INSERT OR IGNORE INTO sessions (session_id) VALUES (?) + f""" + INSERT OR IGNORE INTO {self.sessions_table} (session_id) VALUES (?) """, (session_id,), ) @@ -166,16 +175,16 @@ async def add_messages( # Add messages message_data = [(session_id, json.dumps(message)) for message in messages] conn.executemany( - """ - INSERT INTO messages (session_id, message_data) VALUES (?, ?) + f""" + INSERT INTO {self.messages_table} (session_id, message_data) VALUES (?, ?) """, message_data, ) # Update session timestamp conn.execute( - """ - UPDATE sessions SET updated_at = CURRENT_TIMESTAMP WHERE session_id = ? + f""" + UPDATE {self.sessions_table} SET updated_at = CURRENT_TIMESTAMP WHERE session_id = ? """, (session_id,), ) @@ -189,8 +198,12 @@ async def clear_session(self, session_id: str) -> None: session_id: Unique identifier for the conversation session """ conn = self._get_connection() - conn.execute("DELETE FROM messages WHERE session_id = ?", (session_id,)) - conn.execute("DELETE FROM sessions WHERE session_id = ?", (session_id,)) + conn.execute( + f"DELETE FROM {self.messages_table} WHERE session_id = ?", (session_id,) + ) + conn.execute( + f"DELETE FROM {self.sessions_table} WHERE session_id = ?", (session_id,) + ) conn.commit() def close(self) -> None: From ad03ae7bf9f897d07fc1cb0fd52e3b6d3397ec7e Mon Sep 17 00:00:00 2001 From: Stephan Fitzpatrick Date: Sat, 24 May 2025 10:39:31 -0700 Subject: [PATCH 10/26] Add pop_message functionality to SQLiteSessionMemory - Implemented the `pop_message` method to remove and return the most recent message from a session. - Updated the `SessionMemory` protocol to include the new method signature. - Enhanced documentation in `session_memory.md` with examples demonstrating the usage of `pop_message`. - Added tests to verify the functionality of `pop_message`, including edge cases for empty sessions and multiple sessions. --- docs/session_memory.md | 59 ++++++++++++++++ src/agents/memory/session_memory.py | 58 ++++++++++++++++ tests/test_session_memory.py | 100 ++++++++++++++++++++++++++++ 3 files changed, 217 insertions(+) diff --git a/docs/session_memory.md b/docs/session_memory.md index 5c924a64..15ba917f 100644 --- a/docs/session_memory.md +++ b/docs/session_memory.md @@ -47,6 +47,60 @@ When session memory is enabled: This eliminates the need to manually call `.to_input_list()` and manage conversation state between runs. +## Memory operations + +### Basic operations + +Session memory supports several operations for managing conversation history: + +```python +from agents import SQLiteSessionMemory + +memory = SQLiteSessionMemory("conversations.db") +session_id = "user_123" + +# Get all messages in a session +messages = await memory.get_messages(session_id) + +# Add new messages to a session +new_messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"} +] +await memory.add_messages(session_id, new_messages) + +# Remove and return the most recent message +last_message = await memory.pop_message(session_id) +print(last_message) # {"role": "assistant", "content": "Hi there!"} + +# Clear all messages from a session +await memory.clear_session(session_id) +``` + +### Using pop_message for corrections + +The `pop_message` method is particularly useful when you want to undo or modify the last message in a conversation: + +```python +from agents import Agent, Runner, RunConfig, SQLiteSessionMemory + +agent = Agent(name="Assistant") +memory = SQLiteSessionMemory() +run_config = RunConfig(memory=memory, session_id="correction_example") + +# Initial conversation +result = await Runner.run(agent, "What's 2 + 2?", run_config=run_config) +print(f"Agent: {result.final_output}") + +# User wants to correct their question +user_message = await memory.pop_message("correction_example") # Remove user's question +assistant_message = await memory.pop_message("correction_example") # Remove agent's response + +# Ask a corrected question +result = await Runner.run(agent, "What's 2 + 3?", run_config=run_config) +print(f"Agent: {result.final_output}") +``` + ## Memory options ### No memory (default) @@ -108,6 +162,11 @@ class MyCustomMemory: # Your implementation here pass + async def pop_message(self, session_id: str) -> dict | None: + """Remove and return the most recent message from the session.""" + # Your implementation here + pass + async def clear_session(self, session_id: str) -> None: """Clear all messages for the session.""" # Your implementation here diff --git a/src/agents/memory/session_memory.py b/src/agents/memory/session_memory.py index 9bad42d3..fed50caf 100644 --- a/src/agents/memory/session_memory.py +++ b/src/agents/memory/session_memory.py @@ -40,6 +40,17 @@ async def add_messages( """ ... + async def pop_message(self, session_id: str) -> TResponseInputItem | None: + """Remove and return the most recent message from the session. + + Args: + session_id: Unique identifier for the conversation session + + Returns: + The most recent message if it exists, None if the session is empty + """ + ... + async def clear_session(self, session_id: str) -> None: """Clear all messages for a given session. @@ -191,6 +202,53 @@ async def add_messages( conn.commit() + async def pop_message(self, session_id: str) -> TResponseInputItem | None: + """Remove and return the most recent message from the session. + + Args: + session_id: Unique identifier for the conversation session + + Returns: + The most recent message if it exists, None if the session is empty + """ + conn = self._get_connection() + cursor = conn.execute( + f""" + SELECT id, message_data FROM {self.messages_table} + WHERE session_id = ? + ORDER BY created_at DESC + LIMIT 1 + """, + (session_id,), + ) + + result = cursor.fetchone() + if result: + message_id, message_data = result + try: + message = json.loads(message_data) + # Delete the message by ID + conn.execute( + f""" + DELETE FROM {self.messages_table} WHERE id = ? + """, + (message_id,), + ) + conn.commit() + return message + except json.JSONDecodeError: + # Skip invalid JSON entries, but still delete the corrupted record + conn.execute( + f""" + DELETE FROM {self.messages_table} WHERE id = ? + """, + (message_id,), + ) + conn.commit() + return None + + return None + async def clear_session(self, session_id: str) -> None: """Clear all messages for a given session. diff --git a/tests/test_session_memory.py b/tests/test_session_memory.py index 49f8b748..c90b137f 100644 --- a/tests/test_session_memory.py +++ b/tests/test_session_memory.py @@ -183,6 +183,106 @@ async def test_sqlite_session_memory_direct(): memory.close() +@pytest.mark.asyncio +async def test_sqlite_session_memory_pop_message(): + """Test SQLiteSessionMemory pop_message functionality.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_pop.db" + memory = SQLiteSessionMemory(db_path) + + session_id = "pop_test" + + # Test popping from empty session + popped = await memory.pop_message(session_id) + assert popped is None + + # Add messages + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + {"role": "user", "content": "How are you?"}, + ] + + await memory.add_messages(session_id, messages) + + # Verify all messages are there + retrieved = await memory.get_messages(session_id) + assert len(retrieved) == 3 + + # Pop the most recent message + popped = await memory.pop_message(session_id) + assert popped is not None + assert popped["role"] == "user" + assert popped["content"] == "How are you?" + + # Verify message was removed + retrieved_after_pop = await memory.get_messages(session_id) + assert len(retrieved_after_pop) == 2 + assert retrieved_after_pop[-1]["content"] == "Hi there!" + + # Pop another message + popped2 = await memory.pop_message(session_id) + assert popped2 is not None + assert popped2["role"] == "assistant" + assert popped2["content"] == "Hi there!" + + # Pop the last message + popped3 = await memory.pop_message(session_id) + assert popped3 is not None + assert popped3["role"] == "user" + assert popped3["content"] == "Hello" + + # Try to pop from empty session again + popped4 = await memory.pop_message(session_id) + assert popped4 is None + + # Verify session is empty + final_messages = await memory.get_messages(session_id) + assert len(final_messages) == 0 + + memory.close() + + +@pytest.mark.asyncio +async def test_session_memory_pop_different_sessions(): + """Test that pop_message only affects the specified session.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_pop_sessions.db" + memory = SQLiteSessionMemory(db_path) + + session_1 = "session_1" + session_2 = "session_2" + + # Add messages to both sessions + messages_1 = [ + {"role": "user", "content": "Session 1 message"}, + ] + messages_2 = [ + {"role": "user", "content": "Session 2 message 1"}, + {"role": "user", "content": "Session 2 message 2"}, + ] + + await memory.add_messages(session_1, messages_1) + await memory.add_messages(session_2, messages_2) + + # Pop from session 2 + popped = await memory.pop_message(session_2) + assert popped is not None + assert popped["content"] == "Session 2 message 2" + + # Verify session 1 is unaffected + session_1_messages = await memory.get_messages(session_1) + assert len(session_1_messages) == 1 + assert session_1_messages[0]["content"] == "Session 1 message" + + # Verify session 2 has one message left + session_2_messages = await memory.get_messages(session_2) + assert len(session_2_messages) == 1 + assert session_2_messages[0]["content"] == "Session 2 message 1" + + memory.close() + + def test_session_memory_invalid_config(): """Test that invalid memory configuration raises ValueError.""" with pytest.raises(ValueError, match="Invalid memory configuration"): From f0241dbeb0d94fd3e1ea187599ec13a3244b404e Mon Sep 17 00:00:00 2001 From: Stephan Fitzpatrick Date: Sat, 24 May 2025 13:00:52 -0700 Subject: [PATCH 11/26] Refactor SQLiteSessionMemory methods to support asynchronous execution - Converted synchronous database operations in `get_messages`, `add_messages`, `pop_message`, and `clear_session` methods to asynchronous using `asyncio.to_thread`. - Improved performance and responsiveness of the session memory handling by allowing non-blocking database interactions. --- src/agents/memory/session_memory.py | 188 +++++++++++++++------------- 1 file changed, 102 insertions(+), 86 deletions(-) diff --git a/src/agents/memory/session_memory.py b/src/agents/memory/session_memory.py index fed50caf..b74fe601 100644 --- a/src/agents/memory/session_memory.py +++ b/src/agents/memory/session_memory.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import json import sqlite3 import threading @@ -140,26 +141,30 @@ async def get_messages(self, session_id: str) -> list[TResponseInputItem]: Returns: List of input items representing the conversation history """ - conn = self._get_connection() - cursor = conn.execute( - f""" - SELECT message_data FROM {self.messages_table} - WHERE session_id = ? - ORDER BY created_at ASC - """, - (session_id,), - ) - messages = [] - for (message_data,) in cursor.fetchall(): - try: - message = json.loads(message_data) - messages.append(message) - except json.JSONDecodeError: - # Skip invalid JSON entries - continue + def _get_messages_sync(): + conn = self._get_connection() + cursor = conn.execute( + f""" + SELECT message_data FROM {self.messages_table} + WHERE session_id = ? + ORDER BY created_at ASC + """, + (session_id,), + ) + + messages = [] + for (message_data,) in cursor.fetchall(): + try: + message = json.loads(message_data) + messages.append(message) + except json.JSONDecodeError: + # Skip invalid JSON entries + continue - return messages + return messages + + return await asyncio.to_thread(_get_messages_sync) async def add_messages( self, session_id: str, messages: list[TResponseInputItem] @@ -173,34 +178,37 @@ async def add_messages( if not messages: return - conn = self._get_connection() + def _add_messages_sync(): + conn = self._get_connection() - # Ensure session exists - conn.execute( - f""" - INSERT OR IGNORE INTO {self.sessions_table} (session_id) VALUES (?) - """, - (session_id,), - ) + # Ensure session exists + conn.execute( + f""" + INSERT OR IGNORE INTO {self.sessions_table} (session_id) VALUES (?) + """, + (session_id,), + ) - # Add messages - message_data = [(session_id, json.dumps(message)) for message in messages] - conn.executemany( - f""" - INSERT INTO {self.messages_table} (session_id, message_data) VALUES (?, ?) - """, - message_data, - ) + # Add messages + message_data = [(session_id, json.dumps(message)) for message in messages] + conn.executemany( + f""" + INSERT INTO {self.messages_table} (session_id, message_data) VALUES (?, ?) + """, + message_data, + ) - # Update session timestamp - conn.execute( - f""" - UPDATE {self.sessions_table} SET updated_at = CURRENT_TIMESTAMP WHERE session_id = ? - """, - (session_id,), - ) + # Update session timestamp + conn.execute( + f""" + UPDATE {self.sessions_table} SET updated_at = CURRENT_TIMESTAMP WHERE session_id = ? + """, + (session_id,), + ) - conn.commit() + conn.commit() + + await asyncio.to_thread(_add_messages_sync) async def pop_message(self, session_id: str) -> TResponseInputItem | None: """Remove and return the most recent message from the session. @@ -211,43 +219,47 @@ async def pop_message(self, session_id: str) -> TResponseInputItem | None: Returns: The most recent message if it exists, None if the session is empty """ - conn = self._get_connection() - cursor = conn.execute( - f""" - SELECT id, message_data FROM {self.messages_table} - WHERE session_id = ? - ORDER BY created_at DESC - LIMIT 1 - """, - (session_id,), - ) - result = cursor.fetchone() - if result: - message_id, message_data = result - try: - message = json.loads(message_data) - # Delete the message by ID - conn.execute( - f""" - DELETE FROM {self.messages_table} WHERE id = ? - """, - (message_id,), - ) - conn.commit() - return message - except json.JSONDecodeError: - # Skip invalid JSON entries, but still delete the corrupted record - conn.execute( - f""" - DELETE FROM {self.messages_table} WHERE id = ? - """, - (message_id,), - ) - conn.commit() - return None - - return None + def _pop_message_sync(): + conn = self._get_connection() + cursor = conn.execute( + f""" + SELECT id, message_data FROM {self.messages_table} + WHERE session_id = ? + ORDER BY created_at DESC + LIMIT 1 + """, + (session_id,), + ) + + result = cursor.fetchone() + if result: + message_id, message_data = result + try: + message = json.loads(message_data) + # Delete the message by ID + conn.execute( + f""" + DELETE FROM {self.messages_table} WHERE id = ? + """, + (message_id,), + ) + conn.commit() + return message + except json.JSONDecodeError: + # Skip invalid JSON entries, but still delete the corrupted record + conn.execute( + f""" + DELETE FROM {self.messages_table} WHERE id = ? + """, + (message_id,), + ) + conn.commit() + return None + + return None + + return await asyncio.to_thread(_pop_message_sync) async def clear_session(self, session_id: str) -> None: """Clear all messages for a given session. @@ -255,14 +267,18 @@ async def clear_session(self, session_id: str) -> None: Args: session_id: Unique identifier for the conversation session """ - conn = self._get_connection() - conn.execute( - f"DELETE FROM {self.messages_table} WHERE session_id = ?", (session_id,) - ) - conn.execute( - f"DELETE FROM {self.sessions_table} WHERE session_id = ?", (session_id,) - ) - conn.commit() + + def _clear_session_sync(): + conn = self._get_connection() + conn.execute( + f"DELETE FROM {self.messages_table} WHERE session_id = ?", (session_id,) + ) + conn.execute( + f"DELETE FROM {self.sessions_table} WHERE session_id = ?", (session_id,) + ) + conn.commit() + + await asyncio.to_thread(_clear_session_sync) def close(self) -> None: """Close the database connection.""" From dcae4c706f57848b2dd3f4d93a91129ebe37d72b Mon Sep 17 00:00:00 2001 From: Stephan Fitzpatrick Date: Sat, 24 May 2025 13:07:27 -0700 Subject: [PATCH 12/26] Add validation for session_id when memory is disabled in Runner class - Implemented a check in the Runner class to raise a ValueError if a session_id is provided without enabling memory in the RunConfig. - Updated tests to verify that the appropriate exception is raised when session_id is used without memory. --- src/agents/run.py | 6 ++++++ tests/test_session_memory.py | 19 +++++++++++++++++-- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/src/agents/run.py b/src/agents/run.py index 26c4a66b..8f0cdd2b 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -1075,6 +1075,12 @@ async def _prepare_input_with_memory( """Prepare input by combining it with session memory if enabled.""" memory = cls._get_session_memory(run_config) if memory is None: + # Check if session_id is provided without memory + if run_config.session_id is not None: + raise ValueError( + "session_id provided but memory is disabled. " + "Please enable memory in the RunConfig or remove session_id." + ) return input, memory if run_config.session_id is None: diff --git a/tests/test_session_memory.py b/tests/test_session_memory.py index c90b137f..3855b41a 100644 --- a/tests/test_session_memory.py +++ b/tests/test_session_memory.py @@ -79,8 +79,7 @@ async def test_session_memory_disabled(): model = FakeModel() agent = Agent(name="test", model=model) - session_id = "test_session_789" - run_config = RunConfig(memory=None, session_id=session_id) # No session memory + run_config = RunConfig(memory=None) # No session memory # First turn model.set_next_output([get_text_message("Hello")]) @@ -151,6 +150,22 @@ async def test_session_memory_no_session_id(): await Runner.run(agent, "Hi there", run_config=run_config) +@pytest.mark.asyncio +async def test_session_id_without_memory(): + """Test that providing session_id without memory raises an exception.""" + model = FakeModel() + agent = Agent(name="test", model=model) + + session_id = "test_session_without_memory" + run_config = RunConfig( + memory=None, session_id=session_id + ) # session_id but no memory + + # Should raise ValueError when trying to run with session_id but no memory + with pytest.raises(ValueError, match="session_id provided but memory is disabled"): + await Runner.run(agent, "Hi there", run_config=run_config) + + @pytest.mark.asyncio async def test_sqlite_session_memory_direct(): """Test SQLiteSessionMemory class directly.""" From 4af2192232e7873b6a9aea15f57506442b756eed Mon Sep 17 00:00:00 2001 From: Stephan Fitzpatrick Date: Sat, 24 May 2025 13:22:48 -0700 Subject: [PATCH 13/26] Refactor session memory handling in documentation and examples - Updated README, session_memory.md, and example scripts to remove the use of RunConfig for session memory configuration, directly passing memory and session_id parameters to the Runner.run method. - Enhanced clarity in documentation regarding the requirement of session_id when memory is enabled. - Adjusted tests to reflect the new approach, ensuring consistent behavior across different configurations. --- README.md | 65 ++++++++--- docs/session_memory.md | 142 ++++++++++++++++------- examples/basic/session_memory_example.py | 23 ++-- src/agents/run.py | 66 ++++++----- tests/test_session_memory.py | 55 ++++----- 5 files changed, 217 insertions(+), 134 deletions(-) diff --git a/README.md b/README.md index d63f864f..3a96fc43 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ The Agents SDK provides built-in session memory to automatically maintain conver ### Quick start ```python -from agents import Agent, Runner, RunConfig, SQLiteSessionMemory +from agents import Agent, Runner, SQLiteSessionMemory # Create agent agent = Agent( @@ -32,40 +32,59 @@ agent = Agent( # Create a session memory instance memory = SQLiteSessionMemory() -# Configure run with session memory and session ID -run_config = RunConfig( - memory=memory, # Use our session memory instance +# First turn +result = await Runner.run( + agent, + "What city is the Golden Gate Bridge in?", + memory=memory, session_id="conversation_123" ) - -# First turn -result = await Runner.run(agent, "What city is the Golden Gate Bridge in?", run_config=run_config) print(result.final_output) # "San Francisco" # Second turn - agent automatically remembers previous context -result = await Runner.run(agent, "What state is it in?", run_config=run_config) +result = await Runner.run( + agent, + "What state is it in?", + memory=memory, + session_id="conversation_123" +) print(result.final_output) # "California" # Also works with synchronous runner -result = Runner.run_sync(agent, "What's the population?", run_config=run_config) +result = Runner.run_sync( + agent, + "What's the population?", + memory=memory, + session_id="conversation_123" +) print(result.final_output) # "Approximately 39 million" ``` ### Memory options -- **`memory=None`** (default): No session memory +- **No memory** (default): No session memory when memory parameter is omitted - **`memory=SessionMemory`**: Use the provided session memory implementation ```python -from agents import Agent, Runner, RunConfig, SQLiteSessionMemory +from agents import Agent, Runner, SQLiteSessionMemory # Custom SQLite database file memory = SQLiteSessionMemory("conversations.db") agent = Agent(name="Assistant") # Different session IDs maintain separate conversation histories -run_config_1 = RunConfig(memory=memory, session_id="user_123") -run_config_2 = RunConfig(memory=memory, session_id="user_456") +result1 = await Runner.run( + agent, + "Hello", + memory=memory, + session_id="user_123" +) +result2 = await Runner.run( + agent, + "Hello", + memory=memory, + session_id="user_456" +) ``` ### Custom memory implementations @@ -93,24 +112,34 @@ class MyCustomMemory: # Use your custom memory agent = Agent(name="Assistant") -run_config = RunConfig(memory=MyCustomMemory(), session_id="my_session") +result = await Runner.run( + agent, + "Hello", + memory=MyCustomMemory(), + session_id="my_session" +) ``` ### Important: session_id requirement -When session memory is enabled, you **must** provide a `session_id` in the `RunConfig`. If you don't, the runner will raise a `ValueError`: +When session memory is enabled, you **must** provide a `session_id`. If you don't, the runner will raise a `ValueError`: ```python -from agents import Agent, Runner, RunConfig, SQLiteSessionMemory +from agents import Agent, Runner, SQLiteSessionMemory agent = Agent(name="Assistant") memory = SQLiteSessionMemory() # This will raise ValueError: "session_id is required when memory is enabled" -result = await Runner.run(agent, "Hello", run_config=RunConfig(memory=memory)) +result = await Runner.run(agent, "Hello", memory=memory) # This works correctly -result = await Runner.run(agent, "Hello", run_config=RunConfig(memory=memory, session_id="my_session")) +result = await Runner.run( + agent, + "Hello", + memory=memory, + session_id="my_session" +) ``` ## Get started diff --git a/docs/session_memory.md b/docs/session_memory.md index 15ba917f..5ecaee4f 100644 --- a/docs/session_memory.md +++ b/docs/session_memory.md @@ -7,7 +7,7 @@ Session memory stores conversation history across agent runs, allowing agents to ## Quick start ```python -from agents import Agent, Runner, RunConfig, SQLiteSessionMemory +from agents import Agent, Runner, SQLiteSessionMemory # Create agent agent = Agent( @@ -18,22 +18,31 @@ agent = Agent( # Create a session memory instance memory = SQLiteSessionMemory() -# Configure run with session memory and session ID -run_config = RunConfig( - memory=memory, # Use our session memory instance +# First turn +result = await Runner.run( + agent, + "What city is the Golden Gate Bridge in?", + memory=memory, session_id="conversation_123" ) - -# First turn -result = await Runner.run(agent, "What city is the Golden Gate Bridge in?", run_config=run_config) print(result.final_output) # "San Francisco" # Second turn - agent automatically remembers previous context -result = await Runner.run(agent, "What state is it in?", run_config=run_config) +result = await Runner.run( + agent, + "What state is it in?", + memory=memory, + session_id="conversation_123" +) print(result.final_output) # "California" # Also works with synchronous runner -result = Runner.run_sync(agent, "What's the population?", run_config=run_config) +result = Runner.run_sync( + agent, + "What's the population?", + memory=memory, + session_id="conversation_123" +) print(result.final_output) # "Approximately 39 million" ``` @@ -82,22 +91,32 @@ await memory.clear_session(session_id) The `pop_message` method is particularly useful when you want to undo or modify the last message in a conversation: ```python -from agents import Agent, Runner, RunConfig, SQLiteSessionMemory +from agents import Agent, Runner, SQLiteSessionMemory agent = Agent(name="Assistant") memory = SQLiteSessionMemory() -run_config = RunConfig(memory=memory, session_id="correction_example") +session_id = "correction_example" # Initial conversation -result = await Runner.run(agent, "What's 2 + 2?", run_config=run_config) +result = await Runner.run( + agent, + "What's 2 + 2?", + memory=memory, + session_id=session_id +) print(f"Agent: {result.final_output}") # User wants to correct their question -user_message = await memory.pop_message("correction_example") # Remove user's question -assistant_message = await memory.pop_message("correction_example") # Remove agent's response +user_message = await memory.pop_message(session_id) # Remove user's question +assistant_message = await memory.pop_message(session_id) # Remove agent's response # Ask a corrected question -result = await Runner.run(agent, "What's 2 + 3?", run_config=run_config) +result = await Runner.run( + agent, + "What's 2 + 3?", + memory=memory, + session_id=session_id +) print(f"Agent: {result.final_output}") ``` @@ -121,31 +140,43 @@ memory = SQLiteSessionMemory() # Persistent file-based database memory = SQLiteSessionMemory("conversations.db") -run_config = RunConfig(memory=memory, session_id="user_123") +# Use the memory with session IDs +result = await Runner.run( + agent, + "Hello", + memory=memory, + session_id="user_123" +) ``` ### Multiple sessions ```python -from agents import Agent, Runner, RunConfig, SQLiteSessionMemory +from agents import Agent, Runner, SQLiteSessionMemory memory = SQLiteSessionMemory("conversations.db") agent = Agent(name="Assistant") # Different session IDs maintain separate conversation histories -run_config_user1 = RunConfig(memory=memory, session_id="user_123") -run_config_user2 = RunConfig(memory=memory, session_id="user_456") - -# These will have completely separate conversation histories -result1 = await Runner.run(agent, "Hello", run_config=run_config_user1) -result2 = await Runner.run(agent, "Hello", run_config=run_config_user2) +result1 = await Runner.run( + agent, + "Hello", + memory=memory, + session_id="user_123" +) +result2 = await Runner.run( + agent, + "Hello", + memory=memory, + session_id="user_456" +) ``` ## Custom memory implementations You can implement your own session memory by creating a class that follows the [`SessionMemory`][agents.memory.session_memory.SessionMemory] protocol: -```python +````python from agents.memory import SessionMemory from typing import List @@ -174,26 +205,35 @@ class MyCustomMemory: # Use your custom memory agent = Agent(name="Assistant") -run_config = RunConfig(memory=MyCustomMemory(), session_id="my_session") -``` +result = await Runner.run( + agent, + "Hello", + memory=MyCustomMemory(), + session_id="my_session" +) ## Requirements and validation ### session_id requirement -When session memory is enabled, you **must** provide a `session_id` in the `RunConfig`. If you don't, the runner will raise a `ValueError`: +When session memory is enabled, you **must** provide a `session_id`. If you don't, the runner will raise a `ValueError`: ```python -from agents import Agent, Runner, RunConfig, SQLiteSessionMemory +from agents import Agent, Runner, SQLiteSessionMemory agent = Agent(name="Assistant") memory = SQLiteSessionMemory() # This will raise ValueError: "session_id is required when memory is enabled" -result = await Runner.run(agent, "Hello", run_config=RunConfig(memory=memory)) +result = await Runner.run(agent, "Hello", memory=memory) # This works correctly -result = await Runner.run(agent, "Hello", run_config=RunConfig(memory=memory, session_id="my_session")) +result = await Runner.run( + agent, + "Hello", + memory=memory, + session_id="my_session" +) ``` ## Best practices @@ -223,8 +263,18 @@ support_agent = Agent(name="Support") billing_agent = Agent(name="Billing") # Both agents will see the same conversation history -support_config = RunConfig(memory=memory, session_id="user_123") -billing_config = RunConfig(memory=memory, session_id="user_123") +result1 = await Runner.run( + support_agent, + "Help me with my account", + memory=memory, + session_id="user_123" +) +result2 = await Runner.run( + billing_agent, + "What are my charges?", + memory=memory, + session_id="user_123" +) ``` ## Complete example @@ -233,7 +283,7 @@ Here's a complete example showing session memory in action: ```python import asyncio -from agents import Agent, Runner, RunConfig, SQLiteSessionMemory +from agents import Agent, Runner, SQLiteSessionMemory async def main(): @@ -249,12 +299,6 @@ async def main(): # Define a session ID for this conversation session_id = "conversation_123" - # Create run config with session memory and session ID - run_config = RunConfig( - memory=memory, - session_id=session_id - ) - print("=== Session Memory Example ===") print("The agent will remember previous messages automatically.\n") @@ -262,7 +306,10 @@ async def main(): print("First turn:") print("User: What city is the Golden Gate Bridge in?") result = await Runner.run( - agent, "What city is the Golden Gate Bridge in?", run_config=run_config + agent, + "What city is the Golden Gate Bridge in?", + memory=memory, + session_id=session_id ) print(f"Assistant: {result.final_output}") print() @@ -270,7 +317,12 @@ async def main(): # Second turn - the agent will remember the previous conversation print("Second turn:") print("User: What state is it in?") - result = await Runner.run(agent, "What state is it in?", run_config=run_config) + result = await Runner.run( + agent, + "What state is it in?", + memory=memory, + session_id=session_id + ) print(f"Assistant: {result.final_output}") print() @@ -278,14 +330,17 @@ async def main(): print("Third turn:") print("User: What's the population of that state?") result = await Runner.run( - agent, "What's the population of that state?", run_config=run_config + agent, + "What's the population of that state?", + memory=memory, + session_id=session_id ) print(f"Assistant: {result.final_output}") print() print("=== Conversation Complete ===") print("Notice how the agent remembered the context from previous turns!") - print("Session memory in RunConfig handles conversation history automatically.") + print("Session memory automatically handles conversation history.") if __name__ == "__main__": @@ -300,3 +355,4 @@ For detailed API documentation, see: - [`SQLiteSessionMemory`][agents.memory.SQLiteSessionMemory] - SQLite implementation - [`RunConfig.memory`][agents.run.RunConfig.memory] - Run configuration - [`RunConfig.session_id`][agents.run.RunConfig.session_id] - Session identifier +```` diff --git a/examples/basic/session_memory_example.py b/examples/basic/session_memory_example.py index 3fe41d63..0c4feef1 100644 --- a/examples/basic/session_memory_example.py +++ b/examples/basic/session_memory_example.py @@ -6,7 +6,7 @@ """ import asyncio -from agents import Agent, Runner, RunConfig, SQLiteSessionMemory +from agents import Agent, Runner, SQLiteSessionMemory async def main(): @@ -22,11 +22,6 @@ async def main(): # Define a session ID for this conversation session_id = "conversation_123" - # Create run config with session memory and session ID - run_config = RunConfig( - memory=memory, session_id=session_id # Use our session memory instance - ) - print("=== Session Memory Example ===") print("The agent will remember previous messages automatically.\n") @@ -34,7 +29,10 @@ async def main(): print("First turn:") print("User: What city is the Golden Gate Bridge in?") result = await Runner.run( - agent, "What city is the Golden Gate Bridge in?", run_config=run_config + agent, + "What city is the Golden Gate Bridge in?", + memory=memory, + session_id=session_id, ) print(f"Assistant: {result.final_output}") print() @@ -42,7 +40,9 @@ async def main(): # Second turn - the agent will remember the previous conversation print("Second turn:") print("User: What state is it in?") - result = await Runner.run(agent, "What state is it in?", run_config=run_config) + result = await Runner.run( + agent, "What state is it in?", memory=memory, session_id=session_id + ) print(f"Assistant: {result.final_output}") print() @@ -50,14 +50,17 @@ async def main(): print("Third turn:") print("User: What's the population of that state?") result = await Runner.run( - agent, "What's the population of that state?", run_config=run_config + agent, + "What's the population of that state?", + memory=memory, + session_id=session_id, ) print(f"Assistant: {result.final_output}") print() print("=== Conversation Complete ===") print("Notice how the agent remembered the context from previous turns!") - print("Session memory in RunConfig handles conversation history automatically.") + print("Session memory automatically handles conversation history.") if __name__ == "__main__": diff --git a/src/agents/run.py b/src/agents/run.py index 8f0cdd2b..3e01aeb1 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -111,13 +111,6 @@ class RunConfig: An optional dictionary of additional metadata to include with the trace. """ - memory: SessionMemory | None = None - """ - Session memory instance for conversation history persistence. - - None (default): No session memory - - SessionMemory instance: Use the provided session memory implementation - """ - session_id: str | None = None """ A session identifier for memory persistence. Required when memory is enabled. @@ -137,6 +130,8 @@ async def run( hooks: RunHooks[TContext] | None = None, run_config: RunConfig | None = None, previous_response_id: str | None = None, + memory: SessionMemory | None = None, + session_id: str | None = None, ) -> RunResult: """Run a workflow starting at the given agent. The agent will run in a loop until a final output is generated. The loop runs like so: @@ -163,6 +158,10 @@ async def run( run_config: Global settings for the entire agent run. previous_response_id: The ID of the previous response, if using OpenAI models via the Responses API, this allows you to skip passing in input from the previous turn. + memory: Session memory instance for conversation history persistence. + If None, no conversation history will be maintained. + session_id: A session identifier for memory persistence. Required when memory is provided. + Conversation history will be automatically managed using this session ID. Returns: A run result containing all the inputs, guardrail results and the output of the last @@ -175,7 +174,7 @@ async def run( # Prepare input with session memory if enabled prepared_input, session_memory = await cls._prepare_input_with_memory( - input, run_config + input, memory, session_id ) tool_use_tracker = AgentToolUseTracker() @@ -304,7 +303,7 @@ async def run( # Save the conversation to session memory if enabled await cls._save_result_to_memory( - session_memory, run_config.session_id, input, result + session_memory, session_id, input, result ) return result @@ -336,6 +335,8 @@ def run_sync( hooks: RunHooks[TContext] | None = None, run_config: RunConfig | None = None, previous_response_id: str | None = None, + memory: SessionMemory | None = None, + session_id: str | None = None, ) -> RunResult: """Run a workflow synchronously, starting at the given agent. Note that this just wraps the `run` method, so it will not work if there's already an event loop (e.g. inside an async @@ -366,6 +367,10 @@ def run_sync( run_config: Global settings for the entire agent run. previous_response_id: The ID of the previous response, if using OpenAI models via the Responses API, this allows you to skip passing in input from the previous turn. + memory: Session memory instance for conversation history persistence. + If None, no conversation history will be maintained. + session_id: A session identifier for memory persistence. Required when memory is provided. + Conversation history will be automatically managed using this session ID. Returns: A run result containing all the inputs, guardrail results and the output of the last @@ -380,6 +385,8 @@ def run_sync( hooks=hooks, run_config=run_config, previous_response_id=previous_response_id, + memory=memory, + session_id=session_id, ) ) @@ -393,6 +400,8 @@ def run_streamed( hooks: RunHooks[TContext] | None = None, run_config: RunConfig | None = None, previous_response_id: str | None = None, + memory: SessionMemory | None = None, + session_id: str | None = None, ) -> RunResultStreaming: """Run a workflow starting at the given agent in streaming mode. The returned result object contains a method you can use to stream semantic events as they are generated. @@ -421,6 +430,10 @@ def run_streamed( run_config: Global settings for the entire agent run. previous_response_id: The ID of the previous response, if using OpenAI models via the Responses API, this allows you to skip passing in input from the previous turn. + memory: Session memory instance for conversation history persistence. + If None, no conversation history will be maintained. + session_id: A session identifier for memory persistence. Required when memory is provided. + Conversation history will be automatically managed using this session ID. Returns: A result object that contains data about the run, as well as a method to stream events. """ @@ -476,6 +489,8 @@ def run_streamed( context_wrapper=context_wrapper, run_config=run_config, previous_response_id=previous_response_id, + memory=memory, + session_id=session_id, ) ) return streamed_result @@ -534,13 +549,15 @@ async def _run_streamed_impl( context_wrapper: RunContextWrapper[TContext], run_config: RunConfig, previous_response_id: str | None, + memory: SessionMemory | None, + session_id: str | None, ): if streamed_result.trace: streamed_result.trace.start(mark_as_current=True) # Prepare input with session memory if enabled prepared_input, session_memory = await cls._prepare_input_with_memory( - starting_input, run_config + starting_input, memory, session_id ) # Update the streamed result with the prepared input @@ -677,10 +694,7 @@ async def _run_streamed_impl( context_wrapper=context_wrapper, ) await cls._save_result_to_memory( - session_memory, - run_config.session_id, - starting_input, - temp_result, + session_memory, session_id, starting_input, temp_result ) streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) @@ -1056,41 +1070,31 @@ def _get_model(cls, agent: Agent[Any], run_config: RunConfig) -> Model: return run_config.model_provider.get_model(agent.model) - @classmethod - def _get_session_memory(cls, run_config: RunConfig) -> SessionMemory | None: - """Get the session memory instance from run config, if any.""" - if run_config.memory is None: - return None - elif isinstance(run_config.memory, SessionMemory): - return run_config.memory - else: - raise ValueError(f"Invalid memory configuration: {run_config.memory}") - @classmethod async def _prepare_input_with_memory( cls, input: str | list[TResponseInputItem], - run_config: RunConfig, + memory: SessionMemory | None, + session_id: str | None, ) -> tuple[str | list[TResponseInputItem], SessionMemory | None]: """Prepare input by combining it with session memory if enabled.""" - memory = cls._get_session_memory(run_config) if memory is None: # Check if session_id is provided without memory - if run_config.session_id is not None: + if session_id is not None: raise ValueError( "session_id provided but memory is disabled. " - "Please enable memory in the RunConfig or remove session_id." + "Please provide a memory instance or remove session_id." ) return input, memory - if run_config.session_id is None: + if session_id is None: raise ValueError( "session_id is required when memory is enabled. " - "Please provide a session_id in the RunConfig." + "Please provide a session_id." ) # Get previous conversation history - history = await memory.get_messages(run_config.session_id) + history = await memory.get_messages(session_id) # Convert input to list format new_input_list = ItemHelpers.input_to_new_input_list(input) diff --git a/tests/test_session_memory.py b/tests/test_session_memory.py index 3855b41a..b76e8fa7 100644 --- a/tests/test_session_memory.py +++ b/tests/test_session_memory.py @@ -23,18 +23,22 @@ async def test_session_memory_basic_functionality(): agent = Agent(name="test", model=model) session_id = "test_session_123" - run_config = RunConfig(memory=memory, session_id=session_id) # First turn model.set_next_output([get_text_message("San Francisco")]) result1 = await Runner.run( - agent, "What city is the Golden Gate Bridge in?", run_config=run_config + agent, + "What city is the Golden Gate Bridge in?", + memory=memory, + session_id=session_id, ) assert result1.final_output == "San Francisco" # Second turn - should have conversation history model.set_next_output([get_text_message("California")]) - result2 = await Runner.run(agent, "What state is it in?", run_config=run_config) + result2 = await Runner.run( + agent, "What state is it in?", memory=memory, session_id=session_id + ) assert result2.final_output == "California" # Verify that the input to the second turn includes the previous conversation @@ -56,17 +60,18 @@ async def test_session_memory_with_explicit_instance(): agent = Agent(name="test", model=model) session_id = "test_session_456" - run_config = RunConfig(memory=memory, session_id=session_id) # First turn model.set_next_output([get_text_message("Hello")]) - result1 = await Runner.run(agent, "Hi there", run_config=run_config) + result1 = await Runner.run( + agent, "Hi there", memory=memory, session_id=session_id + ) assert result1.final_output == "Hello" # Second turn model.set_next_output([get_text_message("I remember you said hi")]) result2 = await Runner.run( - agent, "Do you remember what I said?", run_config=run_config + agent, "Do you remember what I said?", memory=memory, session_id=session_id ) assert result2.final_output == "I remember you said hi" @@ -79,18 +84,14 @@ async def test_session_memory_disabled(): model = FakeModel() agent = Agent(name="test", model=model) - run_config = RunConfig(memory=None) # No session memory - - # First turn + # First turn (no memory parameters = disabled) model.set_next_output([get_text_message("Hello")]) - result1 = await Runner.run(agent, "Hi there", run_config=run_config) + result1 = await Runner.run(agent, "Hi there") assert result1.final_output == "Hello" # Second turn - should NOT have conversation history model.set_next_output([get_text_message("I don't remember")]) - result2 = await Runner.run( - agent, "Do you remember what I said?", run_config=run_config - ) + result2 = await Runner.run(agent, "Do you remember what I said?") assert result2.final_output == "I don't remember" # Verify that the input to the second turn is just the current message @@ -110,24 +111,26 @@ async def test_session_memory_different_sessions(): # Session 1 session_id_1 = "session_1" - run_config_1 = RunConfig(memory=memory, session_id=session_id_1) model.set_next_output([get_text_message("I like cats")]) - result1 = await Runner.run(agent, "I like cats", run_config=run_config_1) + result1 = await Runner.run( + agent, "I like cats", memory=memory, session_id=session_id_1 + ) assert result1.final_output == "I like cats" # Session 2 - different session session_id_2 = "session_2" - run_config_2 = RunConfig(memory=memory, session_id=session_id_2) model.set_next_output([get_text_message("I like dogs")]) - result2 = await Runner.run(agent, "I like dogs", run_config=run_config_2) + result2 = await Runner.run( + agent, "I like dogs", memory=memory, session_id=session_id_2 + ) assert result2.final_output == "I like dogs" # Back to Session 1 - should remember cats, not dogs model.set_next_output([get_text_message("Yes, you mentioned cats")]) result3 = await Runner.run( - agent, "What did I say I like?", run_config=run_config_1 + agent, "What did I say I like?", memory=memory, session_id=session_id_1 ) assert result3.final_output == "Yes, you mentioned cats" @@ -141,13 +144,11 @@ async def test_session_memory_no_session_id(): agent = Agent(name="test", model=model) memory = SQLiteSessionMemory() - run_config = RunConfig(memory=memory) # Memory enabled but no session_id - # Should raise ValueError when trying to run with memory enabled but no session_id with pytest.raises( ValueError, match="session_id is required when memory is enabled" ): - await Runner.run(agent, "Hi there", run_config=run_config) + await Runner.run(agent, "Hi there", memory=memory) @pytest.mark.asyncio @@ -157,13 +158,10 @@ async def test_session_id_without_memory(): agent = Agent(name="test", model=model) session_id = "test_session_without_memory" - run_config = RunConfig( - memory=None, session_id=session_id - ) # session_id but no memory # Should raise ValueError when trying to run with session_id but no memory with pytest.raises(ValueError, match="session_id provided but memory is disabled"): - await Runner.run(agent, "Hi there", run_config=run_config) + await Runner.run(agent, "Hi there", session_id=session_id) @pytest.mark.asyncio @@ -296,10 +294,3 @@ async def test_session_memory_pop_different_sessions(): assert session_2_messages[0]["content"] == "Session 2 message 1" memory.close() - - -def test_session_memory_invalid_config(): - """Test that invalid memory configuration raises ValueError.""" - with pytest.raises(ValueError, match="Invalid memory configuration"): - run_config = RunConfig(memory="invalid") - Runner._get_session_memory(run_config) From 57b4f5e8b83b041789b0612fed8607fc4696e5c3 Mon Sep 17 00:00:00 2001 From: Stephan Fitzpatrick Date: Sat, 24 May 2025 13:24:53 -0700 Subject: [PATCH 14/26] Refactor database initialization in SQLiteSessionMemory - Introduced a new method `_init_db_for_connection` to handle database schema initialization for a specific connection. - Updated the `_init_db` method to call the new method, improving clarity and separation of concerns. - Added a comment to indicate the initialization of the database schema for the connection. --- src/agents/memory/session_memory.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/agents/memory/session_memory.py b/src/agents/memory/session_memory.py index b74fe601..5eb5144a 100644 --- a/src/agents/memory/session_memory.py +++ b/src/agents/memory/session_memory.py @@ -96,11 +96,12 @@ def _get_connection(self) -> sqlite3.Connection: check_same_thread=False, ) self._local.connection.execute("PRAGMA journal_mode=WAL") + # Initialize the database schema for this connection + self._init_db_for_connection(self._local.connection) return self._local.connection - def _init_db(self) -> None: - """Initialize the database schema.""" - conn = self._get_connection() + def _init_db_for_connection(self, conn: sqlite3.Connection) -> None: + """Initialize the database schema for a specific connection.""" conn.execute( f""" CREATE TABLE IF NOT EXISTS {self.sessions_table} ( @@ -132,6 +133,11 @@ def _init_db(self) -> None: conn.commit() + def _init_db(self) -> None: + """Initialize the database schema.""" + conn = self._get_connection() + # The schema initialization is now handled in _init_db_for_connection + async def get_messages(self, session_id: str) -> list[TResponseInputItem]: """Retrieve the conversation history for a given session. From 764e82c10961e63911db9b0987014121b7bc9493 Mon Sep 17 00:00:00 2001 From: Stephan Fitzpatrick Date: Sat, 24 May 2025 13:53:07 -0700 Subject: [PATCH 15/26] Remove redundant database initialization method in SQLiteSessionMemory - Deleted the `_init_db` method as database schema initialization is now handled in `_init_db_for_connection`. - This change simplifies the class structure and improves clarity in the database connection management. --- src/agents/memory/session_memory.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/agents/memory/session_memory.py b/src/agents/memory/session_memory.py index 5eb5144a..2b9c8c0c 100644 --- a/src/agents/memory/session_memory.py +++ b/src/agents/memory/session_memory.py @@ -86,7 +86,6 @@ def __init__( self.sessions_table = sessions_table self.messages_table = messages_table self._local = threading.local() - self._init_db() def _get_connection(self) -> sqlite3.Connection: """Get a thread-local database connection.""" @@ -133,11 +132,6 @@ def _init_db_for_connection(self, conn: sqlite3.Connection) -> None: conn.commit() - def _init_db(self) -> None: - """Initialize the database schema.""" - conn = self._get_connection() - # The schema initialization is now handled in _init_db_for_connection - async def get_messages(self, session_id: str) -> list[TResponseInputItem]: """Retrieve the conversation history for a given session. From 1827673c1e88a562d11b7d16a6509f2b5b43845f Mon Sep 17 00:00:00 2001 From: Stephan Fitzpatrick Date: Sat, 24 May 2025 14:02:29 -0700 Subject: [PATCH 16/26] Enhance SQLiteSessionMemory for thread safety and connection management - Introduced a shared connection for in-memory databases to avoid thread isolation, improving concurrency. - Implemented a locking mechanism for database operations to ensure thread safety, regardless of the database type. - Updated the `_get_connection`, `_add_messages_sync`, `_pop_message_sync`, and `_clear_session_sync` methods to utilize the new locking and connection management logic. --- src/agents/memory/session_memory.py | 228 ++++++++++++++++------------ 1 file changed, 128 insertions(+), 100 deletions(-) diff --git a/src/agents/memory/session_memory.py b/src/agents/memory/session_memory.py index 2b9c8c0c..3bae4e24 100644 --- a/src/agents/memory/session_memory.py +++ b/src/agents/memory/session_memory.py @@ -86,18 +86,34 @@ def __init__( self.sessions_table = sessions_table self.messages_table = messages_table self._local = threading.local() + self._lock = threading.Lock() + + # For in-memory databases, we need a shared connection to avoid thread isolation + # For file databases, we use thread-local connections for better concurrency + self._is_memory_db = str(db_path) == ":memory:" + if self._is_memory_db: + self._shared_connection = sqlite3.connect( + ":memory:", check_same_thread=False + ) + self._shared_connection.execute("PRAGMA journal_mode=WAL") + self._init_db_for_connection(self._shared_connection) def _get_connection(self) -> sqlite3.Connection: - """Get a thread-local database connection.""" - if not hasattr(self._local, "connection"): - self._local.connection = sqlite3.connect( - str(self.db_path) if self.db_path != ":memory:" else self.db_path, - check_same_thread=False, - ) - self._local.connection.execute("PRAGMA journal_mode=WAL") - # Initialize the database schema for this connection - self._init_db_for_connection(self._local.connection) - return self._local.connection + """Get a database connection.""" + if self._is_memory_db: + # Use shared connection for in-memory database to avoid thread isolation + return self._shared_connection + else: + # Use thread-local connections for file databases + if not hasattr(self._local, "connection"): + self._local.connection = sqlite3.connect( + str(self.db_path), + check_same_thread=False, + ) + self._local.connection.execute("PRAGMA journal_mode=WAL") + # Initialize the database schema for this connection + self._init_db_for_connection(self._local.connection) + return self._local.connection def _init_db_for_connection(self, conn: sqlite3.Connection) -> None: """Initialize the database schema for a specific connection.""" @@ -144,25 +160,26 @@ async def get_messages(self, session_id: str) -> list[TResponseInputItem]: def _get_messages_sync(): conn = self._get_connection() - cursor = conn.execute( - f""" - SELECT message_data FROM {self.messages_table} - WHERE session_id = ? - ORDER BY created_at ASC - """, - (session_id,), - ) - - messages = [] - for (message_data,) in cursor.fetchall(): - try: - message = json.loads(message_data) - messages.append(message) - except json.JSONDecodeError: - # Skip invalid JSON entries - continue - - return messages + with self._lock if self._is_memory_db else threading.Lock(): + cursor = conn.execute( + f""" + SELECT message_data FROM {self.messages_table} + WHERE session_id = ? + ORDER BY created_at ASC + """, + (session_id,), + ) + + messages = [] + for (message_data,) in cursor.fetchall(): + try: + message = json.loads(message_data) + messages.append(message) + except json.JSONDecodeError: + # Skip invalid JSON entries + continue + + return messages return await asyncio.to_thread(_get_messages_sync) @@ -181,32 +198,35 @@ async def add_messages( def _add_messages_sync(): conn = self._get_connection() - # Ensure session exists - conn.execute( - f""" - INSERT OR IGNORE INTO {self.sessions_table} (session_id) VALUES (?) - """, - (session_id,), - ) - - # Add messages - message_data = [(session_id, json.dumps(message)) for message in messages] - conn.executemany( - f""" - INSERT INTO {self.messages_table} (session_id, message_data) VALUES (?, ?) - """, - message_data, - ) - - # Update session timestamp - conn.execute( - f""" - UPDATE {self.sessions_table} SET updated_at = CURRENT_TIMESTAMP WHERE session_id = ? - """, - (session_id,), - ) - - conn.commit() + with self._lock if self._is_memory_db else threading.Lock(): + # Ensure session exists + conn.execute( + f""" + INSERT OR IGNORE INTO {self.sessions_table} (session_id) VALUES (?) + """, + (session_id,), + ) + + # Add messages + message_data = [ + (session_id, json.dumps(message)) for message in messages + ] + conn.executemany( + f""" + INSERT INTO {self.messages_table} (session_id, message_data) VALUES (?, ?) + """, + message_data, + ) + + # Update session timestamp + conn.execute( + f""" + UPDATE {self.sessions_table} SET updated_at = CURRENT_TIMESTAMP WHERE session_id = ? + """, + (session_id,), + ) + + conn.commit() await asyncio.to_thread(_add_messages_sync) @@ -222,42 +242,43 @@ async def pop_message(self, session_id: str) -> TResponseInputItem | None: def _pop_message_sync(): conn = self._get_connection() - cursor = conn.execute( - f""" - SELECT id, message_data FROM {self.messages_table} - WHERE session_id = ? - ORDER BY created_at DESC - LIMIT 1 - """, - (session_id,), - ) - - result = cursor.fetchone() - if result: - message_id, message_data = result - try: - message = json.loads(message_data) - # Delete the message by ID - conn.execute( - f""" - DELETE FROM {self.messages_table} WHERE id = ? - """, - (message_id,), - ) - conn.commit() - return message - except json.JSONDecodeError: - # Skip invalid JSON entries, but still delete the corrupted record - conn.execute( - f""" - DELETE FROM {self.messages_table} WHERE id = ? - """, - (message_id,), - ) - conn.commit() - return None - - return None + with self._lock if self._is_memory_db else threading.Lock(): + cursor = conn.execute( + f""" + SELECT id, message_data FROM {self.messages_table} + WHERE session_id = ? + ORDER BY created_at DESC + LIMIT 1 + """, + (session_id,), + ) + + result = cursor.fetchone() + if result: + message_id, message_data = result + try: + message = json.loads(message_data) + # Delete the message by ID + conn.execute( + f""" + DELETE FROM {self.messages_table} WHERE id = ? + """, + (message_id,), + ) + conn.commit() + return message + except json.JSONDecodeError: + # Skip invalid JSON entries, but still delete the corrupted record + conn.execute( + f""" + DELETE FROM {self.messages_table} WHERE id = ? + """, + (message_id,), + ) + conn.commit() + return None + + return None return await asyncio.to_thread(_pop_message_sync) @@ -270,17 +291,24 @@ async def clear_session(self, session_id: str) -> None: def _clear_session_sync(): conn = self._get_connection() - conn.execute( - f"DELETE FROM {self.messages_table} WHERE session_id = ?", (session_id,) - ) - conn.execute( - f"DELETE FROM {self.sessions_table} WHERE session_id = ?", (session_id,) - ) - conn.commit() + with self._lock if self._is_memory_db else threading.Lock(): + conn.execute( + f"DELETE FROM {self.messages_table} WHERE session_id = ?", + (session_id,), + ) + conn.execute( + f"DELETE FROM {self.sessions_table} WHERE session_id = ?", + (session_id,), + ) + conn.commit() await asyncio.to_thread(_clear_session_sync) def close(self) -> None: """Close the database connection.""" - if hasattr(self._local, "connection"): - self._local.connection.close() + if self._is_memory_db: + if hasattr(self, "_shared_connection"): + self._shared_connection.close() + else: + if hasattr(self._local, "connection"): + self._local.connection.close() From 131c5dccb3c7d7cca8fbdcd61e7c50ff59edd1b3 Mon Sep 17 00:00:00 2001 From: Stephan Fitzpatrick Date: Sat, 24 May 2025 14:09:11 -0700 Subject: [PATCH 17/26] Initialize database schema for file databases in SQLiteSessionMemory - Added logic to initialize the database schema for file databases during connection setup. - Ensured that the schema is only initialized once, improving efficiency and clarity in connection management. --- src/agents/memory/session_memory.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/agents/memory/session_memory.py b/src/agents/memory/session_memory.py index 3bae4e24..d034aab5 100644 --- a/src/agents/memory/session_memory.py +++ b/src/agents/memory/session_memory.py @@ -97,6 +97,12 @@ def __init__( ) self._shared_connection.execute("PRAGMA journal_mode=WAL") self._init_db_for_connection(self._shared_connection) + else: + # For file databases, initialize the schema once since it persists + init_conn = sqlite3.connect(str(self.db_path), check_same_thread=False) + init_conn.execute("PRAGMA journal_mode=WAL") + self._init_db_for_connection(init_conn) + init_conn.close() def _get_connection(self) -> sqlite3.Connection: """Get a database connection.""" @@ -111,8 +117,6 @@ def _get_connection(self) -> sqlite3.Connection: check_same_thread=False, ) self._local.connection.execute("PRAGMA journal_mode=WAL") - # Initialize the database schema for this connection - self._init_db_for_connection(self._local.connection) return self._local.connection def _init_db_for_connection(self, conn: sqlite3.Connection) -> None: From b8109921d4b0afa9d34240e60cea7b1aee89e4f0 Mon Sep 17 00:00:00 2001 From: Stephan Fitzpatrick Date: Sat, 24 May 2025 14:31:46 -0700 Subject: [PATCH 18/26] Refactor Runner class to improve error handling and input preparation - Enhanced the error handling mechanism in the Runner class to ensure that exceptions during setup result in a completion sentinel being placed in the event queue. - Streamlined the input preparation process by consolidating the logic for handling session memory and updating the streamed result. - Improved clarity and maintainability of the code by restructuring the try-except blocks and ensuring proper resource management for spans and traces. --- src/agents/run.py | 301 ++++++++++++++++++----------------- tests/test_session_memory.py | 271 +++++++++++++++++++++++++++---- 2 files changed, 399 insertions(+), 173 deletions(-) diff --git a/src/agents/run.py b/src/agents/run.py index 3e01aeb1..1dd2d258 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -552,173 +552,186 @@ async def _run_streamed_impl( memory: SessionMemory | None, session_id: str | None, ): - if streamed_result.trace: - streamed_result.trace.start(mark_as_current=True) + current_span: Span[AgentSpanData] | None = None - # Prepare input with session memory if enabled - prepared_input, session_memory = await cls._prepare_input_with_memory( - starting_input, memory, session_id - ) + try: + if streamed_result.trace: + streamed_result.trace.start(mark_as_current=True) - # Update the streamed result with the prepared input - streamed_result.input = prepared_input + # Prepare input with session memory if enabled + prepared_input, session_memory = await cls._prepare_input_with_memory( + starting_input, memory, session_id + ) - current_span: Span[AgentSpanData] | None = None - current_agent = starting_agent - current_turn = 0 - should_run_agent_start_hooks = True - tool_use_tracker = AgentToolUseTracker() + # Update the streamed result with the prepared input + streamed_result.input = prepared_input - streamed_result._event_queue.put_nowait( - AgentUpdatedStreamEvent(new_agent=current_agent) - ) + current_agent = starting_agent + current_turn = 0 + should_run_agent_start_hooks = True + tool_use_tracker = AgentToolUseTracker() - try: - while True: - if streamed_result.is_complete: - break - - # Start an agent span if we don't have one. This span is ended if the current - # agent changes, or if the agent loop ends. - if current_span is None: - handoff_names = [ - h.agent_name for h in cls._get_handoffs(current_agent) - ] - if output_schema := cls._get_output_schema(current_agent): - output_type_name = output_schema.name() - else: - output_type_name = "str" + streamed_result._event_queue.put_nowait( + AgentUpdatedStreamEvent(new_agent=current_agent) + ) - current_span = agent_span( - name=current_agent.name, - handoffs=handoff_names, - output_type=output_type_name, - ) - current_span.start(mark_as_current=True) + try: + while True: + if streamed_result.is_complete: + break - all_tools = await cls._get_all_tools(current_agent) - tool_names = [t.name for t in all_tools] - current_span.span_data.tools = tool_names - current_turn += 1 - streamed_result.current_turn = current_turn + # Start an agent span if we don't have one. This span is ended if the current + # agent changes, or if the agent loop ends. + if current_span is None: + handoff_names = [ + h.agent_name for h in cls._get_handoffs(current_agent) + ] + if output_schema := cls._get_output_schema(current_agent): + output_type_name = output_schema.name() + else: + output_type_name = "str" - if current_turn > max_turns: - _error_tracing.attach_error_to_span( - current_span, - SpanError( - message="Max turns exceeded", - data={"max_turns": max_turns}, - ), - ) - streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) - break - - if current_turn == 1: - # Run the input guardrails in the background and put the results on the queue - streamed_result._input_guardrails_task = asyncio.create_task( - cls._run_input_guardrails_with_queue( - starting_agent, - starting_agent.input_guardrails - + (run_config.input_guardrails or []), - copy.deepcopy( - ItemHelpers.input_to_new_input_list(prepared_input) - ), - context_wrapper, - streamed_result, - current_span, + current_span = agent_span( + name=current_agent.name, + handoffs=handoff_names, + output_type=output_type_name, ) - ) - try: - turn_result = await cls._run_single_turn_streamed( - streamed_result, - current_agent, - hooks, - context_wrapper, - run_config, - should_run_agent_start_hooks, - tool_use_tracker, - all_tools, - previous_response_id, - ) - should_run_agent_start_hooks = False + current_span.start(mark_as_current=True) - streamed_result.raw_responses = streamed_result.raw_responses + [ - turn_result.model_response - ] - streamed_result.input = turn_result.original_input - streamed_result.new_items = turn_result.generated_items + all_tools = await cls._get_all_tools(current_agent) + tool_names = [t.name for t in all_tools] + current_span.span_data.tools = tool_names + current_turn += 1 + streamed_result.current_turn = current_turn - if isinstance(turn_result.next_step, NextStepHandoff): - current_agent = turn_result.next_step.new_agent - current_span.finish(reset_current=True) - current_span = None - should_run_agent_start_hooks = True - streamed_result._event_queue.put_nowait( - AgentUpdatedStreamEvent(new_agent=current_agent) + if current_turn > max_turns: + _error_tracing.attach_error_to_span( + current_span, + SpanError( + message="Max turns exceeded", + data={"max_turns": max_turns}, + ), ) - elif isinstance(turn_result.next_step, NextStepFinalOutput): - streamed_result._output_guardrails_task = asyncio.create_task( - cls._run_output_guardrails( - current_agent.output_guardrails - + (run_config.output_guardrails or []), - current_agent, - turn_result.next_step.output, + streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) + break + + if current_turn == 1: + # Run the input guardrails in the background and put the results on the queue + streamed_result._input_guardrails_task = asyncio.create_task( + cls._run_input_guardrails_with_queue( + starting_agent, + starting_agent.input_guardrails + + (run_config.input_guardrails or []), + copy.deepcopy( + ItemHelpers.input_to_new_input_list(prepared_input) + ), context_wrapper, + streamed_result, + current_span, ) ) + try: + turn_result = await cls._run_single_turn_streamed( + streamed_result, + current_agent, + hooks, + context_wrapper, + run_config, + should_run_agent_start_hooks, + tool_use_tracker, + all_tools, + previous_response_id, + ) + should_run_agent_start_hooks = False - try: - output_guardrail_results = ( - await streamed_result._output_guardrails_task + streamed_result.raw_responses = ( + streamed_result.raw_responses + [turn_result.model_response] + ) + streamed_result.input = turn_result.original_input + streamed_result.new_items = turn_result.generated_items + + if isinstance(turn_result.next_step, NextStepHandoff): + current_agent = turn_result.next_step.new_agent + current_span.finish(reset_current=True) + current_span = None + should_run_agent_start_hooks = True + streamed_result._event_queue.put_nowait( + AgentUpdatedStreamEvent(new_agent=current_agent) + ) + elif isinstance(turn_result.next_step, NextStepFinalOutput): + streamed_result._output_guardrails_task = ( + asyncio.create_task( + cls._run_output_guardrails( + current_agent.output_guardrails + + (run_config.output_guardrails or []), + current_agent, + turn_result.next_step.output, + context_wrapper, + ) + ) ) - except Exception: - # Exceptions will be checked in the stream_events loop - output_guardrail_results = [] - streamed_result.output_guardrail_results = ( - output_guardrail_results - ) - streamed_result.final_output = turn_result.next_step.output - streamed_result.is_complete = True + try: + output_guardrail_results = ( + await streamed_result._output_guardrails_task + ) + except Exception: + # Exceptions will be checked in the stream_events loop + output_guardrail_results = [] - # Save the conversation to session memory if enabled - # Create a temporary RunResult for memory saving - temp_result = RunResult( - input=streamed_result.input, - new_items=streamed_result.new_items, - raw_responses=streamed_result.raw_responses, - final_output=streamed_result.final_output, - _last_agent=current_agent, - input_guardrail_results=streamed_result.input_guardrail_results, - output_guardrail_results=streamed_result.output_guardrail_results, - context_wrapper=context_wrapper, - ) - await cls._save_result_to_memory( - session_memory, session_id, starting_input, temp_result - ) + streamed_result.output_guardrail_results = ( + output_guardrail_results + ) + streamed_result.final_output = turn_result.next_step.output + streamed_result.is_complete = True + + # Save the conversation to session memory if enabled + # Create a temporary RunResult for memory saving + temp_result = RunResult( + input=streamed_result.input, + new_items=streamed_result.new_items, + raw_responses=streamed_result.raw_responses, + final_output=streamed_result.final_output, + _last_agent=current_agent, + input_guardrail_results=streamed_result.input_guardrail_results, + output_guardrail_results=streamed_result.output_guardrail_results, + context_wrapper=context_wrapper, + ) + await cls._save_result_to_memory( + session_memory, session_id, starting_input, temp_result + ) + streamed_result._event_queue.put_nowait( + QueueCompleteSentinel() + ) + elif isinstance(turn_result.next_step, NextStepRunAgain): + pass + except Exception as e: + if current_span: + _error_tracing.attach_error_to_span( + current_span, + SpanError( + message="Error in agent run", + data={"error": str(e)}, + ), + ) + streamed_result.is_complete = True streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) - elif isinstance(turn_result.next_step, NextStepRunAgain): - pass - except Exception as e: - if current_span: - _error_tracing.attach_error_to_span( - current_span, - SpanError( - message="Error in agent run", - data={"error": str(e)}, - ), - ) - streamed_result.is_complete = True - streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) - raise + raise + + streamed_result.is_complete = True + finally: + if current_span: + current_span.finish(reset_current=True) + if streamed_result.trace: + streamed_result.trace.finish(reset_current=True) + except Exception: + # Ensure that any exception (including those during setup) results in a completion sentinel + # being put in the queue so that stream_events() doesn't hang streamed_result.is_complete = True - finally: - if current_span: - current_span.finish(reset_current=True) - if streamed_result.trace: - streamed_result.trace.finish(reset_current=True) + streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) + raise @classmethod async def _run_single_turn_streamed( diff --git a/tests/test_session_memory.py b/tests/test_session_memory.py index b76e8fa7..54699434 100644 --- a/tests/test_session_memory.py +++ b/tests/test_session_memory.py @@ -4,6 +4,7 @@ import tempfile import os from pathlib import Path +import asyncio from agents import Agent, Runner, RunConfig, SQLiteSessionMemory from agents.memory import SessionMemory @@ -12,9 +13,48 @@ from .test_responses import get_text_message +# Helper functions for parametrized testing of different Runner methods +def _run_sync_wrapper(agent, input_data, **kwargs): + """Wrapper for run_sync that properly sets up an event loop.""" + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + return Runner.run_sync(agent, input_data, **kwargs) + finally: + loop.close() + + +async def run_agent_async(runner_method: str, agent, input_data, **kwargs): + """Helper function to run agent with different methods.""" + if runner_method == "run": + return await Runner.run(agent, input_data, **kwargs) + elif runner_method == "run_sync": + # For run_sync, we need to run it in a thread with its own event loop + return await asyncio.to_thread(_run_sync_wrapper, agent, input_data, **kwargs) + elif runner_method == "run_streamed": + result = Runner.run_streamed(agent, input_data, **kwargs) + # For streaming, we first try to get at least one event to trigger any early exceptions + # If there's an exception in setup (like memory validation), it will be raised here + try: + first_event = None + async for event in result.stream_events(): + if first_event is None: + first_event = event + # Continue consuming all events + pass + except Exception: + # If an exception occurs during streaming, we let it propagate up + raise + return result + else: + raise ValueError(f"Unknown runner method: {runner_method}") + + +# Parametrized tests for different runner methods +@pytest.mark.parametrize("runner_method", ["run", "run_sync", "run_streamed"]) @pytest.mark.asyncio -async def test_session_memory_basic_functionality(): - """Test basic session memory functionality with SQLite backend.""" +async def test_session_memory_basic_functionality_parametrized(runner_method): + """Test basic session memory functionality with SQLite backend across all runner methods.""" with tempfile.TemporaryDirectory() as temp_dir: db_path = Path(temp_dir) / "test_memory.db" memory = SQLiteSessionMemory(db_path) @@ -26,7 +66,8 @@ async def test_session_memory_basic_functionality(): # First turn model.set_next_output([get_text_message("San Francisco")]) - result1 = await Runner.run( + result1 = await run_agent_async( + runner_method, agent, "What city is the Golden Gate Bridge in?", memory=memory, @@ -36,8 +77,12 @@ async def test_session_memory_basic_functionality(): # Second turn - should have conversation history model.set_next_output([get_text_message("California")]) - result2 = await Runner.run( - agent, "What state is it in?", memory=memory, session_id=session_id + result2 = await run_agent_async( + runner_method, + agent, + "What state is it in?", + memory=memory, + session_id=session_id, ) assert result2.final_output == "California" @@ -49,9 +94,10 @@ async def test_session_memory_basic_functionality(): memory.close() +@pytest.mark.parametrize("runner_method", ["run", "run_sync", "run_streamed"]) @pytest.mark.asyncio -async def test_session_memory_with_explicit_instance(): - """Test session memory with an explicit SQLiteSessionMemory instance.""" +async def test_session_memory_with_explicit_instance_parametrized(runner_method): + """Test session memory with an explicit SQLiteSessionMemory instance across all runner methods.""" with tempfile.TemporaryDirectory() as temp_dir: db_path = Path(temp_dir) / "test_memory.db" memory = SQLiteSessionMemory(db_path) @@ -63,35 +109,42 @@ async def test_session_memory_with_explicit_instance(): # First turn model.set_next_output([get_text_message("Hello")]) - result1 = await Runner.run( - agent, "Hi there", memory=memory, session_id=session_id + result1 = await run_agent_async( + runner_method, agent, "Hi there", memory=memory, session_id=session_id ) assert result1.final_output == "Hello" # Second turn model.set_next_output([get_text_message("I remember you said hi")]) - result2 = await Runner.run( - agent, "Do you remember what I said?", memory=memory, session_id=session_id + result2 = await run_agent_async( + runner_method, + agent, + "Do you remember what I said?", + memory=memory, + session_id=session_id, ) assert result2.final_output == "I remember you said hi" memory.close() +@pytest.mark.parametrize("runner_method", ["run", "run_sync", "run_streamed"]) @pytest.mark.asyncio -async def test_session_memory_disabled(): - """Test that session memory is disabled when memory=None.""" +async def test_session_memory_disabled_parametrized(runner_method): + """Test that session memory is disabled when memory=None across all runner methods.""" model = FakeModel() agent = Agent(name="test", model=model) # First turn (no memory parameters = disabled) model.set_next_output([get_text_message("Hello")]) - result1 = await Runner.run(agent, "Hi there") + result1 = await run_agent_async(runner_method, agent, "Hi there") assert result1.final_output == "Hello" # Second turn - should NOT have conversation history model.set_next_output([get_text_message("I don't remember")]) - result2 = await Runner.run(agent, "Do you remember what I said?") + result2 = await run_agent_async( + runner_method, agent, "Do you remember what I said?" + ) assert result2.final_output == "I don't remember" # Verify that the input to the second turn is just the current message @@ -99,9 +152,10 @@ async def test_session_memory_disabled(): assert len(last_input) == 1 # Should only have the current message +@pytest.mark.parametrize("runner_method", ["run", "run_sync", "run_streamed"]) @pytest.mark.asyncio -async def test_session_memory_different_sessions(): - """Test that different session IDs maintain separate conversation histories.""" +async def test_session_memory_different_sessions_parametrized(runner_method): + """Test that different session IDs maintain separate conversation histories across all runner methods.""" with tempfile.TemporaryDirectory() as temp_dir: db_path = Path(temp_dir) / "test_memory.db" memory = SQLiteSessionMemory(db_path) @@ -113,8 +167,8 @@ async def test_session_memory_different_sessions(): session_id_1 = "session_1" model.set_next_output([get_text_message("I like cats")]) - result1 = await Runner.run( - agent, "I like cats", memory=memory, session_id=session_id_1 + result1 = await run_agent_async( + runner_method, agent, "I like cats", memory=memory, session_id=session_id_1 ) assert result1.final_output == "I like cats" @@ -122,24 +176,29 @@ async def test_session_memory_different_sessions(): session_id_2 = "session_2" model.set_next_output([get_text_message("I like dogs")]) - result2 = await Runner.run( - agent, "I like dogs", memory=memory, session_id=session_id_2 + result2 = await run_agent_async( + runner_method, agent, "I like dogs", memory=memory, session_id=session_id_2 ) assert result2.final_output == "I like dogs" # Back to Session 1 - should remember cats, not dogs model.set_next_output([get_text_message("Yes, you mentioned cats")]) - result3 = await Runner.run( - agent, "What did I say I like?", memory=memory, session_id=session_id_1 + result3 = await run_agent_async( + runner_method, + agent, + "What did I say I like?", + memory=memory, + session_id=session_id_1, ) assert result3.final_output == "Yes, you mentioned cats" memory.close() +@pytest.mark.parametrize("runner_method", ["run", "run_sync", "run_streamed"]) @pytest.mark.asyncio -async def test_session_memory_no_session_id(): - """Test that session memory raises an exception when no session_id is provided.""" +async def test_session_memory_no_session_id_parametrized(runner_method): + """Test that session memory raises an exception when no session_id is provided across all runner methods.""" model = FakeModel() agent = Agent(name="test", model=model) memory = SQLiteSessionMemory() @@ -148,12 +207,13 @@ async def test_session_memory_no_session_id(): with pytest.raises( ValueError, match="session_id is required when memory is enabled" ): - await Runner.run(agent, "Hi there", memory=memory) + await run_agent_async(runner_method, agent, "Hi there", memory=memory) +@pytest.mark.parametrize("runner_method", ["run", "run_sync", "run_streamed"]) @pytest.mark.asyncio -async def test_session_id_without_memory(): - """Test that providing session_id without memory raises an exception.""" +async def test_session_id_without_memory_parametrized(runner_method): + """Test that providing session_id without memory raises an exception across all runner methods.""" model = FakeModel() agent = Agent(name="test", model=model) @@ -161,7 +221,7 @@ async def test_session_id_without_memory(): # Should raise ValueError when trying to run with session_id but no memory with pytest.raises(ValueError, match="session_id provided but memory is disabled"): - await Runner.run(agent, "Hi there", session_id=session_id) + await run_agent_async(runner_method, agent, "Hi there", session_id=session_id) @pytest.mark.asyncio @@ -294,3 +354,156 @@ async def test_session_memory_pop_different_sessions(): assert session_2_messages[0]["content"] == "Session 2 message 1" memory.close() + + +# Original non-parametrized tests for backwards compatibility +@pytest.mark.asyncio +async def test_session_memory_basic_functionality(): + """Test basic session memory functionality with SQLite backend.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_memory.db" + memory = SQLiteSessionMemory(db_path) + + model = FakeModel() + agent = Agent(name="test", model=model) + + session_id = "test_session_123" + + # First turn + model.set_next_output([get_text_message("San Francisco")]) + result1 = await Runner.run( + agent, + "What city is the Golden Gate Bridge in?", + memory=memory, + session_id=session_id, + ) + assert result1.final_output == "San Francisco" + + # Second turn - should have conversation history + model.set_next_output([get_text_message("California")]) + result2 = await Runner.run( + agent, "What state is it in?", memory=memory, session_id=session_id + ) + assert result2.final_output == "California" + + # Verify that the input to the second turn includes the previous conversation + # The model should have received the full conversation history + last_input = model.last_turn_args["input"] + assert len(last_input) > 1 # Should have more than just the current message + + memory.close() + + +@pytest.mark.asyncio +async def test_session_memory_with_explicit_instance(): + """Test session memory with an explicit SQLiteSessionMemory instance.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_memory.db" + memory = SQLiteSessionMemory(db_path) + + model = FakeModel() + agent = Agent(name="test", model=model) + + session_id = "test_session_456" + + # First turn + model.set_next_output([get_text_message("Hello")]) + result1 = await Runner.run( + agent, "Hi there", memory=memory, session_id=session_id + ) + assert result1.final_output == "Hello" + + # Second turn + model.set_next_output([get_text_message("I remember you said hi")]) + result2 = await Runner.run( + agent, "Do you remember what I said?", memory=memory, session_id=session_id + ) + assert result2.final_output == "I remember you said hi" + + memory.close() + + +@pytest.mark.asyncio +async def test_session_memory_disabled(): + """Test that session memory is disabled when memory=None.""" + model = FakeModel() + agent = Agent(name="test", model=model) + + # First turn (no memory parameters = disabled) + model.set_next_output([get_text_message("Hello")]) + result1 = await Runner.run(agent, "Hi there") + assert result1.final_output == "Hello" + + # Second turn - should NOT have conversation history + model.set_next_output([get_text_message("I don't remember")]) + result2 = await Runner.run(agent, "Do you remember what I said?") + assert result2.final_output == "I don't remember" + + # Verify that the input to the second turn is just the current message + last_input = model.last_turn_args["input"] + assert len(last_input) == 1 # Should only have the current message + + +@pytest.mark.asyncio +async def test_session_memory_different_sessions(): + """Test that different session IDs maintain separate conversation histories.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_memory.db" + memory = SQLiteSessionMemory(db_path) + + model = FakeModel() + agent = Agent(name="test", model=model) + + # Session 1 + session_id_1 = "session_1" + + model.set_next_output([get_text_message("I like cats")]) + result1 = await Runner.run( + agent, "I like cats", memory=memory, session_id=session_id_1 + ) + assert result1.final_output == "I like cats" + + # Session 2 - different session + session_id_2 = "session_2" + + model.set_next_output([get_text_message("I like dogs")]) + result2 = await Runner.run( + agent, "I like dogs", memory=memory, session_id=session_id_2 + ) + assert result2.final_output == "I like dogs" + + # Back to Session 1 - should remember cats, not dogs + model.set_next_output([get_text_message("Yes, you mentioned cats")]) + result3 = await Runner.run( + agent, "What did I say I like?", memory=memory, session_id=session_id_1 + ) + assert result3.final_output == "Yes, you mentioned cats" + + memory.close() + + +@pytest.mark.asyncio +async def test_session_memory_no_session_id(): + """Test that session memory raises an exception when no session_id is provided.""" + model = FakeModel() + agent = Agent(name="test", model=model) + memory = SQLiteSessionMemory() + + # Should raise ValueError when trying to run with memory enabled but no session_id + with pytest.raises( + ValueError, match="session_id is required when memory is enabled" + ): + await Runner.run(agent, "Hi there", memory=memory) + + +@pytest.mark.asyncio +async def test_session_id_without_memory(): + """Test that providing session_id without memory raises an exception.""" + model = FakeModel() + agent = Agent(name="test", model=model) + + session_id = "test_session_without_memory" + + # Should raise ValueError when trying to run with session_id but no memory + with pytest.raises(ValueError, match="session_id provided but memory is disabled"): + await Runner.run(agent, "Hi there", session_id=session_id) From 8b827f208f10f95c39e87937c7930bd8ffde2287 Mon Sep 17 00:00:00 2001 From: Stephan Fitzpatrick Date: Sat, 24 May 2025 14:36:55 -0700 Subject: [PATCH 19/26] Refactor test_session_memory.py to simplify imports - Removed unused imports and streamlined the import statements for clarity and maintainability. - This change enhances the readability of the test file by focusing on the necessary components. --- tests/test_session_memory.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_session_memory.py b/tests/test_session_memory.py index 54699434..27f301dd 100644 --- a/tests/test_session_memory.py +++ b/tests/test_session_memory.py @@ -2,12 +2,10 @@ import pytest import tempfile -import os from pathlib import Path import asyncio -from agents import Agent, Runner, RunConfig, SQLiteSessionMemory -from agents.memory import SessionMemory +from agents import Agent, Runner, SQLiteSessionMemory from .fake_model import FakeModel from .test_responses import get_text_message From 0cf4f880b1e8861c424abf61c5ab94dcd97970d1 Mon Sep 17 00:00:00 2001 From: Stephan Fitzpatrick Date: Sat, 24 May 2025 15:45:08 -0700 Subject: [PATCH 20/26] Refactor session memory implementation to simplify usage and improve clarity - Replaced `SQLiteSessionMemory` with `SQLiteSession` in the codebase, streamlining session management. - Updated documentation and examples to reflect the new session handling approach, removing the need for `session_id` when using sessions. - Enhanced the `Session` protocol to better define session behavior and improve consistency across implementations. - Adjusted tests to ensure compatibility with the new session structure, maintaining functionality across various scenarios. --- README.md | 78 +++------ docs/ref/memory.md | 2 + docs/running_agents.md | 13 +- docs/session_memory.md | 177 ++++++++----------- examples/basic/session_memory_example.py | 18 +- src/agents/__init__.py | 4 +- src/agents/memory/__init__.py | 4 +- src/agents/memory/session_memory.py | 84 ++++----- src/agents/run.py | 104 ++++------- tests/test_session_memory.py | 213 ++++++++--------------- 10 files changed, 255 insertions(+), 442 deletions(-) diff --git a/README.md b/README.md index 3a96fc43..004fa178 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ The Agents SDK provides built-in session memory to automatically maintain conver ### Quick start ```python -from agents import Agent, Runner, SQLiteSessionMemory +from agents import Agent, Runner, SQLiteSession # Create agent agent = Agent( @@ -29,15 +29,14 @@ agent = Agent( instructions="Reply very concisely.", ) -# Create a session memory instance -memory = SQLiteSessionMemory() +# Create a session instance +session = SQLiteSession("conversation_123") # First turn result = await Runner.run( agent, "What city is the Golden Gate Bridge in?", - memory=memory, - session_id="conversation_123" + session=session ) print(result.final_output) # "San Francisco" @@ -45,8 +44,7 @@ print(result.final_output) # "San Francisco" result = await Runner.run( agent, "What state is it in?", - memory=memory, - session_id="conversation_123" + session=session ) print(result.final_output) # "California" @@ -54,91 +52,73 @@ print(result.final_output) # "California" result = Runner.run_sync( agent, "What's the population?", - memory=memory, - session_id="conversation_123" + session=session ) print(result.final_output) # "Approximately 39 million" ``` ### Memory options -- **No memory** (default): No session memory when memory parameter is omitted -- **`memory=SessionMemory`**: Use the provided session memory implementation +- **No memory** (default): No session memory when session parameter is omitted +- **`session=Session`**: Use the provided session implementation ```python -from agents import Agent, Runner, SQLiteSessionMemory +from agents import Agent, Runner, SQLiteSession # Custom SQLite database file -memory = SQLiteSessionMemory("conversations.db") +session = SQLiteSession("user_123", "conversations.db") agent = Agent(name="Assistant") # Different session IDs maintain separate conversation histories result1 = await Runner.run( agent, "Hello", - memory=memory, - session_id="user_123" + session=session ) result2 = await Runner.run( agent, "Hello", - memory=memory, - session_id="user_456" + session=SQLiteSession("user_456", "conversations.db") ) ``` ### Custom memory implementations -You can implement your own session memory by creating a class that follows the `SessionMemory` protocol: +You can implement your own session memory by creating a class that follows the `Session` protocol: ```python -from agents.memory import SessionMemory +from agents.memory import Session from typing import List -class MyCustomMemory: - """Custom memory implementation following the SessionMemory protocol.""" +class MyCustomSession: + """Custom session implementation following the Session protocol.""" - async def get_messages(self, session_id: str) -> List[dict]: + def __init__(self, session_id: str): + self.session_id = session_id + # Your initialization here + + async def get_messages(self) -> List[dict]: # Retrieve conversation history for the session pass - async def add_messages(self, session_id: str, messages: List[dict]) -> None: + async def add_messages(self, messages: List[dict]) -> None: # Store new messages for the session pass - async def clear_session(self, session_id: str) -> None: - # Clear all messages for the session + async def pop_message(self) -> dict | None: + # Remove and return the most recent message from the session pass -# Use your custom memory -agent = Agent(name="Assistant") -result = await Runner.run( - agent, - "Hello", - memory=MyCustomMemory(), - session_id="my_session" -) -``` - -### Important: session_id requirement - -When session memory is enabled, you **must** provide a `session_id`. If you don't, the runner will raise a `ValueError`: - -```python -from agents import Agent, Runner, SQLiteSessionMemory + async def clear_session(self) -> None: + # Clear all messages for the session + pass +# Use your custom session agent = Agent(name="Assistant") -memory = SQLiteSessionMemory() - -# This will raise ValueError: "session_id is required when memory is enabled" -result = await Runner.run(agent, "Hello", memory=memory) - -# This works correctly result = await Runner.run( agent, "Hello", - memory=memory, - session_id="my_session" + session=MyCustomSession("my_session") ) ``` diff --git a/docs/ref/memory.md b/docs/ref/memory.md index a72612f9..5502c4ab 100644 --- a/docs/ref/memory.md +++ b/docs/ref/memory.md @@ -4,5 +4,7 @@ options: members: + - Session + - SQLiteSession - SessionMemory - SQLiteSessionMemory diff --git a/docs/running_agents.md b/docs/running_agents.md index 5c30a34b..d86714c7 100644 --- a/docs/running_agents.md +++ b/docs/running_agents.md @@ -49,8 +49,6 @@ The `run_config` parameter lets you configure some global settings for the agent - [`model`][agents.run.RunConfig.model]: Allows setting a global LLM model to use, irrespective of what `model` each Agent has. - [`model_provider`][agents.run.RunConfig.model_provider]: A model provider for looking up model names, which defaults to OpenAI. - [`model_settings`][agents.run.RunConfig.model_settings]: Overrides agent-specific settings. For example, you can set a global `temperature` or `top_p`. -- [`memory`][agents.run.RunConfig.memory]: Session memory instance for automatic conversation history management. See [Session Memory](session_memory.md) for details. -- [`session_id`][agents.run.RunConfig.session_id]: Unique identifier for the conversation session. Required when `memory` is enabled. - [`input_guardrails`][agents.run.RunConfig.input_guardrails], [`output_guardrails`][agents.run.RunConfig.output_guardrails]: A list of input or output guardrails to include on all runs. - [`handoff_input_filter`][agents.run.RunConfig.handoff_input_filter]: A global input filter to apply to all handoffs, if the handoff doesn't already have one. The input filter allows you to edit the inputs that are sent to the new agent. See the documentation in [`Handoff.input_filter`][agents.handoffs.Handoff.input_filter] for more details. - [`tracing_disabled`][agents.run.RunConfig.tracing_disabled]: Allows you to disable [tracing](tracing.md) for the entire run. @@ -93,23 +91,22 @@ async def main(): For a simpler approach, you can use [Session Memory](session_memory.md) to automatically handle conversation history without manually calling `.to_input_list()`: ```python -from agents import Agent, Runner, RunConfig, SQLiteSessionMemory +from agents import Agent, Runner, SQLiteSession async def main(): agent = Agent(name="Assistant", instructions="Reply very concisely.") - # Create session memory and run config - memory = SQLiteSessionMemory() - run_config = RunConfig(memory=memory, session_id="conversation_123") + # Create session instance + session = SQLiteSession("conversation_123") with trace(workflow_name="Conversation", group_id=thread_id): # First turn - result = await Runner.run(agent, "What city is the Golden Gate Bridge in?", run_config=run_config) + result = await Runner.run(agent, "What city is the Golden Gate Bridge in?", session=session) print(result.final_output) # San Francisco # Second turn - agent automatically remembers previous context - result = await Runner.run(agent, "What state is it in?", run_config=run_config) + result = await Runner.run(agent, "What state is it in?", session=session) print(result.final_output) # California ``` diff --git a/docs/session_memory.md b/docs/session_memory.md index 5ecaee4f..49316ebc 100644 --- a/docs/session_memory.md +++ b/docs/session_memory.md @@ -2,12 +2,12 @@ The Agents SDK provides built-in session memory to automatically maintain conversation history across multiple agent runs, eliminating the need to manually handle `.to_input_list()` between turns. -Session memory stores conversation history across agent runs, allowing agents to maintain context without requiring explicit manual memory management. This is particularly useful for building chat applications or multi-turn conversations where you want the agent to remember previous interactions. +Session memory stores conversation history for a specific session, allowing agents to maintain context without requiring explicit manual memory management. This is particularly useful for building chat applications or multi-turn conversations where you want the agent to remember previous interactions. ## Quick start ```python -from agents import Agent, Runner, SQLiteSessionMemory +from agents import Agent, Runner, SQLiteSession # Create agent agent = Agent( @@ -15,15 +15,14 @@ agent = Agent( instructions="Reply very concisely.", ) -# Create a session memory instance -memory = SQLiteSessionMemory() +# Create a session instance with a session ID +session = SQLiteSession("conversation_123") # First turn result = await Runner.run( agent, "What city is the Golden Gate Bridge in?", - memory=memory, - session_id="conversation_123" + session=session ) print(result.final_output) # "San Francisco" @@ -31,8 +30,7 @@ print(result.final_output) # "San Francisco" result = await Runner.run( agent, "What state is it in?", - memory=memory, - session_id="conversation_123" + session=session ) print(result.final_output) # "California" @@ -40,8 +38,7 @@ print(result.final_output) # "California" result = Runner.run_sync( agent, "What's the population?", - memory=memory, - session_id="conversation_123" + session=session ) print(result.final_output) # "Approximately 39 million" ``` @@ -50,9 +47,9 @@ print(result.final_output) # "Approximately 39 million" When session memory is enabled: -1. **Before each run**: The runner automatically retrieves the conversation history for the given `session_id` and prepends it to the input messages. -2. **After each run**: All new messages generated during the run (user input, assistant responses, tool calls, etc.) are automatically stored in the session memory. -3. **Context preservation**: Each subsequent run in the same session includes the full conversation history, allowing the agent to maintain context. +1. **Before each run**: The runner automatically retrieves the conversation history for the session and prepends it to the input messages. +2. **After each run**: All new messages generated during the run (user input, assistant responses, tool calls, etc.) are automatically stored in the session. +3. **Context preservation**: Each subsequent run with the same session includes the full conversation history, allowing the agent to maintain context. This eliminates the need to manually call `.to_input_list()` and manage conversation state between runs. @@ -63,27 +60,26 @@ This eliminates the need to manually call `.to_input_list()` and manage conversa Session memory supports several operations for managing conversation history: ```python -from agents import SQLiteSessionMemory +from agents import SQLiteSession -memory = SQLiteSessionMemory("conversations.db") -session_id = "user_123" +session = SQLiteSession("user_123", "conversations.db") # Get all messages in a session -messages = await memory.get_messages(session_id) +messages = await session.get_messages() # Add new messages to a session new_messages = [ {"role": "user", "content": "Hello"}, {"role": "assistant", "content": "Hi there!"} ] -await memory.add_messages(session_id, new_messages) +await session.add_messages(new_messages) # Remove and return the most recent message -last_message = await memory.pop_message(session_id) +last_message = await session.pop_message() print(last_message) # {"role": "assistant", "content": "Hi there!"} # Clear all messages from a session -await memory.clear_session(session_id) +await session.clear_session() ``` ### Using pop_message for corrections @@ -91,31 +87,28 @@ await memory.clear_session(session_id) The `pop_message` method is particularly useful when you want to undo or modify the last message in a conversation: ```python -from agents import Agent, Runner, SQLiteSessionMemory +from agents import Agent, Runner, SQLiteSession agent = Agent(name="Assistant") -memory = SQLiteSessionMemory() -session_id = "correction_example" +session = SQLiteSession("correction_example") # Initial conversation result = await Runner.run( agent, "What's 2 + 2?", - memory=memory, - session_id=session_id + session=session ) print(f"Agent: {result.final_output}") # User wants to correct their question -user_message = await memory.pop_message(session_id) # Remove user's question -assistant_message = await memory.pop_message(session_id) # Remove agent's response +user_message = await session.pop_message() # Remove user's question +assistant_message = await session.pop_message() # Remove agent's response # Ask a corrected question result = await Runner.run( agent, "What's 2 + 3?", - memory=memory, - session_id=session_id + session=session ) print(f"Agent: {result.final_output}") ``` @@ -132,111 +125,89 @@ result = await Runner.run(agent, "Hello") ### SQLite memory ```python -from agents import SQLiteSessionMemory +from agents import SQLiteSession # In-memory database (lost when process ends) -memory = SQLiteSessionMemory() +session = SQLiteSession("user_123") # Persistent file-based database -memory = SQLiteSessionMemory("conversations.db") +session = SQLiteSession("user_123", "conversations.db") -# Use the memory with session IDs +# Use the session result = await Runner.run( agent, "Hello", - memory=memory, - session_id="user_123" + session=session ) ``` ### Multiple sessions ```python -from agents import Agent, Runner, SQLiteSessionMemory +from agents import Agent, Runner, SQLiteSession -memory = SQLiteSessionMemory("conversations.db") agent = Agent(name="Assistant") -# Different session IDs maintain separate conversation histories +# Different sessions maintain separate conversation histories +session_1 = SQLiteSession("user_123", "conversations.db") +session_2 = SQLiteSession("user_456", "conversations.db") + result1 = await Runner.run( agent, "Hello", - memory=memory, - session_id="user_123" + session=session_1 ) result2 = await Runner.run( agent, "Hello", - memory=memory, - session_id="user_456" + session=session_2 ) ``` ## Custom memory implementations -You can implement your own session memory by creating a class that follows the [`SessionMemory`][agents.memory.session_memory.SessionMemory] protocol: +You can implement your own session memory by creating a class that follows the [`Session`][agents.memory.session_memory.Session] protocol: ````python -from agents.memory import SessionMemory +from agents.memory import Session from typing import List -class MyCustomMemory: - """Custom memory implementation following the SessionMemory protocol.""" +class MyCustomSession: + """Custom session implementation following the Session protocol.""" + + def __init__(self, session_id: str): + self.session_id = session_id + # Your initialization here - async def get_messages(self, session_id: str) -> List[dict]: - """Retrieve conversation history for the session.""" + async def get_messages(self) -> List[dict]: + """Retrieve conversation history for this session.""" # Your implementation here pass - async def add_messages(self, session_id: str, messages: List[dict]) -> None: - """Store new messages for the session.""" + async def add_messages(self, messages: List[dict]) -> None: + """Store new messages for this session.""" # Your implementation here pass - async def pop_message(self, session_id: str) -> dict | None: - """Remove and return the most recent message from the session.""" + async def pop_message(self) -> dict | None: + """Remove and return the most recent message from this session.""" # Your implementation here pass - async def clear_session(self, session_id: str) -> None: - """Clear all messages for the session.""" + async def clear_session(self) -> None: + """Clear all messages for this session.""" # Your implementation here pass -# Use your custom memory +# Use your custom session agent = Agent(name="Assistant") result = await Runner.run( agent, "Hello", - memory=MyCustomMemory(), - session_id="my_session" + session=MyCustomSession("my_session") ) -## Requirements and validation - -### session_id requirement - -When session memory is enabled, you **must** provide a `session_id`. If you don't, the runner will raise a `ValueError`: - -```python -from agents import Agent, Runner, SQLiteSessionMemory - -agent = Agent(name="Assistant") -memory = SQLiteSessionMemory() - -# This will raise ValueError: "session_id is required when memory is enabled" -result = await Runner.run(agent, "Hello", memory=memory) - -# This works correctly -result = await Runner.run( - agent, - "Hello", - memory=memory, - session_id="my_session" -) -``` - -## Best practices +## Session management ### Session ID naming @@ -248,32 +219,31 @@ Use meaningful session IDs that help you organize conversations: ### Memory persistence -- Use in-memory SQLite (`SQLiteSessionMemory()`) for temporary conversations -- Use file-based SQLite (`SQLiteSessionMemory("path/to/db.sqlite")`) for persistent conversations -- Consider implementing custom memory backends for production systems (Redis, PostgreSQL, etc.) +- Use in-memory SQLite (`SQLiteSession("session_id")`) for temporary conversations +- Use file-based SQLite (`SQLiteSession("session_id", "path/to/db.sqlite")`) for persistent conversations +- Consider implementing custom session backends for production systems (Redis, PostgreSQL, etc.) ### Session management ```python # Clear a session when conversation should start fresh -await memory.clear_session("user_123") +await session.clear_session() -# Different agents can share the same session memory +# Different agents can share the same session support_agent = Agent(name="Support") billing_agent = Agent(name="Billing") +session = SQLiteSession("user_123") # Both agents will see the same conversation history result1 = await Runner.run( support_agent, "Help me with my account", - memory=memory, - session_id="user_123" + session=session ) result2 = await Runner.run( billing_agent, "What are my charges?", - memory=memory, - session_id="user_123" + session=session ) ``` @@ -283,7 +253,7 @@ Here's a complete example showing session memory in action: ```python import asyncio -from agents import Agent, Runner, SQLiteSessionMemory +from agents import Agent, Runner, SQLiteSession async def main(): @@ -293,11 +263,8 @@ async def main(): instructions="Reply very concisely.", ) - # Create a session memory instance that will persist across runs - memory = SQLiteSessionMemory("conversation_history.db") - - # Define a session ID for this conversation - session_id = "conversation_123" + # Create a session instance that will persist across runs + session = SQLiteSession("conversation_123", "conversation_history.db") print("=== Session Memory Example ===") print("The agent will remember previous messages automatically.\n") @@ -308,8 +275,7 @@ async def main(): result = await Runner.run( agent, "What city is the Golden Gate Bridge in?", - memory=memory, - session_id=session_id + session=session ) print(f"Assistant: {result.final_output}") print() @@ -320,8 +286,7 @@ async def main(): result = await Runner.run( agent, "What state is it in?", - memory=memory, - session_id=session_id + session=session ) print(f"Assistant: {result.final_output}") print() @@ -332,8 +297,7 @@ async def main(): result = await Runner.run( agent, "What's the population of that state?", - memory=memory, - session_id=session_id + session=session ) print(f"Assistant: {result.final_output}") print() @@ -351,8 +315,7 @@ if __name__ == "__main__": For detailed API documentation, see: -- [`SessionMemory`][agents.memory.SessionMemory] - Protocol interface -- [`SQLiteSessionMemory`][agents.memory.SQLiteSessionMemory] - SQLite implementation -- [`RunConfig.memory`][agents.run.RunConfig.memory] - Run configuration -- [`RunConfig.session_id`][agents.run.RunConfig.session_id] - Session identifier +- [`Session`][agents.memory.Session] - Protocol interface +- [`SQLiteSession`][agents.memory.SQLiteSession] - SQLite implementation +- [`RunConfig.session`][agents.run.RunConfig.session] - Run configuration ```` diff --git a/examples/basic/session_memory_example.py b/examples/basic/session_memory_example.py index 0c4feef1..4a85e4db 100644 --- a/examples/basic/session_memory_example.py +++ b/examples/basic/session_memory_example.py @@ -6,7 +6,7 @@ """ import asyncio -from agents import Agent, Runner, SQLiteSessionMemory +from agents import Agent, Runner, SQLiteSession async def main(): @@ -16,11 +16,9 @@ async def main(): instructions="Reply very concisely.", ) - # Create a session memory instance that will persist across runs - memory = SQLiteSessionMemory() - - # Define a session ID for this conversation + # Create a session instance that will persist across runs session_id = "conversation_123" + session = SQLiteSession(session_id) print("=== Session Memory Example ===") print("The agent will remember previous messages automatically.\n") @@ -31,8 +29,7 @@ async def main(): result = await Runner.run( agent, "What city is the Golden Gate Bridge in?", - memory=memory, - session_id=session_id, + session=session, ) print(f"Assistant: {result.final_output}") print() @@ -40,9 +37,7 @@ async def main(): # Second turn - the agent will remember the previous conversation print("Second turn:") print("User: What state is it in?") - result = await Runner.run( - agent, "What state is it in?", memory=memory, session_id=session_id - ) + result = await Runner.run(agent, "What state is it in?", session=session) print(f"Assistant: {result.final_output}") print() @@ -52,8 +47,7 @@ async def main(): result = await Runner.run( agent, "What's the population of that state?", - memory=memory, - session_id=session_id, + session=session, ) print(f"Assistant: {result.final_output}") print() diff --git a/src/agents/__init__.py b/src/agents/__init__.py index 93653ef3..7a852644 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -39,7 +39,7 @@ TResponseInputItem, ) from .lifecycle import AgentHooks, RunHooks -from .memory import SessionMemory, SQLiteSessionMemory +from .memory import Session, SQLiteSession, SessionMemory, SQLiteSessionMemory from .model_settings import ModelSettings from .models.interface import Model, ModelProvider, ModelTracing from .models.openai_chatcompletions import OpenAIChatCompletionsModel @@ -205,6 +205,8 @@ def enable_verbose_stdout_logging(): "ItemHelpers", "RunHooks", "AgentHooks", + "Session", + "SQLiteSession", "SessionMemory", "SQLiteSessionMemory", "RunContextWrapper", diff --git a/src/agents/memory/__init__.py b/src/agents/memory/__init__.py index 59a83d2d..fa7ba1f7 100644 --- a/src/agents/memory/__init__.py +++ b/src/agents/memory/__init__.py @@ -1,3 +1,3 @@ -from .session_memory import SessionMemory, SQLiteSessionMemory +from .session_memory import Session, SQLiteSession, SessionMemory, SQLiteSessionMemory -__all__ = ["SessionMemory", "SQLiteSessionMemory"] +__all__ = ["Session", "SQLiteSession", "SessionMemory", "SQLiteSessionMemory"] diff --git a/src/agents/memory/session_memory.py b/src/agents/memory/session_memory.py index d034aab5..98b49d59 100644 --- a/src/agents/memory/session_memory.py +++ b/src/agents/memory/session_memory.py @@ -12,57 +12,44 @@ @runtime_checkable -class SessionMemory(Protocol): - """Protocol for session memory implementations. +class Session(Protocol): + """Protocol for session implementations. - Session memory stores conversation history across agent runs, allowing + Session stores conversation history for a specific session, allowing agents to maintain context without requiring explicit manual memory management. """ - async def get_messages(self, session_id: str) -> list[TResponseInputItem]: - """Retrieve the conversation history for a given session. - - Args: - session_id: Unique identifier for the conversation session + async def get_messages(self) -> list[TResponseInputItem]: + """Retrieve the conversation history for this session. Returns: List of input items representing the conversation history """ ... - async def add_messages( - self, session_id: str, messages: list[TResponseInputItem] - ) -> None: + async def add_messages(self, messages: list[TResponseInputItem]) -> None: """Add new messages to the conversation history. Args: - session_id: Unique identifier for the conversation session messages: List of input items to add to the history """ ... - async def pop_message(self, session_id: str) -> TResponseInputItem | None: + async def pop_message(self) -> TResponseInputItem | None: """Remove and return the most recent message from the session. - Args: - session_id: Unique identifier for the conversation session - Returns: The most recent message if it exists, None if the session is empty """ ... - async def clear_session(self, session_id: str) -> None: - """Clear all messages for a given session. - - Args: - session_id: Unique identifier for the conversation session - """ + async def clear_session(self) -> None: + """Clear all messages for this session.""" ... -class SQLiteSessionMemory(SessionMemory): - """SQLite-based implementation of session memory. +class SQLiteSession(Session): + """SQLite-based implementation of session storage. This implementation stores conversation history in a SQLite database. By default, uses an in-memory database that is lost when the process ends. @@ -71,17 +58,20 @@ class SQLiteSessionMemory(SessionMemory): def __init__( self, + session_id: str, db_path: str | Path = ":memory:", sessions_table: str = "agent_sessions", messages_table: str = "agent_messages", ): - """Initialize the SQLite session memory. + """Initialize the SQLite session. Args: + session_id: Unique identifier for the conversation session db_path: Path to the SQLite database file. Defaults to ':memory:' (in-memory database) sessions_table: Name of the table to store session metadata. Defaults to 'agent_sessions' messages_table: Name of the table to store message data. Defaults to 'agent_messages' """ + self.session_id = session_id self.db_path = db_path self.sessions_table = sessions_table self.messages_table = messages_table @@ -152,11 +142,8 @@ def _init_db_for_connection(self, conn: sqlite3.Connection) -> None: conn.commit() - async def get_messages(self, session_id: str) -> list[TResponseInputItem]: - """Retrieve the conversation history for a given session. - - Args: - session_id: Unique identifier for the conversation session + async def get_messages(self) -> list[TResponseInputItem]: + """Retrieve the conversation history for this session. Returns: List of input items representing the conversation history @@ -171,7 +158,7 @@ def _get_messages_sync(): WHERE session_id = ? ORDER BY created_at ASC """, - (session_id,), + (self.session_id,), ) messages = [] @@ -187,13 +174,10 @@ def _get_messages_sync(): return await asyncio.to_thread(_get_messages_sync) - async def add_messages( - self, session_id: str, messages: list[TResponseInputItem] - ) -> None: + async def add_messages(self, messages: list[TResponseInputItem]) -> None: """Add new messages to the conversation history. Args: - session_id: Unique identifier for the conversation session messages: List of input items to add to the history """ if not messages: @@ -208,12 +192,12 @@ def _add_messages_sync(): f""" INSERT OR IGNORE INTO {self.sessions_table} (session_id) VALUES (?) """, - (session_id,), + (self.session_id,), ) # Add messages message_data = [ - (session_id, json.dumps(message)) for message in messages + (self.session_id, json.dumps(message)) for message in messages ] conn.executemany( f""" @@ -227,19 +211,16 @@ def _add_messages_sync(): f""" UPDATE {self.sessions_table} SET updated_at = CURRENT_TIMESTAMP WHERE session_id = ? """, - (session_id,), + (self.session_id,), ) conn.commit() await asyncio.to_thread(_add_messages_sync) - async def pop_message(self, session_id: str) -> TResponseInputItem | None: + async def pop_message(self) -> TResponseInputItem | None: """Remove and return the most recent message from the session. - Args: - session_id: Unique identifier for the conversation session - Returns: The most recent message if it exists, None if the session is empty """ @@ -254,7 +235,7 @@ def _pop_message_sync(): ORDER BY created_at DESC LIMIT 1 """, - (session_id,), + (self.session_id,), ) result = cursor.fetchone() @@ -286,23 +267,19 @@ def _pop_message_sync(): return await asyncio.to_thread(_pop_message_sync) - async def clear_session(self, session_id: str) -> None: - """Clear all messages for a given session. - - Args: - session_id: Unique identifier for the conversation session - """ + async def clear_session(self) -> None: + """Clear all messages for this session.""" def _clear_session_sync(): conn = self._get_connection() with self._lock if self._is_memory_db else threading.Lock(): conn.execute( f"DELETE FROM {self.messages_table} WHERE session_id = ?", - (session_id,), + (self.session_id,), ) conn.execute( f"DELETE FROM {self.sessions_table} WHERE session_id = ?", - (session_id,), + (self.session_id,), ) conn.commit() @@ -316,3 +293,8 @@ def close(self) -> None: else: if hasattr(self._local, "connection"): self._local.connection.close() + + +# Legacy aliases for backwards compatibility +SessionMemory = Session +SQLiteSessionMemory = SQLiteSession diff --git a/src/agents/run.py b/src/agents/run.py index 1dd2d258..2a96c82f 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -37,7 +37,7 @@ from .items import ItemHelpers, ModelResponse, RunItem, TResponseInputItem from .lifecycle import RunHooks from .logger import logger -from .memory import SessionMemory, SQLiteSessionMemory +from .memory import Session from .model_settings import ModelSettings from .models.interface import Model, ModelProvider from .models.multi_provider import MultiProvider @@ -111,12 +111,6 @@ class RunConfig: An optional dictionary of additional metadata to include with the trace. """ - session_id: str | None = None - """ - A session identifier for memory persistence. Required when memory is enabled. - Conversation history will be automatically managed using this session ID. - """ - class Runner: @classmethod @@ -130,8 +124,7 @@ async def run( hooks: RunHooks[TContext] | None = None, run_config: RunConfig | None = None, previous_response_id: str | None = None, - memory: SessionMemory | None = None, - session_id: str | None = None, + session: Session | None = None, ) -> RunResult: """Run a workflow starting at the given agent. The agent will run in a loop until a final output is generated. The loop runs like so: @@ -158,10 +151,8 @@ async def run( run_config: Global settings for the entire agent run. previous_response_id: The ID of the previous response, if using OpenAI models via the Responses API, this allows you to skip passing in input from the previous turn. - memory: Session memory instance for conversation history persistence. + session: Session instance for conversation history persistence. If None, no conversation history will be maintained. - session_id: A session identifier for memory persistence. Required when memory is provided. - Conversation history will be automatically managed using this session ID. Returns: A run result containing all the inputs, guardrail results and the output of the last @@ -172,10 +163,8 @@ async def run( if run_config is None: run_config = RunConfig() - # Prepare input with session memory if enabled - prepared_input, session_memory = await cls._prepare_input_with_memory( - input, memory, session_id - ) + # Prepare input with session if enabled + prepared_input = await cls._prepare_input_with_session(input, session) tool_use_tracker = AgentToolUseTracker() @@ -301,10 +290,8 @@ async def run( context_wrapper=context_wrapper, ) - # Save the conversation to session memory if enabled - await cls._save_result_to_memory( - session_memory, session_id, input, result - ) + # Save the conversation to session if enabled + await cls._save_result_to_session(session, input, result) return result elif isinstance(turn_result.next_step, NextStepHandoff): @@ -335,8 +322,7 @@ def run_sync( hooks: RunHooks[TContext] | None = None, run_config: RunConfig | None = None, previous_response_id: str | None = None, - memory: SessionMemory | None = None, - session_id: str | None = None, + session: Session | None = None, ) -> RunResult: """Run a workflow synchronously, starting at the given agent. Note that this just wraps the `run` method, so it will not work if there's already an event loop (e.g. inside an async @@ -367,10 +353,8 @@ def run_sync( run_config: Global settings for the entire agent run. previous_response_id: The ID of the previous response, if using OpenAI models via the Responses API, this allows you to skip passing in input from the previous turn. - memory: Session memory instance for conversation history persistence. + session: Session instance for conversation history persistence. If None, no conversation history will be maintained. - session_id: A session identifier for memory persistence. Required when memory is provided. - Conversation history will be automatically managed using this session ID. Returns: A run result containing all the inputs, guardrail results and the output of the last @@ -385,8 +369,7 @@ def run_sync( hooks=hooks, run_config=run_config, previous_response_id=previous_response_id, - memory=memory, - session_id=session_id, + session=session, ) ) @@ -400,8 +383,7 @@ def run_streamed( hooks: RunHooks[TContext] | None = None, run_config: RunConfig | None = None, previous_response_id: str | None = None, - memory: SessionMemory | None = None, - session_id: str | None = None, + session: Session | None = None, ) -> RunResultStreaming: """Run a workflow starting at the given agent in streaming mode. The returned result object contains a method you can use to stream semantic events as they are generated. @@ -430,10 +412,8 @@ def run_streamed( run_config: Global settings for the entire agent run. previous_response_id: The ID of the previous response, if using OpenAI models via the Responses API, this allows you to skip passing in input from the previous turn. - memory: Session memory instance for conversation history persistence. + session: Session instance for conversation history persistence. If None, no conversation history will be maintained. - session_id: A session identifier for memory persistence. Required when memory is provided. - Conversation history will be automatically managed using this session ID. Returns: A result object that contains data about the run, as well as a method to stream events. """ @@ -489,8 +469,7 @@ def run_streamed( context_wrapper=context_wrapper, run_config=run_config, previous_response_id=previous_response_id, - memory=memory, - session_id=session_id, + session=session, ) ) return streamed_result @@ -549,8 +528,7 @@ async def _run_streamed_impl( context_wrapper: RunContextWrapper[TContext], run_config: RunConfig, previous_response_id: str | None, - memory: SessionMemory | None, - session_id: str | None, + session: Session | None, ): current_span: Span[AgentSpanData] | None = None @@ -558,9 +536,9 @@ async def _run_streamed_impl( if streamed_result.trace: streamed_result.trace.start(mark_as_current=True) - # Prepare input with session memory if enabled - prepared_input, session_memory = await cls._prepare_input_with_memory( - starting_input, memory, session_id + # Prepare input with session if enabled + prepared_input = await cls._prepare_input_with_session( + starting_input, session ) # Update the streamed result with the prepared input @@ -685,8 +663,8 @@ async def _run_streamed_impl( streamed_result.final_output = turn_result.next_step.output streamed_result.is_complete = True - # Save the conversation to session memory if enabled - # Create a temporary RunResult for memory saving + # Save the conversation to session if enabled + # Create a temporary RunResult for session saving temp_result = RunResult( input=streamed_result.input, new_items=streamed_result.new_items, @@ -697,8 +675,8 @@ async def _run_streamed_impl( output_guardrail_results=streamed_result.output_guardrail_results, context_wrapper=context_wrapper, ) - await cls._save_result_to_memory( - session_memory, session_id, starting_input, temp_result + await cls._save_result_to_session( + session, starting_input, temp_result ) streamed_result._event_queue.put_nowait( @@ -1084,30 +1062,17 @@ def _get_model(cls, agent: Agent[Any], run_config: RunConfig) -> Model: return run_config.model_provider.get_model(agent.model) @classmethod - async def _prepare_input_with_memory( + async def _prepare_input_with_session( cls, input: str | list[TResponseInputItem], - memory: SessionMemory | None, - session_id: str | None, - ) -> tuple[str | list[TResponseInputItem], SessionMemory | None]: - """Prepare input by combining it with session memory if enabled.""" - if memory is None: - # Check if session_id is provided without memory - if session_id is not None: - raise ValueError( - "session_id provided but memory is disabled. " - "Please provide a memory instance or remove session_id." - ) - return input, memory - - if session_id is None: - raise ValueError( - "session_id is required when memory is enabled. " - "Please provide a session_id." - ) + session: Session | None, + ) -> str | list[TResponseInputItem]: + """Prepare input by combining it with session history if enabled.""" + if session is None: + return input # Get previous conversation history - history = await memory.get_messages(session_id) + history = await session.get_messages() # Convert input to list format new_input_list = ItemHelpers.input_to_new_input_list(input) @@ -1115,18 +1080,17 @@ async def _prepare_input_with_memory( # Combine history with new input combined_input = history + new_input_list - return combined_input, memory + return combined_input @classmethod - async def _save_result_to_memory( + async def _save_result_to_session( cls, - memory: SessionMemory | None, - session_id: str | None, + session: Session | None, original_input: str | list[TResponseInputItem], result: RunResult, ) -> None: - """Save the conversation turn to session memory.""" - if memory is None or session_id is None: + """Save the conversation turn to session.""" + if session is None: return # Convert original input to list format if needed @@ -1137,4 +1101,4 @@ async def _save_result_to_memory( # Save all messages from this turn messages_to_save = input_list + new_items_as_input - await memory.add_messages(session_id, messages_to_save) + await session.add_messages(messages_to_save) diff --git a/tests/test_session_memory.py b/tests/test_session_memory.py index 27f301dd..bf81ce49 100644 --- a/tests/test_session_memory.py +++ b/tests/test_session_memory.py @@ -5,7 +5,7 @@ from pathlib import Path import asyncio -from agents import Agent, Runner, SQLiteSessionMemory +from agents import Agent, Runner, SQLiteSession from .fake_model import FakeModel from .test_responses import get_text_message @@ -55,21 +55,19 @@ async def test_session_memory_basic_functionality_parametrized(runner_method): """Test basic session memory functionality with SQLite backend across all runner methods.""" with tempfile.TemporaryDirectory() as temp_dir: db_path = Path(temp_dir) / "test_memory.db" - memory = SQLiteSessionMemory(db_path) + session_id = "test_session_123" + session = SQLiteSession(session_id, db_path) model = FakeModel() agent = Agent(name="test", model=model) - session_id = "test_session_123" - # First turn model.set_next_output([get_text_message("San Francisco")]) result1 = await run_agent_async( runner_method, agent, "What city is the Golden Gate Bridge in?", - memory=memory, - session_id=session_id, + session=session, ) assert result1.final_output == "San Francisco" @@ -79,8 +77,7 @@ async def test_session_memory_basic_functionality_parametrized(runner_method): runner_method, agent, "What state is it in?", - memory=memory, - session_id=session_id, + session=session, ) assert result2.final_output == "California" @@ -89,26 +86,25 @@ async def test_session_memory_basic_functionality_parametrized(runner_method): last_input = model.last_turn_args["input"] assert len(last_input) > 1 # Should have more than just the current message - memory.close() + session.close() @pytest.mark.parametrize("runner_method", ["run", "run_sync", "run_streamed"]) @pytest.mark.asyncio async def test_session_memory_with_explicit_instance_parametrized(runner_method): - """Test session memory with an explicit SQLiteSessionMemory instance across all runner methods.""" + """Test session memory with an explicit SQLiteSession instance across all runner methods.""" with tempfile.TemporaryDirectory() as temp_dir: db_path = Path(temp_dir) / "test_memory.db" - memory = SQLiteSessionMemory(db_path) + session_id = "test_session_456" + session = SQLiteSession(session_id, db_path) model = FakeModel() agent = Agent(name="test", model=model) - session_id = "test_session_456" - # First turn model.set_next_output([get_text_message("Hello")]) result1 = await run_agent_async( - runner_method, agent, "Hi there", memory=memory, session_id=session_id + runner_method, agent, "Hi there", session=session ) assert result1.final_output == "Hello" @@ -118,22 +114,21 @@ async def test_session_memory_with_explicit_instance_parametrized(runner_method) runner_method, agent, "Do you remember what I said?", - memory=memory, - session_id=session_id, + session=session, ) assert result2.final_output == "I remember you said hi" - memory.close() + session.close() @pytest.mark.parametrize("runner_method", ["run", "run_sync", "run_streamed"]) @pytest.mark.asyncio async def test_session_memory_disabled_parametrized(runner_method): - """Test that session memory is disabled when memory=None across all runner methods.""" + """Test that session memory is disabled when session=None across all runner methods.""" model = FakeModel() agent = Agent(name="test", model=model) - # First turn (no memory parameters = disabled) + # First turn (no session parameters = disabled) model.set_next_output([get_text_message("Hello")]) result1 = await run_agent_async(runner_method, agent, "Hi there") assert result1.final_output == "Hello" @@ -156,26 +151,27 @@ async def test_session_memory_different_sessions_parametrized(runner_method): """Test that different session IDs maintain separate conversation histories across all runner methods.""" with tempfile.TemporaryDirectory() as temp_dir: db_path = Path(temp_dir) / "test_memory.db" - memory = SQLiteSessionMemory(db_path) model = FakeModel() agent = Agent(name="test", model=model) # Session 1 session_id_1 = "session_1" + session_1 = SQLiteSession(session_id_1, db_path) model.set_next_output([get_text_message("I like cats")]) result1 = await run_agent_async( - runner_method, agent, "I like cats", memory=memory, session_id=session_id_1 + runner_method, agent, "I like cats", session=session_1 ) assert result1.final_output == "I like cats" # Session 2 - different session session_id_2 = "session_2" + session_2 = SQLiteSession(session_id_2, db_path) model.set_next_output([get_text_message("I like dogs")]) result2 = await run_agent_async( - runner_method, agent, "I like dogs", memory=memory, session_id=session_id_2 + runner_method, agent, "I like dogs", session=session_2 ) assert result2.final_output == "I like dogs" @@ -185,51 +181,21 @@ async def test_session_memory_different_sessions_parametrized(runner_method): runner_method, agent, "What did I say I like?", - memory=memory, - session_id=session_id_1, + session=session_1, ) assert result3.final_output == "Yes, you mentioned cats" - memory.close() - - -@pytest.mark.parametrize("runner_method", ["run", "run_sync", "run_streamed"]) -@pytest.mark.asyncio -async def test_session_memory_no_session_id_parametrized(runner_method): - """Test that session memory raises an exception when no session_id is provided across all runner methods.""" - model = FakeModel() - agent = Agent(name="test", model=model) - memory = SQLiteSessionMemory() - - # Should raise ValueError when trying to run with memory enabled but no session_id - with pytest.raises( - ValueError, match="session_id is required when memory is enabled" - ): - await run_agent_async(runner_method, agent, "Hi there", memory=memory) - - -@pytest.mark.parametrize("runner_method", ["run", "run_sync", "run_streamed"]) -@pytest.mark.asyncio -async def test_session_id_without_memory_parametrized(runner_method): - """Test that providing session_id without memory raises an exception across all runner methods.""" - model = FakeModel() - agent = Agent(name="test", model=model) - - session_id = "test_session_without_memory" - - # Should raise ValueError when trying to run with session_id but no memory - with pytest.raises(ValueError, match="session_id provided but memory is disabled"): - await run_agent_async(runner_method, agent, "Hi there", session_id=session_id) + session_1.close() + session_2.close() @pytest.mark.asyncio async def test_sqlite_session_memory_direct(): - """Test SQLiteSessionMemory class directly.""" + """Test SQLiteSession class directly.""" with tempfile.TemporaryDirectory() as temp_dir: db_path = Path(temp_dir) / "test_direct.db" - memory = SQLiteSessionMemory(db_path) - session_id = "direct_test" + session = SQLiteSession(session_id, db_path) # Test adding and retrieving messages messages = [ @@ -237,8 +203,8 @@ async def test_sqlite_session_memory_direct(): {"role": "assistant", "content": "Hi there!"}, ] - await memory.add_messages(session_id, messages) - retrieved = await memory.get_messages(session_id) + await session.add_messages(messages) + retrieved = await session.get_messages() assert len(retrieved) == 2 assert retrieved[0]["role"] == "user" @@ -247,24 +213,23 @@ async def test_sqlite_session_memory_direct(): assert retrieved[1]["content"] == "Hi there!" # Test clearing session - await memory.clear_session(session_id) - retrieved_after_clear = await memory.get_messages(session_id) + await session.clear_session() + retrieved_after_clear = await session.get_messages() assert len(retrieved_after_clear) == 0 - memory.close() + session.close() @pytest.mark.asyncio async def test_sqlite_session_memory_pop_message(): - """Test SQLiteSessionMemory pop_message functionality.""" + """Test SQLiteSession pop_message functionality.""" with tempfile.TemporaryDirectory() as temp_dir: db_path = Path(temp_dir) / "test_pop.db" - memory = SQLiteSessionMemory(db_path) - session_id = "pop_test" + session = SQLiteSession(session_id, db_path) # Test popping from empty session - popped = await memory.pop_message(session_id) + popped = await session.pop_message() assert popped is None # Add messages @@ -274,44 +239,44 @@ async def test_sqlite_session_memory_pop_message(): {"role": "user", "content": "How are you?"}, ] - await memory.add_messages(session_id, messages) + await session.add_messages(messages) # Verify all messages are there - retrieved = await memory.get_messages(session_id) + retrieved = await session.get_messages() assert len(retrieved) == 3 # Pop the most recent message - popped = await memory.pop_message(session_id) + popped = await session.pop_message() assert popped is not None assert popped["role"] == "user" assert popped["content"] == "How are you?" # Verify message was removed - retrieved_after_pop = await memory.get_messages(session_id) + retrieved_after_pop = await session.get_messages() assert len(retrieved_after_pop) == 2 assert retrieved_after_pop[-1]["content"] == "Hi there!" # Pop another message - popped2 = await memory.pop_message(session_id) + popped2 = await session.pop_message() assert popped2 is not None assert popped2["role"] == "assistant" assert popped2["content"] == "Hi there!" # Pop the last message - popped3 = await memory.pop_message(session_id) + popped3 = await session.pop_message() assert popped3 is not None assert popped3["role"] == "user" assert popped3["content"] == "Hello" # Try to pop from empty session again - popped4 = await memory.pop_message(session_id) + popped4 = await session.pop_message() assert popped4 is None # Verify session is empty - final_messages = await memory.get_messages(session_id) + final_messages = await session.get_messages() assert len(final_messages) == 0 - memory.close() + session.close() @pytest.mark.asyncio @@ -319,10 +284,11 @@ async def test_session_memory_pop_different_sessions(): """Test that pop_message only affects the specified session.""" with tempfile.TemporaryDirectory() as temp_dir: db_path = Path(temp_dir) / "test_pop_sessions.db" - memory = SQLiteSessionMemory(db_path) - session_1 = "session_1" - session_2 = "session_2" + session_1_id = "session_1" + session_2_id = "session_2" + session_1 = SQLiteSession(session_1_id, db_path) + session_2 = SQLiteSession(session_2_id, db_path) # Add messages to both sessions messages_1 = [ @@ -333,25 +299,26 @@ async def test_session_memory_pop_different_sessions(): {"role": "user", "content": "Session 2 message 2"}, ] - await memory.add_messages(session_1, messages_1) - await memory.add_messages(session_2, messages_2) + await session_1.add_messages(messages_1) + await session_2.add_messages(messages_2) # Pop from session 2 - popped = await memory.pop_message(session_2) + popped = await session_2.pop_message() assert popped is not None assert popped["content"] == "Session 2 message 2" # Verify session 1 is unaffected - session_1_messages = await memory.get_messages(session_1) + session_1_messages = await session_1.get_messages() assert len(session_1_messages) == 1 assert session_1_messages[0]["content"] == "Session 1 message" # Verify session 2 has one message left - session_2_messages = await memory.get_messages(session_2) + session_2_messages = await session_2.get_messages() assert len(session_2_messages) == 1 assert session_2_messages[0]["content"] == "Session 2 message 1" - memory.close() + session_1.close() + session_2.close() # Original non-parametrized tests for backwards compatibility @@ -360,28 +327,24 @@ async def test_session_memory_basic_functionality(): """Test basic session memory functionality with SQLite backend.""" with tempfile.TemporaryDirectory() as temp_dir: db_path = Path(temp_dir) / "test_memory.db" - memory = SQLiteSessionMemory(db_path) + session_id = "test_session_123" + session = SQLiteSession(session_id, db_path) model = FakeModel() agent = Agent(name="test", model=model) - session_id = "test_session_123" - # First turn model.set_next_output([get_text_message("San Francisco")]) result1 = await Runner.run( agent, "What city is the Golden Gate Bridge in?", - memory=memory, - session_id=session_id, + session=session, ) assert result1.final_output == "San Francisco" # Second turn - should have conversation history model.set_next_output([get_text_message("California")]) - result2 = await Runner.run( - agent, "What state is it in?", memory=memory, session_id=session_id - ) + result2 = await Runner.run(agent, "What state is it in?", session=session) assert result2.final_output == "California" # Verify that the input to the second turn includes the previous conversation @@ -389,45 +352,42 @@ async def test_session_memory_basic_functionality(): last_input = model.last_turn_args["input"] assert len(last_input) > 1 # Should have more than just the current message - memory.close() + session.close() @pytest.mark.asyncio async def test_session_memory_with_explicit_instance(): - """Test session memory with an explicit SQLiteSessionMemory instance.""" + """Test session memory with an explicit SQLiteSession instance.""" with tempfile.TemporaryDirectory() as temp_dir: db_path = Path(temp_dir) / "test_memory.db" - memory = SQLiteSessionMemory(db_path) + session_id = "test_session_456" + session = SQLiteSession(session_id, db_path) model = FakeModel() agent = Agent(name="test", model=model) - session_id = "test_session_456" - # First turn model.set_next_output([get_text_message("Hello")]) - result1 = await Runner.run( - agent, "Hi there", memory=memory, session_id=session_id - ) + result1 = await Runner.run(agent, "Hi there", session=session) assert result1.final_output == "Hello" # Second turn model.set_next_output([get_text_message("I remember you said hi")]) result2 = await Runner.run( - agent, "Do you remember what I said?", memory=memory, session_id=session_id + agent, "Do you remember what I said?", session=session ) assert result2.final_output == "I remember you said hi" - memory.close() + session.close() @pytest.mark.asyncio async def test_session_memory_disabled(): - """Test that session memory is disabled when memory=None.""" + """Test that session memory is disabled when session=None.""" model = FakeModel() agent = Agent(name="test", model=model) - # First turn (no memory parameters = disabled) + # First turn (no session parameters = disabled) model.set_next_output([get_text_message("Hello")]) result1 = await Runner.run(agent, "Hi there") assert result1.final_output == "Hello" @@ -447,61 +407,30 @@ async def test_session_memory_different_sessions(): """Test that different session IDs maintain separate conversation histories.""" with tempfile.TemporaryDirectory() as temp_dir: db_path = Path(temp_dir) / "test_memory.db" - memory = SQLiteSessionMemory(db_path) model = FakeModel() agent = Agent(name="test", model=model) # Session 1 session_id_1 = "session_1" + session_1 = SQLiteSession(session_id_1, db_path) model.set_next_output([get_text_message("I like cats")]) - result1 = await Runner.run( - agent, "I like cats", memory=memory, session_id=session_id_1 - ) + result1 = await Runner.run(agent, "I like cats", session=session_1) assert result1.final_output == "I like cats" # Session 2 - different session session_id_2 = "session_2" + session_2 = SQLiteSession(session_id_2, db_path) model.set_next_output([get_text_message("I like dogs")]) - result2 = await Runner.run( - agent, "I like dogs", memory=memory, session_id=session_id_2 - ) + result2 = await Runner.run(agent, "I like dogs", session=session_2) assert result2.final_output == "I like dogs" # Back to Session 1 - should remember cats, not dogs model.set_next_output([get_text_message("Yes, you mentioned cats")]) - result3 = await Runner.run( - agent, "What did I say I like?", memory=memory, session_id=session_id_1 - ) + result3 = await Runner.run(agent, "What did I say I like?", session=session_1) assert result3.final_output == "Yes, you mentioned cats" - memory.close() - - -@pytest.mark.asyncio -async def test_session_memory_no_session_id(): - """Test that session memory raises an exception when no session_id is provided.""" - model = FakeModel() - agent = Agent(name="test", model=model) - memory = SQLiteSessionMemory() - - # Should raise ValueError when trying to run with memory enabled but no session_id - with pytest.raises( - ValueError, match="session_id is required when memory is enabled" - ): - await Runner.run(agent, "Hi there", memory=memory) - - -@pytest.mark.asyncio -async def test_session_id_without_memory(): - """Test that providing session_id without memory raises an exception.""" - model = FakeModel() - agent = Agent(name="test", model=model) - - session_id = "test_session_without_memory" - - # Should raise ValueError when trying to run with session_id but no memory - with pytest.raises(ValueError, match="session_id provided but memory is disabled"): - await Runner.run(agent, "Hi there", session_id=session_id) + session_1.close() + session_2.close() From b9be6b9420ec2ce3a7578d23d641bbb4dee6f8c8 Mon Sep 17 00:00:00 2001 From: Stephan Fitzpatrick Date: Sat, 24 May 2025 15:53:19 -0700 Subject: [PATCH 21/26] Update session memory documentation and refactor references - Renamed `session_memory.md` to `session.md` for clarity and consistency. - Updated links in `running_agents.md` to reflect the new documentation filename. - Added comprehensive documentation for session memory functionality, including usage examples and API reference. - Removed references to `SessionMemory` and `SQLiteSessionMemory` from the codebase to streamline session management. --- docs/ref/memory.md | 2 -- docs/running_agents.md | 4 ++-- docs/{session_memory.md => session.md} | 5 ++--- .../{session_memory_example.py => session_example.py} | 0 mkdocs.yml | 4 ++-- src/agents/__init__.py | 4 +--- src/agents/memory/__init__.py | 4 ++-- src/agents/memory/{session_memory.py => session.py} | 7 ++----- tests/{test_session_memory.py => test_session.py} | 0 9 files changed, 11 insertions(+), 19 deletions(-) rename docs/{session_memory.md => session.md} (99%) rename examples/basic/{session_memory_example.py => session_example.py} (100%) rename src/agents/memory/{session_memory.py => session.py} (99%) rename tests/{test_session_memory.py => test_session.py} (100%) diff --git a/docs/ref/memory.md b/docs/ref/memory.md index 5502c4ab..04a2258b 100644 --- a/docs/ref/memory.md +++ b/docs/ref/memory.md @@ -6,5 +6,3 @@ members: - Session - SQLiteSession - - SessionMemory - - SQLiteSessionMemory diff --git a/docs/running_agents.md b/docs/running_agents.md index d86714c7..87e9c7cf 100644 --- a/docs/running_agents.md +++ b/docs/running_agents.md @@ -88,7 +88,7 @@ async def main(): ### Automatic conversation management with Session Memory -For a simpler approach, you can use [Session Memory](session_memory.md) to automatically handle conversation history without manually calling `.to_input_list()`: +For a simpler approach, you can use [Session Memory](session.md) to automatically handle conversation history without manually calling `.to_input_list()`: ```python from agents import Agent, Runner, SQLiteSession @@ -117,7 +117,7 @@ Session memory automatically: - Stores new messages after each run - Maintains separate conversations for different session IDs -See the [Session Memory documentation](session_memory.md) for more details. +See the [Session Memory documentation](session.md) for more details. ## Exceptions diff --git a/docs/session_memory.md b/docs/session.md similarity index 99% rename from docs/session_memory.md rename to docs/session.md index 49316ebc..32cd51fd 100644 --- a/docs/session_memory.md +++ b/docs/session.md @@ -166,7 +166,7 @@ result2 = await Runner.run( ## Custom memory implementations -You can implement your own session memory by creating a class that follows the [`Session`][agents.memory.session_memory.Session] protocol: +You can implement your own session memory by creating a class that follows the [`Session`][agents.memory.session.Session] protocol: ````python from agents.memory import Session @@ -245,7 +245,7 @@ result2 = await Runner.run( "What are my charges?", session=session ) -``` +```` ## Complete example @@ -318,4 +318,3 @@ For detailed API documentation, see: - [`Session`][agents.memory.Session] - Protocol interface - [`SQLiteSession`][agents.memory.SQLiteSession] - SQLite implementation - [`RunConfig.session`][agents.run.RunConfig.session] - Run configuration -```` diff --git a/examples/basic/session_memory_example.py b/examples/basic/session_example.py similarity index 100% rename from examples/basic/session_memory_example.py rename to examples/basic/session_example.py diff --git a/mkdocs.yml b/mkdocs.yml index cfdac7dc..1dbccbd8 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -57,7 +57,7 @@ plugins: - Documentation: - agents.md - running_agents.md - - session_memory.md + - session.md - results.md - streaming.md - tools.md @@ -139,7 +139,7 @@ plugins: - ドキュメント: - agents.md - running_agents.md - - session_memory.md + - session.md - results.md - streaming.md - tools.md diff --git a/src/agents/__init__.py b/src/agents/__init__.py index 7a852644..6ee871b2 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -39,7 +39,7 @@ TResponseInputItem, ) from .lifecycle import AgentHooks, RunHooks -from .memory import Session, SQLiteSession, SessionMemory, SQLiteSessionMemory +from .memory import Session, SQLiteSession from .model_settings import ModelSettings from .models.interface import Model, ModelProvider, ModelTracing from .models.openai_chatcompletions import OpenAIChatCompletionsModel @@ -207,8 +207,6 @@ def enable_verbose_stdout_logging(): "AgentHooks", "Session", "SQLiteSession", - "SessionMemory", - "SQLiteSessionMemory", "RunContextWrapper", "TContext", "RunResult", diff --git a/src/agents/memory/__init__.py b/src/agents/memory/__init__.py index fa7ba1f7..059ca57a 100644 --- a/src/agents/memory/__init__.py +++ b/src/agents/memory/__init__.py @@ -1,3 +1,3 @@ -from .session_memory import Session, SQLiteSession, SessionMemory, SQLiteSessionMemory +from .session import Session, SQLiteSession -__all__ = ["Session", "SQLiteSession", "SessionMemory", "SQLiteSessionMemory"] +__all__ = ["Session", "SQLiteSession"] diff --git a/src/agents/memory/session_memory.py b/src/agents/memory/session.py similarity index 99% rename from src/agents/memory/session_memory.py rename to src/agents/memory/session.py index 98b49d59..7e62a53c 100644 --- a/src/agents/memory/session_memory.py +++ b/src/agents/memory/session.py @@ -19,6 +19,8 @@ class Session(Protocol): agents to maintain context without requiring explicit manual memory management. """ + session_id: str + async def get_messages(self) -> list[TResponseInputItem]: """Retrieve the conversation history for this session. @@ -293,8 +295,3 @@ def close(self) -> None: else: if hasattr(self._local, "connection"): self._local.connection.close() - - -# Legacy aliases for backwards compatibility -SessionMemory = Session -SQLiteSessionMemory = SQLiteSession diff --git a/tests/test_session_memory.py b/tests/test_session.py similarity index 100% rename from tests/test_session_memory.py rename to tests/test_session.py From 3a79e38bd47d56060ebbb2f2b5b629dd8106c83a Mon Sep 17 00:00:00 2001 From: Stephan Fitzpatrick Date: Sat, 24 May 2025 16:02:15 -0700 Subject: [PATCH 22/26] Update documentation to reflect renaming of Session Memory to Sessions - Changed all references from "Session Memory" to "Sessions" in README, documentation, and example files for consistency. - Updated descriptions to clarify the functionality of Sessions in managing conversation history across agent runs. --- README.md | 4 ++-- docs/index.md | 4 ++-- docs/running_agents.md | 8 ++++---- docs/session.md | 11 +++++------ examples/basic/session_example.py | 4 ++-- 5 files changed, 15 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index 004fa178..05060d98 100644 --- a/README.md +++ b/README.md @@ -9,12 +9,12 @@ The OpenAI Agents SDK is a lightweight yet powerful framework for building multi 1. [**Agents**](https://openai.github.io/openai-agents-python/agents): LLMs configured with instructions, tools, guardrails, and handoffs 2. [**Handoffs**](https://openai.github.io/openai-agents-python/handoffs/): A specialized tool call used by the Agents SDK for transferring control between agents 3. [**Guardrails**](https://openai.github.io/openai-agents-python/guardrails/): Configurable safety checks for input and output validation -4. [**Session Memory**](#session-memory): Automatic conversation history management across agent runs +4. [**Sessions**](#sessions): Automatic conversation history management across agent runs 5. [**Tracing**](https://openai.github.io/openai-agents-python/tracing/): Built-in tracking of agent runs, allowing you to view, debug and optimize your workflows Explore the [examples](examples) directory to see the SDK in action, and read our [documentation](https://openai.github.io/openai-agents-python/) for more details. -## Session Memory +## Sessions The Agents SDK provides built-in session memory to automatically maintain conversation history across multiple agent runs, eliminating the need to manually handle `.to_input_list()` between turns. diff --git a/docs/index.md b/docs/index.md index 5111dda7..935c4be5 100644 --- a/docs/index.md +++ b/docs/index.md @@ -5,7 +5,7 @@ The [OpenAI Agents SDK](https://github.com/openai/openai-agents-python) enables - **Agents**, which are LLMs equipped with instructions and tools - **Handoffs**, which allow agents to delegate to other agents for specific tasks - **Guardrails**, which enable the inputs to agents to be validated -- **Session Memory**, which automatically maintains conversation history across agent runs +- **Sessions**, which automatically maintains conversation history across agent runs In combination with Python, these primitives are powerful enough to express complex relationships between tools and agents, and allow you to build real-world applications without a steep learning curve. In addition, the SDK comes with built-in **tracing** that lets you visualize and debug your agentic flows, as well as evaluate them and even fine-tune models for your application. @@ -22,7 +22,7 @@ Here are the main features of the SDK: - Python-first: Use built-in language features to orchestrate and chain agents, rather than needing to learn new abstractions. - Handoffs: A powerful feature to coordinate and delegate between multiple agents. - Guardrails: Run input validations and checks in parallel to your agents, breaking early if the checks fail. -- Session Memory: Automatic conversation history management across agent runs, eliminating manual state handling. +- Sessions: Automatic conversation history management across agent runs, eliminating manual state handling. - Function tools: Turn any Python function into a tool, with automatic schema generation and Pydantic-powered validation. - Tracing: Built-in tracing that lets you visualize, debug and monitor your workflows, as well as use the OpenAI suite of evaluation, fine-tuning and distillation tools. diff --git a/docs/running_agents.md b/docs/running_agents.md index 87e9c7cf..4355fff1 100644 --- a/docs/running_agents.md +++ b/docs/running_agents.md @@ -86,9 +86,9 @@ async def main(): # California ``` -### Automatic conversation management with Session Memory +### Automatic conversation management with Sessions -For a simpler approach, you can use [Session Memory](session.md) to automatically handle conversation history without manually calling `.to_input_list()`: +For a simpler approach, you can use [Sessions](session.md) to automatically handle conversation history without manually calling `.to_input_list()`: ```python from agents import Agent, Runner, SQLiteSession @@ -111,13 +111,13 @@ async def main(): # California ``` -Session memory automatically: +Sessions automatically: - Retrieves conversation history before each run - Stores new messages after each run - Maintains separate conversations for different session IDs -See the [Session Memory documentation](session.md) for more details. +See the [Sessions documentation](session.md) for more details. ## Exceptions diff --git a/docs/session.md b/docs/session.md index 32cd51fd..36f89c63 100644 --- a/docs/session.md +++ b/docs/session.md @@ -1,8 +1,8 @@ -# Session Memory +# Sessions The Agents SDK provides built-in session memory to automatically maintain conversation history across multiple agent runs, eliminating the need to manually handle `.to_input_list()` between turns. -Session memory stores conversation history for a specific session, allowing agents to maintain context without requiring explicit manual memory management. This is particularly useful for building chat applications or multi-turn conversations where you want the agent to remember previous interactions. +Sessions stores conversation history for a specific session, allowing agents to maintain context without requiring explicit manual memory management. This is particularly useful for building chat applications or multi-turn conversations where you want the agent to remember previous interactions. ## Quick start @@ -57,7 +57,7 @@ This eliminates the need to manually call `.to_input_list()` and manage conversa ### Basic operations -Session memory supports several operations for managing conversation history: +Sessions supports several operations for managing conversation history: ```python from agents import SQLiteSession @@ -266,7 +266,7 @@ async def main(): # Create a session instance that will persist across runs session = SQLiteSession("conversation_123", "conversation_history.db") - print("=== Session Memory Example ===") + print("=== Sessions Example ===") print("The agent will remember previous messages automatically.\n") # First turn @@ -304,7 +304,7 @@ async def main(): print("=== Conversation Complete ===") print("Notice how the agent remembered the context from previous turns!") - print("Session memory automatically handles conversation history.") + print("Sessions automatically handles conversation history.") if __name__ == "__main__": @@ -317,4 +317,3 @@ For detailed API documentation, see: - [`Session`][agents.memory.Session] - Protocol interface - [`SQLiteSession`][agents.memory.SQLiteSession] - SQLite implementation -- [`RunConfig.session`][agents.run.RunConfig.session] - Run configuration diff --git a/examples/basic/session_example.py b/examples/basic/session_example.py index 4a85e4db..a8ea4601 100644 --- a/examples/basic/session_example.py +++ b/examples/basic/session_example.py @@ -20,7 +20,7 @@ async def main(): session_id = "conversation_123" session = SQLiteSession(session_id) - print("=== Session Memory Example ===") + print("=== Session Example ===") print("The agent will remember previous messages automatically.\n") # First turn @@ -54,7 +54,7 @@ async def main(): print("=== Conversation Complete ===") print("Notice how the agent remembered the context from previous turns!") - print("Session memory automatically handles conversation history.") + print("Sessions automatically handles conversation history.") if __name__ == "__main__": From 8e9edd83261bf348c25695da928e124ee1672536 Mon Sep 17 00:00:00 2001 From: Stephan Fitzpatrick Date: Sat, 24 May 2025 16:14:45 -0700 Subject: [PATCH 23/26] Update documentation and references for Sessions - Renamed instances of "session.md" to "sessions.md" in mkdocs.yml and running_agents.md for consistency. - Added new sessions.md file detailing the functionality and usage of session memory in the Agents SDK, including examples and API reference. --- docs/running_agents.md | 4 ++-- docs/{session.md => sessions.md} | 0 mkdocs.yml | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) rename docs/{session.md => sessions.md} (100%) diff --git a/docs/running_agents.md b/docs/running_agents.md index 4355fff1..6898f510 100644 --- a/docs/running_agents.md +++ b/docs/running_agents.md @@ -88,7 +88,7 @@ async def main(): ### Automatic conversation management with Sessions -For a simpler approach, you can use [Sessions](session.md) to automatically handle conversation history without manually calling `.to_input_list()`: +For a simpler approach, you can use [Sessions](sessions.md) to automatically handle conversation history without manually calling `.to_input_list()`: ```python from agents import Agent, Runner, SQLiteSession @@ -117,7 +117,7 @@ Sessions automatically: - Stores new messages after each run - Maintains separate conversations for different session IDs -See the [Sessions documentation](session.md) for more details. +See the [Sessions documentation](sessions.md) for more details. ## Exceptions diff --git a/docs/session.md b/docs/sessions.md similarity index 100% rename from docs/session.md rename to docs/sessions.md diff --git a/mkdocs.yml b/mkdocs.yml index 1dbccbd8..8bc3b57f 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -57,7 +57,7 @@ plugins: - Documentation: - agents.md - running_agents.md - - session.md + - sessions.md - results.md - streaming.md - tools.md @@ -139,7 +139,7 @@ plugins: - ドキュメント: - agents.md - running_agents.md - - session.md + - sessions.md - results.md - streaming.md - tools.md From 92f7db51eed3189b96bda110419762cdeddee956 Mon Sep 17 00:00:00 2001 From: Stephan Fitzpatrick Date: Sat, 24 May 2025 16:38:15 -0700 Subject: [PATCH 24/26] Enhance session message retrieval functionality - Updated the `get_messages` method in the `Session` and `SQLiteSession` classes to accept an optional `amount` parameter, allowing retrieval of the latest N messages or all messages if not specified. - Added a demonstration in `session_example.py` to showcase the new functionality for fetching the latest messages. - Implemented tests in `test_session.py` to verify the behavior of the `get_messages` method with various amounts, ensuring correct message retrieval. --- examples/basic/session_example.py | 15 +++++++++ src/agents/memory/session.py | 54 +++++++++++++++++++++++++------ tests/test_session.py | 53 ++++++++++++++++++++++++++++++ 3 files changed, 112 insertions(+), 10 deletions(-) diff --git a/examples/basic/session_example.py b/examples/basic/session_example.py index a8ea4601..56431721 100644 --- a/examples/basic/session_example.py +++ b/examples/basic/session_example.py @@ -56,6 +56,21 @@ async def main(): print("Notice how the agent remembered the context from previous turns!") print("Sessions automatically handles conversation history.") + # Demonstrate the amount parameter - get only the latest 2 messages + print("\n=== Latest Messages Demo ===") + latest_messages = await session.get_messages(amount=2) + print("Latest 2 messages:") + for i, msg in enumerate(latest_messages, 1): + role = msg.get("role", "unknown") + content = msg.get("content", "") + print(f" {i}. {role}: {content}") + + print(f"\nFetched {len(latest_messages)} out of total conversation history.") + + # Get all messages to show the difference + all_messages = await session.get_messages() + print(f"Total messages in session: {len(all_messages)}") + if __name__ == "__main__": asyncio.run(main()) diff --git a/src/agents/memory/session.py b/src/agents/memory/session.py index 7e62a53c..230981e0 100644 --- a/src/agents/memory/session.py +++ b/src/agents/memory/session.py @@ -21,9 +21,13 @@ class Session(Protocol): session_id: str - async def get_messages(self) -> list[TResponseInputItem]: + async def get_messages(self, amount: int | None = None) -> list[TResponseInputItem]: """Retrieve the conversation history for this session. + Args: + amount: Maximum number of messages to retrieve. If None, retrieves all messages. + When specified, returns the latest N messages in chronological order. + Returns: List of input items representing the conversation history """ @@ -144,9 +148,13 @@ def _init_db_for_connection(self, conn: sqlite3.Connection) -> None: conn.commit() - async def get_messages(self) -> list[TResponseInputItem]: + async def get_messages(self, amount: int | None = None) -> list[TResponseInputItem]: """Retrieve the conversation history for this session. + Args: + amount: Maximum number of messages to retrieve. If None, retrieves all messages. + When specified, returns the latest N messages in chronological order. + Returns: List of input items representing the conversation history """ @@ -154,14 +162,40 @@ async def get_messages(self) -> list[TResponseInputItem]: def _get_messages_sync(): conn = self._get_connection() with self._lock if self._is_memory_db else threading.Lock(): - cursor = conn.execute( - f""" - SELECT message_data FROM {self.messages_table} - WHERE session_id = ? - ORDER BY created_at ASC - """, - (self.session_id,), - ) + if amount is None: + # Fetch all messages in chronological order + cursor = conn.execute( + f""" + SELECT message_data FROM {self.messages_table} + WHERE session_id = ? + ORDER BY created_at ASC + """, + (self.session_id,), + ) + else: + # Fetch the latest N messages in chronological order + # First get the total count to calculate offset + count_cursor = conn.execute( + f""" + SELECT COUNT(*) FROM {self.messages_table} + WHERE session_id = ? + """, + (self.session_id,), + ) + total_count = count_cursor.fetchone()[0] + + # Calculate offset to get the latest N messages + offset = max(0, total_count - amount) + + cursor = conn.execute( + f""" + SELECT message_data FROM {self.messages_table} + WHERE session_id = ? + ORDER BY created_at ASC + LIMIT ? OFFSET ? + """, + (self.session_id, amount, offset), + ) messages = [] for (message_data,) in cursor.fetchall(): diff --git a/tests/test_session.py b/tests/test_session.py index bf81ce49..41418a72 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -321,6 +321,59 @@ async def test_session_memory_pop_different_sessions(): session_2.close() +@pytest.mark.asyncio +async def test_sqlite_session_get_messages_with_amount(): + """Test SQLiteSession get_messages with amount parameter.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_amount.db" + session_id = "amount_test" + session = SQLiteSession(session_id, db_path) + + # Add multiple messages + messages = [ + {"role": "user", "content": "Message 1"}, + {"role": "assistant", "content": "Response 1"}, + {"role": "user", "content": "Message 2"}, + {"role": "assistant", "content": "Response 2"}, + {"role": "user", "content": "Message 3"}, + {"role": "assistant", "content": "Response 3"}, + ] + + await session.add_messages(messages) + + # Test getting all messages (default behavior) + all_messages = await session.get_messages() + assert len(all_messages) == 6 + assert all_messages[0]["content"] == "Message 1" + assert all_messages[-1]["content"] == "Response 3" + + # Test getting latest 2 messages + latest_2 = await session.get_messages(amount=2) + assert len(latest_2) == 2 + assert latest_2[0]["content"] == "Message 3" + assert latest_2[1]["content"] == "Response 3" + + # Test getting latest 4 messages + latest_4 = await session.get_messages(amount=4) + assert len(latest_4) == 4 + assert latest_4[0]["content"] == "Message 2" + assert latest_4[1]["content"] == "Response 2" + assert latest_4[2]["content"] == "Message 3" + assert latest_4[3]["content"] == "Response 3" + + # Test getting more messages than available + latest_10 = await session.get_messages(amount=10) + assert len(latest_10) == 6 # Should return all available messages + assert latest_10[0]["content"] == "Message 1" + assert latest_10[-1]["content"] == "Response 3" + + # Test getting 0 messages + latest_0 = await session.get_messages(amount=0) + assert len(latest_0) == 0 + + session.close() + + # Original non-parametrized tests for backwards compatibility @pytest.mark.asyncio async def test_session_memory_basic_functionality(): From b5ad785dad27853909516aa628f41c497e9c198c Mon Sep 17 00:00:00 2001 From: Stephan Fitzpatrick Date: Sat, 24 May 2025 16:52:57 -0700 Subject: [PATCH 25/26] Update AGENTS.md to include note on executing CLI commands in virtual environment --- AGENTS.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/AGENTS.md b/AGENTS.md index ff37db32..9e857801 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -35,6 +35,8 @@ Welcome to the OpenAI Agents SDK repository. This file contains the main points Coverage can be generated with `make coverage`. +*Note*: Use `uv run ...` to execute arbitrary cli commands within the project's virtual environment. + ## Snapshot tests Some tests rely on inline snapshots. See `tests/README.md` for details on updating them: From e26aed14991dd14928e8c7f73e13dad0fc8221ba Mon Sep 17 00:00:00 2001 From: Stephan Fitzpatrick Date: Sat, 24 May 2025 16:59:15 -0700 Subject: [PATCH 26/26] Update README to reflect changes from "Memory options" to "Session options" and clarify session implementation details. Adjusted section headers and descriptions for consistency with recent documentation updates. --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 05060d98..0e8eff64 100644 --- a/README.md +++ b/README.md @@ -57,10 +57,10 @@ result = Runner.run_sync( print(result.final_output) # "Approximately 39 million" ``` -### Memory options +### Session options - **No memory** (default): No session memory when session parameter is omitted -- **`session=Session`**: Use the provided session implementation +- **`session: Session = DatabaseSession(...)`**: Use a Session instance to manage conversation history ```python from agents import Agent, Runner, SQLiteSession @@ -82,7 +82,7 @@ result2 = await Runner.run( ) ``` -### Custom memory implementations +### Custom session implementations You can implement your own session memory by creating a class that follows the `Session` protocol: