From 111fc9ee66c5f6f4af3ba6faaedba4789da3d03a Mon Sep 17 00:00:00 2001 From: Jonny Kalambay Date: Tue, 22 Apr 2025 19:26:47 -0700 Subject: [PATCH 01/33] Adding extra_headers parameters to ModelSettings (#550) --- src/agents/extensions/models/litellm_model.py | 2 +- src/agents/model_settings.py | 6 +- src/agents/models/openai_chatcompletions.py | 2 +- src/agents/models/openai_responses.py | 2 +- tests/test_extra_headers.py | 92 +++++++++++++++++++ 5 files changed, 100 insertions(+), 4 deletions(-) create mode 100644 tests/test_extra_headers.py diff --git a/src/agents/extensions/models/litellm_model.py b/src/agents/extensions/models/litellm_model.py index e939ee8da..f5e7752fa 100644 --- a/src/agents/extensions/models/litellm_model.py +++ b/src/agents/extensions/models/litellm_model.py @@ -286,7 +286,7 @@ async def _fetch_response( stream=stream, stream_options=stream_options, reasoning_effort=reasoning_effort, - extra_headers=HEADERS, + extra_headers={**HEADERS, **(model_settings.extra_headers or {})}, api_key=self.api_key, base_url=self.base_url, **extra_kwargs, diff --git a/src/agents/model_settings.py b/src/agents/model_settings.py index ed9a01318..fee92b4e0 100644 --- a/src/agents/model_settings.py +++ b/src/agents/model_settings.py @@ -3,7 +3,7 @@ from dataclasses import dataclass, fields, replace from typing import Literal -from openai._types import Body, Query +from openai._types import Body, Headers, Query from openai.types.shared import Reasoning @@ -67,6 +67,10 @@ class ModelSettings: """Additional body fields to provide with the request. Defaults to None if not provided.""" + extra_headers: Headers | None = None + """Additional headers to provide with the request. + Defaults to None if not provided.""" + def resolve(self, override: ModelSettings | None) -> ModelSettings: """Produce a new ModelSettings by overlaying any non-None values from the override on top of this instance.""" diff --git a/src/agents/models/openai_chatcompletions.py b/src/agents/models/openai_chatcompletions.py index 9fd102690..15bf19cb3 100644 --- a/src/agents/models/openai_chatcompletions.py +++ b/src/agents/models/openai_chatcompletions.py @@ -255,7 +255,7 @@ async def _fetch_response( stream_options=self._non_null_or_not_given(stream_options), store=self._non_null_or_not_given(store), reasoning_effort=self._non_null_or_not_given(reasoning_effort), - extra_headers=HEADERS, + extra_headers={ **HEADERS, **(model_settings.extra_headers or {}) }, extra_query=model_settings.extra_query, extra_body=model_settings.extra_body, metadata=self._non_null_or_not_given(model_settings.metadata), diff --git a/src/agents/models/openai_responses.py b/src/agents/models/openai_responses.py index b751663da..c1ff85b98 100644 --- a/src/agents/models/openai_responses.py +++ b/src/agents/models/openai_responses.py @@ -253,7 +253,7 @@ async def _fetch_response( tool_choice=tool_choice, parallel_tool_calls=parallel_tool_calls, stream=stream, - extra_headers=_HEADERS, + extra_headers={**_HEADERS, **(model_settings.extra_headers or {})}, extra_query=model_settings.extra_query, extra_body=model_settings.extra_body, text=response_format, diff --git a/tests/test_extra_headers.py b/tests/test_extra_headers.py new file mode 100644 index 000000000..f29c25408 --- /dev/null +++ b/tests/test_extra_headers.py @@ -0,0 +1,92 @@ +import pytest +from openai.types.chat.chat_completion import ChatCompletion, Choice +from openai.types.chat.chat_completion_message import ChatCompletionMessage + +from agents import ModelSettings, ModelTracing, OpenAIChatCompletionsModel, OpenAIResponsesModel + + +@pytest.mark.allow_call_model_methods +@pytest.mark.asyncio +async def test_extra_headers_passed_to_openai_responses_model(): + """ + Ensure extra_headers in ModelSettings is passed to the OpenAIResponsesModel client. + """ + called_kwargs = {} + + class DummyResponses: + async def create(self, **kwargs): + nonlocal called_kwargs + called_kwargs = kwargs + class DummyResponse: + id = "dummy" + output = [] + usage = type( + "Usage", (), {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0} + )() + return DummyResponse() + + class DummyClient: + def __init__(self): + self.responses = DummyResponses() + + + + model = OpenAIResponsesModel(model="gpt-4", openai_client=DummyClient()) # type: ignore + extra_headers = {"X-Test-Header": "test-value"} + await model.get_response( + system_instructions=None, + input="hi", + model_settings=ModelSettings(extra_headers=extra_headers), + tools=[], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + previous_response_id=None, + ) + assert "extra_headers" in called_kwargs + assert called_kwargs["extra_headers"]["X-Test-Header"] == "test-value" + + + +@pytest.mark.allow_call_model_methods +@pytest.mark.asyncio +async def test_extra_headers_passed_to_openai_client(): + """ + Ensure extra_headers in ModelSettings is passed to the OpenAI client. + """ + called_kwargs = {} + + class DummyCompletions: + async def create(self, **kwargs): + nonlocal called_kwargs + called_kwargs = kwargs + msg = ChatCompletionMessage(role="assistant", content="Hello") + choice = Choice(index=0, finish_reason="stop", message=msg) + return ChatCompletion( + id="resp-id", + created=0, + model="fake", + object="chat.completion", + choices=[choice], + usage=None, + ) + + class DummyClient: + def __init__(self): + self.chat = type("_Chat", (), {"completions": DummyCompletions()})() + self.base_url = "https://api.openai.com" + + model = OpenAIChatCompletionsModel(model="gpt-4", openai_client=DummyClient()) # type: ignore + extra_headers = {"X-Test-Header": "test-value"} + await model.get_response( + system_instructions=None, + input="hi", + model_settings=ModelSettings(extra_headers=extra_headers), + tools=[], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + previous_response_id=None, + ) + assert "extra_headers" in called_kwargs + assert called_kwargs["extra_headers"]["X-Test-Header"] == "test-value" From 178020ea33980e5873a82dc715e79f0c6a285623 Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Wed, 23 Apr 2025 11:29:12 +0900 Subject: [PATCH 02/33] Examples: Fix financial_research_agent instructions (#573) --- examples/financial_research_agent/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/financial_research_agent/main.py b/examples/financial_research_agent/main.py index 3fa8a7e08..b5b6cfdfd 100644 --- a/examples/financial_research_agent/main.py +++ b/examples/financial_research_agent/main.py @@ -4,7 +4,7 @@ # Entrypoint for the financial bot example. -# Run this as `python -m examples.financial_bot.main` and enter a +# Run this as `python -m examples.financial_research_agent.main` and enter a # financial research query, for example: # "Write up an analysis of Apple Inc.'s most recent quarter." async def main() -> None: From a113fea0eef82bb37a0a803eaae42c4761d0ebdf Mon Sep 17 00:00:00 2001 From: Andrew Han Date: Wed, 23 Apr 2025 16:51:10 -0700 Subject: [PATCH 03/33] Allow cancel out of the streaming result (#579) Fix for #574 @rm-openai I'm not sure how to add a test within the repo but I have pasted a test script below that seems to work ```python import asyncio from openai.types.responses import ResponseTextDeltaEvent from agents import Agent, Runner async def main(): agent = Agent( name="Joker", instructions="You are a helpful assistant.", ) result = Runner.run_streamed(agent, input="Please tell me 5 jokes.") num_visible_event = 0 async for event in result.stream_events(): if event.type == "raw_response_event" and isinstance(event.data, ResponseTextDeltaEvent): print(event.data.delta, end="", flush=True) num_visible_event += 1 print(num_visible_event) if num_visible_event == 3: result.cancel() if __name__ == "__main__": asyncio.run(main()) ```` --- src/agents/result.py | 24 +++++++++++++++++++++--- tests/test_cancel_streaming.py | 22 ++++++++++++++++++++++ 2 files changed, 43 insertions(+), 3 deletions(-) create mode 100644 tests/test_cancel_streaming.py diff --git a/src/agents/result.py b/src/agents/result.py index 0d8372c86..1f1c78328 100644 --- a/src/agents/result.py +++ b/src/agents/result.py @@ -75,7 +75,9 @@ def final_output_as(self, cls: type[T], raise_if_incorrect_type: bool = False) - def to_input_list(self) -> list[TResponseInputItem]: """Creates a new input list, merging the original input with all the new items generated.""" - original_items: list[TResponseInputItem] = ItemHelpers.input_to_new_input_list(self.input) + original_items: list[TResponseInputItem] = ItemHelpers.input_to_new_input_list( + self.input + ) new_items = [item.to_input_item() for item in self.new_items] return original_items + new_items @@ -152,6 +154,18 @@ def last_agent(self) -> Agent[Any]: """ return self.current_agent + def cancel(self) -> None: + """Cancels the streaming run, stopping all background tasks and marking the run as + complete.""" + self._cleanup_tasks() # Cancel all running tasks + self.is_complete = True # Mark the run as complete to stop event streaming + + # Optionally, clear the event queue to prevent processing stale events + while not self._event_queue.empty(): + self._event_queue.get_nowait() + while not self._input_guardrail_queue.empty(): + self._input_guardrail_queue.get_nowait() + async def stream_events(self) -> AsyncIterator[StreamEvent]: """Stream deltas for new items as they are generated. We're using the types from the OpenAI Responses API, so these are semantic events: each event has a `type` field that @@ -192,13 +206,17 @@ async def stream_events(self) -> AsyncIterator[StreamEvent]: def _check_errors(self): if self.current_turn > self.max_turns: - self._stored_exception = MaxTurnsExceeded(f"Max turns ({self.max_turns}) exceeded") + self._stored_exception = MaxTurnsExceeded( + f"Max turns ({self.max_turns}) exceeded" + ) # Fetch all the completed guardrail results from the queue and raise if needed while not self._input_guardrail_queue.empty(): guardrail_result = self._input_guardrail_queue.get_nowait() if guardrail_result.output.tripwire_triggered: - self._stored_exception = InputGuardrailTripwireTriggered(guardrail_result) + self._stored_exception = InputGuardrailTripwireTriggered( + guardrail_result + ) # Check the tasks for any exceptions if self._run_impl_task and self._run_impl_task.done(): diff --git a/tests/test_cancel_streaming.py b/tests/test_cancel_streaming.py new file mode 100644 index 000000000..6d1807d7c --- /dev/null +++ b/tests/test_cancel_streaming.py @@ -0,0 +1,22 @@ +import pytest + +from agents import Agent, Runner + +from .fake_model import FakeModel + + +@pytest.mark.asyncio +async def test_joker_streamed_jokes_with_cancel(): + model = FakeModel() + agent = Agent(name="Joker", model=model) + + result = Runner.run_streamed(agent, input="Please tell me 5 jokes.") + num_events = 0 + stop_after = 1 # There are two that the model gives back. + + async for _event in result.stream_events(): + num_events += 1 + if num_events == 1: + result.cancel() + + assert num_events == 1, f"Expected {stop_after} visible events, but got {num_events}" From 3755ea86589b8e929c5b2bdd51df9f62c1cad8bf Mon Sep 17 00:00:00 2001 From: Rohan Mehta Date: Wed, 23 Apr 2025 20:39:07 -0400 Subject: [PATCH 04/33] Create to_json_dict for ModelSettings (#582) Now that `ModelSettings` has `Reasoning`, a non-primitive object, `dataclasses.as_dict()` wont work. It will raise an error when you try to serialize (e.g. for tracing). This ensures the object is actually serializable. --- pyproject.toml | 2 +- src/agents/extensions/models/litellm_model.py | 5 +- src/agents/model_settings.py | 17 +++++- src/agents/models/openai_chatcompletions.py | 7 +-- tests/model_settings/test_serialization.py | 59 +++++++++++++++++++ tests/voice/conftest.py | 1 - uv.lock | 8 +-- 7 files changed, 84 insertions(+), 15 deletions(-) create mode 100644 tests/model_settings/test_serialization.py diff --git a/pyproject.toml b/pyproject.toml index eeeb6d3d3..12ffff1ff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ requires-python = ">=3.9" license = "MIT" authors = [{ name = "OpenAI", email = "support@openai.com" }] dependencies = [ - "openai>=1.66.5", + "openai>=1.76.0", "pydantic>=2.10, <3", "griffe>=1.5.6, <2", "typing-extensions>=4.12.2, <5", diff --git a/src/agents/extensions/models/litellm_model.py b/src/agents/extensions/models/litellm_model.py index f5e7752fa..dc672acd4 100644 --- a/src/agents/extensions/models/litellm_model.py +++ b/src/agents/extensions/models/litellm_model.py @@ -1,6 +1,5 @@ from __future__ import annotations -import dataclasses import json import time from collections.abc import AsyncIterator @@ -75,7 +74,7 @@ async def get_response( ) -> ModelResponse: with generation_span( model=str(self.model), - model_config=dataclasses.asdict(model_settings) + model_config=model_settings.to_json_dict() | {"base_url": str(self.base_url or ""), "model_impl": "litellm"}, disabled=tracing.is_disabled(), ) as span_generation: @@ -147,7 +146,7 @@ async def stream_response( ) -> AsyncIterator[TResponseStreamEvent]: with generation_span( model=str(self.model), - model_config=dataclasses.asdict(model_settings) + model_config=model_settings.to_json_dict() | {"base_url": str(self.base_url or ""), "model_impl": "litellm"}, disabled=tracing.is_disabled(), ) as span_generation: diff --git a/src/agents/model_settings.py b/src/agents/model_settings.py index fee92b4e0..7b016c98f 100644 --- a/src/agents/model_settings.py +++ b/src/agents/model_settings.py @@ -1,10 +1,12 @@ from __future__ import annotations +import dataclasses from dataclasses import dataclass, fields, replace -from typing import Literal +from typing import Any, Literal from openai._types import Body, Headers, Query from openai.types.shared import Reasoning +from pydantic import BaseModel @dataclass @@ -83,3 +85,16 @@ def resolve(self, override: ModelSettings | None) -> ModelSettings: if getattr(override, field.name) is not None } return replace(self, **changes) + + def to_json_dict(self) -> dict[str, Any]: + dataclass_dict = dataclasses.asdict(self) + + json_dict: dict[str, Any] = {} + + for field_name, value in dataclass_dict.items(): + if isinstance(value, BaseModel): + json_dict[field_name] = value.model_dump(mode="json") + else: + json_dict[field_name] = value + + return json_dict diff --git a/src/agents/models/openai_chatcompletions.py b/src/agents/models/openai_chatcompletions.py index 15bf19cb3..89619f838 100644 --- a/src/agents/models/openai_chatcompletions.py +++ b/src/agents/models/openai_chatcompletions.py @@ -1,6 +1,5 @@ from __future__ import annotations -import dataclasses import json import time from collections.abc import AsyncIterator @@ -56,8 +55,7 @@ async def get_response( ) -> ModelResponse: with generation_span( model=str(self.model), - model_config=dataclasses.asdict(model_settings) - | {"base_url": str(self._client.base_url)}, + model_config=model_settings.to_json_dict() | {"base_url": str(self._client.base_url)}, disabled=tracing.is_disabled(), ) as span_generation: response = await self._fetch_response( @@ -121,8 +119,7 @@ async def stream_response( """ with generation_span( model=str(self.model), - model_config=dataclasses.asdict(model_settings) - | {"base_url": str(self._client.base_url)}, + model_config=model_settings.to_json_dict() | {"base_url": str(self._client.base_url)}, disabled=tracing.is_disabled(), ) as span_generation: response, stream = await self._fetch_response( diff --git a/tests/model_settings/test_serialization.py b/tests/model_settings/test_serialization.py new file mode 100644 index 000000000..d76a58d17 --- /dev/null +++ b/tests/model_settings/test_serialization.py @@ -0,0 +1,59 @@ +import json +from dataclasses import fields + +from openai.types.shared import Reasoning + +from agents.model_settings import ModelSettings + + +def verify_serialization(model_settings: ModelSettings) -> None: + """Verify that ModelSettings can be serialized to a JSON string.""" + json_dict = model_settings.to_json_dict() + json_string = json.dumps(json_dict) + assert json_string is not None + + +def test_basic_serialization() -> None: + """Tests whether ModelSettings can be serialized to a JSON string.""" + + # First, lets create a ModelSettings instance + model_settings = ModelSettings( + temperature=0.5, + top_p=0.9, + max_tokens=100, + ) + + # Now, lets serialize the ModelSettings instance to a JSON string + verify_serialization(model_settings) + + +def test_all_fields_serialization() -> None: + """Tests whether ModelSettings can be serialized to a JSON string.""" + + # First, lets create a ModelSettings instance + model_settings = ModelSettings( + temperature=0.5, + top_p=0.9, + frequency_penalty=0.0, + presence_penalty=0.0, + tool_choice="auto", + parallel_tool_calls=True, + truncation="auto", + max_tokens=100, + reasoning=Reasoning(), + metadata={"foo": "bar"}, + store=False, + include_usage=False, + extra_query={"foo": "bar"}, + extra_body={"foo": "bar"}, + extra_headers={"foo": "bar"}, + ) + + # Verify that every single field is set to a non-None value + for field in fields(model_settings): + assert getattr(model_settings, field.name) is not None, ( + f"You must set the {field.name} field" + ) + + # Now, lets serialize the ModelSettings instance to a JSON string + verify_serialization(model_settings) diff --git a/tests/voice/conftest.py b/tests/voice/conftest.py index 6ed7422ce..79d85d8b4 100644 --- a/tests/voice/conftest.py +++ b/tests/voice/conftest.py @@ -9,4 +9,3 @@ def pytest_ignore_collect(collection_path, config): if str(collection_path).startswith(this_dir): return True - diff --git a/uv.lock b/uv.lock index 3a737cf37..4c6c370ad 100644 --- a/uv.lock +++ b/uv.lock @@ -1463,7 +1463,7 @@ wheels = [ [[package]] name = "openai" -version = "1.74.0" +version = "1.76.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, @@ -1475,9 +1475,9 @@ dependencies = [ { name = "tqdm" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/75/86/c605a6e84da0248f2cebfcd864b5a6076ecf78849245af5e11d2a5ec7977/openai-1.74.0.tar.gz", hash = "sha256:592c25b8747a7cad33a841958f5eb859a785caea9ee22b9e4f4a2ec062236526", size = 427571 } +sdist = { url = "https://files.pythonhosted.org/packages/84/51/817969ec969b73d8ddad085670ecd8a45ef1af1811d8c3b8a177ca4d1309/openai-1.76.0.tar.gz", hash = "sha256:fd2bfaf4608f48102d6b74f9e11c5ecaa058b60dad9c36e409c12477dfd91fb2", size = 434660 } wheels = [ - { url = "https://files.pythonhosted.org/packages/a9/91/8c150f16a96367e14bd7d20e86e0bbbec3080e3eb593e63f21a7f013f8e4/openai-1.74.0-py3-none-any.whl", hash = "sha256:aff3e0f9fb209836382ec112778667027f4fd6ae38bdb2334bc9e173598b092a", size = 644790 }, + { url = "https://files.pythonhosted.org/packages/59/aa/84e02ab500ca871eb8f62784426963a1c7c17a72fea3c7f268af4bbaafa5/openai-1.76.0-py3-none-any.whl", hash = "sha256:a712b50e78cf78e6d7b2a8f69c4978243517c2c36999756673e07a14ce37dc0a", size = 661201 }, ] [[package]] @@ -1538,7 +1538,7 @@ requires-dist = [ { name = "litellm", marker = "extra == 'litellm'", specifier = ">=1.65.0,<2" }, { name = "mcp", marker = "python_full_version >= '3.10'", specifier = ">=1.6.0,<2" }, { name = "numpy", marker = "python_full_version >= '3.10' and extra == 'voice'", specifier = ">=2.2.0,<3" }, - { name = "openai", specifier = ">=1.66.5" }, + { name = "openai", specifier = ">=1.76.0" }, { name = "pydantic", specifier = ">=2.10,<3" }, { name = "requests", specifier = ">=2.0,<3" }, { name = "types-requests", specifier = ">=2.0,<3" }, From af80e3a97123a5a0ad0fba695fbc257163c23224 Mon Sep 17 00:00:00 2001 From: Nathan Brake <33383515+njbrake@users.noreply.github.com> Date: Thu, 24 Apr 2025 12:12:46 -0400 Subject: [PATCH 05/33] Prevent MCP ClientSession hang (#580) Per https://modelcontextprotocol.io/specification/draft/basic/lifecycle#timeouts "Implementations SHOULD establish timeouts for all sent requests, to prevent hung connections and resource exhaustion. When the request has not received a success or error response within the timeout period, the sender SHOULD issue a cancellation notification for that request and stop waiting for a response. SDKs and other middleware SHOULD allow these timeouts to be configured on a per-request basis." I picked 5 seconds since that's the default for SSE --- src/agents/mcp/server.py | 26 ++++++++++++++++++++++---- tests/mcp/test_server_errors.py | 2 +- 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/src/agents/mcp/server.py b/src/agents/mcp/server.py index 9a137bbdd..9916c92b0 100644 --- a/src/agents/mcp/server.py +++ b/src/agents/mcp/server.py @@ -3,6 +3,7 @@ import abc import asyncio from contextlib import AbstractAsyncContextManager, AsyncExitStack +from datetime import timedelta from pathlib import Path from typing import Any, Literal @@ -54,7 +55,7 @@ async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> C class _MCPServerWithClientSession(MCPServer, abc.ABC): """Base class for MCP servers that use a `ClientSession` to communicate with the server.""" - def __init__(self, cache_tools_list: bool): + def __init__(self, cache_tools_list: bool, client_session_timeout_seconds: float | None): """ Args: cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be @@ -63,12 +64,16 @@ def __init__(self, cache_tools_list: bool): by calling `invalidate_tools_cache()`. You should set this to `True` if you know the server will not change its tools list, because it can drastically improve latency (by avoiding a round-trip to the server every time). + + client_session_timeout_seconds: the read timeout passed to the MCP ClientSession. """ self.session: ClientSession | None = None self.exit_stack: AsyncExitStack = AsyncExitStack() self._cleanup_lock: asyncio.Lock = asyncio.Lock() self.cache_tools_list = cache_tools_list + self.client_session_timeout_seconds = client_session_timeout_seconds + # The cache is always dirty at startup, so that we fetch tools at least once self._cache_dirty = True self._tools_list: list[MCPTool] | None = None @@ -101,7 +106,15 @@ async def connect(self): try: transport = await self.exit_stack.enter_async_context(self.create_streams()) read, write = transport - session = await self.exit_stack.enter_async_context(ClientSession(read, write)) + session = await self.exit_stack.enter_async_context( + ClientSession( + read, + write, + timedelta(seconds=self.client_session_timeout_seconds) + if self.client_session_timeout_seconds + else None, + ) + ) await session.initialize() self.session = session except Exception as e: @@ -183,6 +196,7 @@ def __init__( params: MCPServerStdioParams, cache_tools_list: bool = False, name: str | None = None, + client_session_timeout_seconds: float | None = 5, ): """Create a new MCP server based on the stdio transport. @@ -199,8 +213,9 @@ def __init__( improve latency (by avoiding a round-trip to the server every time). name: A readable name for the server. If not provided, we'll create one from the command. + client_session_timeout_seconds: the read timeout passed to the MCP ClientSession. """ - super().__init__(cache_tools_list) + super().__init__(cache_tools_list, client_session_timeout_seconds) self.params = StdioServerParameters( command=params["command"], @@ -257,6 +272,7 @@ def __init__( params: MCPServerSseParams, cache_tools_list: bool = False, name: str | None = None, + client_session_timeout_seconds: float | None = 5, ): """Create a new MCP server based on the HTTP with SSE transport. @@ -274,8 +290,10 @@ def __init__( name: A readable name for the server. If not provided, we'll create one from the URL. + + client_session_timeout_seconds: the read timeout passed to the MCP ClientSession. """ - super().__init__(cache_tools_list) + super().__init__(cache_tools_list, client_session_timeout_seconds) self.params = params self._name = name or f"sse: {self.params['url']}" diff --git a/tests/mcp/test_server_errors.py b/tests/mcp/test_server_errors.py index bdca7ce62..fbd8db17d 100644 --- a/tests/mcp/test_server_errors.py +++ b/tests/mcp/test_server_errors.py @@ -6,7 +6,7 @@ class CrashingClientSessionServer(_MCPServerWithClientSession): def __init__(self): - super().__init__(cache_tools_list=False) + super().__init__(cache_tools_list=False, client_session_timeout_seconds=5) self.cleanup_called = False def create_streams(self): From e11b822d5f075fc32683c8df71ac9388a7df79e5 Mon Sep 17 00:00:00 2001 From: Daniele Morotti <58258368+DanieleMorotti@users.noreply.github.com> Date: Thu, 24 Apr 2025 18:53:39 +0200 Subject: [PATCH 06/33] Fix stream error using LiteLLM (#589) In response to issue #587 , I implemented a solution to first check if `refusal` and `usage` attributes exist in the `delta` object. I added a unit test similar to `test_openai_chatcompletions_stream.py`. Let me know if I should change something. --------- Co-authored-by: Rohan Mehta --- src/agents/models/chatcmpl_stream_handler.py | 6 +- .../test_litellm_chatcompletions_stream.py | 286 ++++++++++++++++++ 2 files changed, 290 insertions(+), 2 deletions(-) create mode 100644 tests/models/test_litellm_chatcompletions_stream.py diff --git a/src/agents/models/chatcmpl_stream_handler.py b/src/agents/models/chatcmpl_stream_handler.py index 32f04acb4..c71adeb55 100644 --- a/src/agents/models/chatcmpl_stream_handler.py +++ b/src/agents/models/chatcmpl_stream_handler.py @@ -56,7 +56,8 @@ async def handle_stream( type="response.created", ) - usage = chunk.usage + # This is always set by the OpenAI API, but not by others e.g. LiteLLM + usage = chunk.usage if hasattr(chunk, "usage") else None if not chunk.choices or not chunk.choices[0].delta: continue @@ -112,7 +113,8 @@ async def handle_stream( state.text_content_index_and_output[1].text += delta.content # Handle refusals (model declines to answer) - if delta.refusal: + # This is always set by the OpenAI API, but not by others e.g. LiteLLM + if hasattr(delta, "refusal") and delta.refusal: if not state.refusal_content_index_and_output: # Initialize a content tracker for streaming refusal text state.refusal_content_index_and_output = ( diff --git a/tests/models/test_litellm_chatcompletions_stream.py b/tests/models/test_litellm_chatcompletions_stream.py new file mode 100644 index 000000000..80bd8ea22 --- /dev/null +++ b/tests/models/test_litellm_chatcompletions_stream.py @@ -0,0 +1,286 @@ +from collections.abc import AsyncIterator + +import pytest +from openai.types.chat.chat_completion_chunk import ( + ChatCompletionChunk, + Choice, + ChoiceDelta, + ChoiceDeltaToolCall, + ChoiceDeltaToolCallFunction, +) +from openai.types.completion_usage import CompletionUsage +from openai.types.responses import ( + Response, + ResponseFunctionToolCall, + ResponseOutputMessage, + ResponseOutputRefusal, + ResponseOutputText, +) + +from agents.extensions.models.litellm_model import LitellmModel +from agents.extensions.models.litellm_provider import LitellmProvider +from agents.model_settings import ModelSettings +from agents.models.interface import ModelTracing + + +@pytest.mark.allow_call_model_methods +@pytest.mark.asyncio +async def test_stream_response_yields_events_for_text_content(monkeypatch) -> None: + """ + Validate that `stream_response` emits the correct sequence of events when + streaming a simple assistant message consisting of plain text content. + We simulate two chunks of text returned from the chat completion stream. + """ + # Create two chunks that will be emitted by the fake stream. + chunk1 = ChatCompletionChunk( + id="chunk-id", + created=1, + model="fake", + object="chat.completion.chunk", + choices=[Choice(index=0, delta=ChoiceDelta(content="He"))], + ) + # Mark last chunk with usage so stream_response knows this is final. + chunk2 = ChatCompletionChunk( + id="chunk-id", + created=1, + model="fake", + object="chat.completion.chunk", + choices=[Choice(index=0, delta=ChoiceDelta(content="llo"))], + usage=CompletionUsage(completion_tokens=5, prompt_tokens=7, total_tokens=12), + ) + + async def fake_stream() -> AsyncIterator[ChatCompletionChunk]: + for c in (chunk1, chunk2): + yield c + + # Patch _fetch_response to inject our fake stream + async def patched_fetch_response(self, *args, **kwargs): + # `_fetch_response` is expected to return a Response skeleton and the async stream + resp = Response( + id="resp-id", + created_at=0, + model="fake-model", + object="response", + output=[], + tool_choice="none", + tools=[], + parallel_tool_calls=False, + ) + return resp, fake_stream() + + monkeypatch.setattr(LitellmModel, "_fetch_response", patched_fetch_response) + model = LitellmProvider().get_model("gpt-4") + output_events = [] + async for event in model.stream_response( + system_instructions=None, + input="", + model_settings=ModelSettings(), + tools=[], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + previous_response_id=None, + ): + output_events.append(event) + # We expect a response.created, then a response.output_item.added, content part added, + # two content delta events (for "He" and "llo"), a content part done, the assistant message + # output_item.done, and finally response.completed. + # There should be 8 events in total. + assert len(output_events) == 8 + # First event indicates creation. + assert output_events[0].type == "response.created" + # The output item added and content part added events should mark the assistant message. + assert output_events[1].type == "response.output_item.added" + assert output_events[2].type == "response.content_part.added" + # Two text delta events. + assert output_events[3].type == "response.output_text.delta" + assert output_events[3].delta == "He" + assert output_events[4].type == "response.output_text.delta" + assert output_events[4].delta == "llo" + # After streaming, the content part and item should be marked done. + assert output_events[5].type == "response.content_part.done" + assert output_events[6].type == "response.output_item.done" + # Last event indicates completion of the stream. + assert output_events[7].type == "response.completed" + # The completed response should have one output message with full text. + completed_resp = output_events[7].response + assert isinstance(completed_resp.output[0], ResponseOutputMessage) + assert isinstance(completed_resp.output[0].content[0], ResponseOutputText) + assert completed_resp.output[0].content[0].text == "Hello" + + assert completed_resp.usage, "usage should not be None" + assert completed_resp.usage.input_tokens == 7 + assert completed_resp.usage.output_tokens == 5 + assert completed_resp.usage.total_tokens == 12 + + +@pytest.mark.allow_call_model_methods +@pytest.mark.asyncio +async def test_stream_response_yields_events_for_refusal_content(monkeypatch) -> None: + """ + Validate that when the model streams a refusal string instead of normal content, + `stream_response` emits the appropriate sequence of events including + `response.refusal.delta` events for each chunk of the refusal message and + constructs a completed assistant message with a `ResponseOutputRefusal` part. + """ + # Simulate refusal text coming in two pieces, like content but using the `refusal` + # field on the delta rather than `content`. + chunk1 = ChatCompletionChunk( + id="chunk-id", + created=1, + model="fake", + object="chat.completion.chunk", + choices=[Choice(index=0, delta=ChoiceDelta(refusal="No"))], + ) + chunk2 = ChatCompletionChunk( + id="chunk-id", + created=1, + model="fake", + object="chat.completion.chunk", + choices=[Choice(index=0, delta=ChoiceDelta(refusal="Thanks"))], + usage=CompletionUsage(completion_tokens=2, prompt_tokens=2, total_tokens=4), + ) + + async def fake_stream() -> AsyncIterator[ChatCompletionChunk]: + for c in (chunk1, chunk2): + yield c + + async def patched_fetch_response(self, *args, **kwargs): + resp = Response( + id="resp-id", + created_at=0, + model="fake-model", + object="response", + output=[], + tool_choice="none", + tools=[], + parallel_tool_calls=False, + ) + return resp, fake_stream() + + monkeypatch.setattr(LitellmModel, "_fetch_response", patched_fetch_response) + model = LitellmProvider().get_model("gpt-4") + output_events = [] + async for event in model.stream_response( + system_instructions=None, + input="", + model_settings=ModelSettings(), + tools=[], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + previous_response_id=None, + ): + output_events.append(event) + # Expect sequence similar to text: created, output_item.added, content part added, + # two refusal delta events, content part done, output_item.done, completed. + assert len(output_events) == 8 + assert output_events[0].type == "response.created" + assert output_events[1].type == "response.output_item.added" + assert output_events[2].type == "response.content_part.added" + assert output_events[3].type == "response.refusal.delta" + assert output_events[3].delta == "No" + assert output_events[4].type == "response.refusal.delta" + assert output_events[4].delta == "Thanks" + assert output_events[5].type == "response.content_part.done" + assert output_events[6].type == "response.output_item.done" + assert output_events[7].type == "response.completed" + completed_resp = output_events[7].response + assert isinstance(completed_resp.output[0], ResponseOutputMessage) + refusal_part = completed_resp.output[0].content[0] + assert isinstance(refusal_part, ResponseOutputRefusal) + assert refusal_part.refusal == "NoThanks" + + +@pytest.mark.allow_call_model_methods +@pytest.mark.asyncio +async def test_stream_response_yields_events_for_tool_call(monkeypatch) -> None: + """ + Validate that `stream_response` emits the correct sequence of events when + the model is streaming a function/tool call instead of plain text. + The function call will be split across two chunks. + """ + # Simulate a single tool call whose ID stays constant and function name/args built over chunks. + tool_call_delta1 = ChoiceDeltaToolCall( + index=0, + id="tool-id", + function=ChoiceDeltaToolCallFunction(name="my_", arguments="arg1"), + type="function", + ) + tool_call_delta2 = ChoiceDeltaToolCall( + index=0, + id="tool-id", + function=ChoiceDeltaToolCallFunction(name="func", arguments="arg2"), + type="function", + ) + chunk1 = ChatCompletionChunk( + id="chunk-id", + created=1, + model="fake", + object="chat.completion.chunk", + choices=[Choice(index=0, delta=ChoiceDelta(tool_calls=[tool_call_delta1]))], + ) + chunk2 = ChatCompletionChunk( + id="chunk-id", + created=1, + model="fake", + object="chat.completion.chunk", + choices=[Choice(index=0, delta=ChoiceDelta(tool_calls=[tool_call_delta2]))], + usage=CompletionUsage(completion_tokens=1, prompt_tokens=1, total_tokens=2), + ) + + async def fake_stream() -> AsyncIterator[ChatCompletionChunk]: + for c in (chunk1, chunk2): + yield c + + async def patched_fetch_response(self, *args, **kwargs): + resp = Response( + id="resp-id", + created_at=0, + model="fake-model", + object="response", + output=[], + tool_choice="none", + tools=[], + parallel_tool_calls=False, + ) + return resp, fake_stream() + + monkeypatch.setattr(LitellmModel, "_fetch_response", patched_fetch_response) + model = LitellmProvider().get_model("gpt-4") + output_events = [] + async for event in model.stream_response( + system_instructions=None, + input="", + model_settings=ModelSettings(), + tools=[], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + previous_response_id=None, + ): + output_events.append(event) + # Sequence should be: response.created, then after loop we expect function call-related events: + # one response.output_item.added for function call, a response.function_call_arguments.delta, + # a response.output_item.done, and finally response.completed. + assert output_events[0].type == "response.created" + # The next three events are about the tool call. + assert output_events[1].type == "response.output_item.added" + # The added item should be a ResponseFunctionToolCall. + added_fn = output_events[1].item + assert isinstance(added_fn, ResponseFunctionToolCall) + assert added_fn.name == "my_func" # Name should be concatenation of both chunks. + assert added_fn.arguments == "arg1arg2" + assert output_events[2].type == "response.function_call_arguments.delta" + assert output_events[2].delta == "arg1arg2" + assert output_events[3].type == "response.output_item.done" + assert output_events[4].type == "response.completed" + assert output_events[2].delta == "arg1arg2" + assert output_events[3].type == "response.output_item.done" + assert output_events[4].type == "response.completed" + assert added_fn.name == "my_func" # Name should be concatenation of both chunks. + assert added_fn.arguments == "arg1arg2" + assert output_events[2].type == "response.function_call_arguments.delta" + assert output_events[2].delta == "arg1arg2" + assert output_events[3].type == "response.output_item.done" + assert output_events[4].type == "response.completed" From 45eb41f1e668a45c6b53b64e06fa7db9eab4db46 Mon Sep 17 00:00:00 2001 From: Rohan Mehta Date: Thu, 24 Apr 2025 14:45:03 -0400 Subject: [PATCH 07/33] More tests for cancelling streamed run (#590) --- tests/fake_model.py | 5 +- tests/test_cancel_streaming.py | 98 +++++++++++++++++++++++++++++++++- 2 files changed, 100 insertions(+), 3 deletions(-) diff --git a/tests/fake_model.py b/tests/fake_model.py index c6b3ba924..da3019a0f 100644 --- a/tests/fake_model.py +++ b/tests/fake_model.py @@ -127,7 +127,10 @@ async def stream_response( ) -def get_response_obj(output: list[TResponseOutputItem], response_id: str | None = None) -> Response: +def get_response_obj( + output: list[TResponseOutputItem], + response_id: str | None = None, +) -> Response: return Response( id=response_id or "123", created_at=123, diff --git a/tests/test_cancel_streaming.py b/tests/test_cancel_streaming.py index 6d1807d7c..3417a3c5d 100644 --- a/tests/test_cancel_streaming.py +++ b/tests/test_cancel_streaming.py @@ -1,12 +1,15 @@ +import json + import pytest from agents import Agent, Runner from .fake_model import FakeModel +from .test_responses import get_function_tool, get_function_tool_call, get_text_message @pytest.mark.asyncio -async def test_joker_streamed_jokes_with_cancel(): +async def test_simple_streaming_with_cancel(): model = FakeModel() agent = Agent(name="Joker", model=model) @@ -16,7 +19,98 @@ async def test_joker_streamed_jokes_with_cancel(): async for _event in result.stream_events(): num_events += 1 - if num_events == 1: + if num_events == stop_after: result.cancel() assert num_events == 1, f"Expected {stop_after} visible events, but got {num_events}" + + +@pytest.mark.asyncio +async def test_multiple_events_streaming_with_cancel(): + model = FakeModel() + agent = Agent( + name="Joker", + model=model, + tools=[get_function_tool("foo", "tool_result")], + ) + + model.add_multiple_turn_outputs( + [ + # First turn: a message and tool call + [ + get_text_message("a_message"), + get_function_tool_call("foo", json.dumps({"a": "b"})), + ], + # Second turn: text message + [get_text_message("done")], + ] + ) + + result = Runner.run_streamed(agent, input="Please tell me 5 jokes.") + num_events = 0 + stop_after = 2 + + async for _ in result.stream_events(): + num_events += 1 + if num_events == stop_after: + result.cancel() + + assert num_events == stop_after, f"Expected {stop_after} visible events, but got {num_events}" + + +@pytest.mark.asyncio +async def test_cancel_prevents_further_events(): + model = FakeModel() + agent = Agent(name="Joker", model=model) + result = Runner.run_streamed(agent, input="Please tell me 5 jokes.") + events = [] + async for event in result.stream_events(): + events.append(event) + result.cancel() + break # Cancel after first event + # Try to get more events after cancel + more_events = [e async for e in result.stream_events()] + assert len(events) == 1 + assert more_events == [], "No events should be yielded after cancel()" + + +@pytest.mark.asyncio +async def test_cancel_is_idempotent(): + model = FakeModel() + agent = Agent(name="Joker", model=model) + result = Runner.run_streamed(agent, input="Please tell me 5 jokes.") + events = [] + async for event in result.stream_events(): + events.append(event) + result.cancel() + result.cancel() # Call cancel again + break + # Should not raise or misbehave + assert len(events) == 1 + + +@pytest.mark.asyncio +async def test_cancel_before_streaming(): + model = FakeModel() + agent = Agent(name="Joker", model=model) + result = Runner.run_streamed(agent, input="Please tell me 5 jokes.") + result.cancel() # Cancel before streaming + events = [e async for e in result.stream_events()] + assert events == [], "No events should be yielded if cancel() is called before streaming." + + +@pytest.mark.asyncio +async def test_cancel_cleans_up_resources(): + model = FakeModel() + agent = Agent(name="Joker", model=model) + result = Runner.run_streamed(agent, input="Please tell me 5 jokes.") + # Start streaming, then cancel + async for _ in result.stream_events(): + result.cancel() + break + # After cancel, queues should be empty and is_complete True + assert result.is_complete, "Result should be marked complete after cancel." + assert result._event_queue.empty(), "Event queue should be empty after cancel." + assert result._input_guardrail_queue.empty(), ( + "Input guardrail queue should be empty after cancel." + ) From 3bbc7c48cb9ee80ed4b3dfbbd55efddf7f77d6a3 Mon Sep 17 00:00:00 2001 From: Rohan Mehta Date: Thu, 24 Apr 2025 14:58:38 -0400 Subject: [PATCH 08/33] v0.0.13 (#593) --- pyproject.toml | 2 +- uv.lock | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 12ffff1ff..c1ae467a2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "openai-agents" -version = "0.0.12" +version = "0.0.13" description = "OpenAI Agents SDK" readme = "README.md" requires-python = ">=3.9" diff --git a/uv.lock b/uv.lock index 4c6c370ad..c6824a082 100644 --- a/uv.lock +++ b/uv.lock @@ -1482,7 +1482,7 @@ wheels = [ [[package]] name = "openai-agents" -version = "0.0.12" +version = "0.0.13" source = { editable = "." } dependencies = [ { name = "griffe" }, From 8fd7773a5ef5121d9349edbabebc9522a2f3c4f0 Mon Sep 17 00:00:00 2001 From: Rohan Mehta Date: Thu, 24 Apr 2025 18:20:35 -0400 Subject: [PATCH 09/33] Add usage to context in streaming (#595) --- src/agents/result.py | 16 +++++++--------- src/agents/run.py | 3 +++ tests/fake_model.py | 19 ++++++++++++++++--- tests/test_result_cast.py | 3 ++- 4 files changed, 28 insertions(+), 13 deletions(-) diff --git a/src/agents/result.py b/src/agents/result.py index 1f1c78328..243db155c 100644 --- a/src/agents/result.py +++ b/src/agents/result.py @@ -15,6 +15,7 @@ from .guardrail import InputGuardrailResult, OutputGuardrailResult from .items import ItemHelpers, ModelResponse, RunItem, TResponseInputItem from .logger import logger +from .run_context import RunContextWrapper from .stream_events import StreamEvent from .tracing import Trace from .util._pretty_print import pretty_print_result, pretty_print_run_result_streaming @@ -50,6 +51,9 @@ class RunResultBase(abc.ABC): output_guardrail_results: list[OutputGuardrailResult] """Guardrail results for the final output of the agent.""" + context_wrapper: RunContextWrapper[Any] + """The context wrapper for the agent run.""" + @property @abc.abstractmethod def last_agent(self) -> Agent[Any]: @@ -75,9 +79,7 @@ def final_output_as(self, cls: type[T], raise_if_incorrect_type: bool = False) - def to_input_list(self) -> list[TResponseInputItem]: """Creates a new input list, merging the original input with all the new items generated.""" - original_items: list[TResponseInputItem] = ItemHelpers.input_to_new_input_list( - self.input - ) + original_items: list[TResponseInputItem] = ItemHelpers.input_to_new_input_list(self.input) new_items = [item.to_input_item() for item in self.new_items] return original_items + new_items @@ -206,17 +208,13 @@ async def stream_events(self) -> AsyncIterator[StreamEvent]: def _check_errors(self): if self.current_turn > self.max_turns: - self._stored_exception = MaxTurnsExceeded( - f"Max turns ({self.max_turns}) exceeded" - ) + self._stored_exception = MaxTurnsExceeded(f"Max turns ({self.max_turns}) exceeded") # Fetch all the completed guardrail results from the queue and raise if needed while not self._input_guardrail_queue.empty(): guardrail_result = self._input_guardrail_queue.get_nowait() if guardrail_result.output.tripwire_triggered: - self._stored_exception = InputGuardrailTripwireTriggered( - guardrail_result - ) + self._stored_exception = InputGuardrailTripwireTriggered(guardrail_result) # Check the tasks for any exceptions if self._run_impl_task and self._run_impl_task.done(): diff --git a/src/agents/run.py b/src/agents/run.py index 2af558d58..849da7bfc 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -270,6 +270,7 @@ async def run( _last_agent=current_agent, input_guardrail_results=input_guardrail_results, output_guardrail_results=output_guardrail_results, + context_wrapper=context_wrapper, ) elif isinstance(turn_result.next_step, NextStepHandoff): current_agent = cast(Agent[TContext], turn_result.next_step.new_agent) @@ -423,6 +424,7 @@ def run_streamed( output_guardrail_results=[], _current_agent_output_schema=output_schema, trace=new_trace, + context_wrapper=context_wrapper, ) # Kick off the actual agent loop in the background and return the streamed result object. @@ -696,6 +698,7 @@ async def _run_single_turn_streamed( usage=usage, response_id=event.response.id, ) + context_wrapper.usage.add(usage) streamed_result._event_queue.put_nowait(RawResponsesStreamEvent(data=event)) diff --git a/tests/fake_model.py b/tests/fake_model.py index da3019a0f..32f919ef1 100644 --- a/tests/fake_model.py +++ b/tests/fake_model.py @@ -3,7 +3,8 @@ from collections.abc import AsyncIterator from typing import Any -from openai.types.responses import Response, ResponseCompletedEvent +from openai.types.responses import Response, ResponseCompletedEvent, ResponseUsage +from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails from agents.agent_output import AgentOutputSchemaBase from agents.handoffs import Handoff @@ -33,6 +34,10 @@ def __init__( ) self.tracing_enabled = tracing_enabled self.last_turn_args: dict[str, Any] = {} + self.hardcoded_usage: Usage | None = None + + def set_hardcoded_usage(self, usage: Usage): + self.hardcoded_usage = usage def set_next_output(self, output: list[TResponseOutputItem] | Exception): self.turn_outputs.append(output) @@ -83,7 +88,7 @@ async def get_response( return ModelResponse( output=output, - usage=Usage(), + usage=self.hardcoded_usage or Usage(), response_id=None, ) @@ -123,13 +128,14 @@ async def stream_response( yield ResponseCompletedEvent( type="response.completed", - response=get_response_obj(output), + response=get_response_obj(output, usage=self.hardcoded_usage), ) def get_response_obj( output: list[TResponseOutputItem], response_id: str | None = None, + usage: Usage | None = None, ) -> Response: return Response( id=response_id or "123", @@ -141,4 +147,11 @@ def get_response_obj( tools=[], top_p=None, parallel_tool_calls=False, + usage=ResponseUsage( + input_tokens=usage.input_tokens if usage else 0, + output_tokens=usage.output_tokens if usage else 0, + total_tokens=usage.total_tokens if usage else 0, + input_tokens_details=InputTokensDetails(cached_tokens=0), + output_tokens_details=OutputTokensDetails(reasoning_tokens=0), + ), ) diff --git a/tests/test_result_cast.py b/tests/test_result_cast.py index ec17e3275..c621e7352 100644 --- a/tests/test_result_cast.py +++ b/tests/test_result_cast.py @@ -3,7 +3,7 @@ import pytest from pydantic import BaseModel -from agents import Agent, RunResult +from agents import Agent, RunContextWrapper, RunResult def create_run_result(final_output: Any) -> RunResult: @@ -15,6 +15,7 @@ def create_run_result(final_output: Any) -> RunResult: input_guardrail_results=[], output_guardrail_results=[], _last_agent=Agent(name="test"), + context_wrapper=RunContextWrapper(context=None), ) From aa197e1e1431ec8ad381575d0ea5a9667a9c2fcb Mon Sep 17 00:00:00 2001 From: Stefano Baccianella <4247706+mangiucugna@users.noreply.github.com> Date: Fri, 25 Apr 2025 01:11:25 +0200 Subject: [PATCH 10/33] Make the TTS voices type exportable (#577) When using the voice agent in typed code, it is suboptimal and error prone to type the TTS voice variables in your code independently. With this commit we are making the type exportable so that developers can just use that and be future-proof. Example of usage in code: ``` DEFAULT_TTS_VOICE: TTSModelSettings.TTSVoice = "alloy" ... tts_voice: TTSModelSettings.TTSVoice = DEFAULT_TTS_VOICE ... output = await VoicePipeline( workflow=workflow, config=VoicePipelineConfig( tts_settings=TTSModelSettings( buffer_size=512, transform_data=transform_data, voice=tts_voice, instructions=tts_instructions, )) ).run(audio_input) ``` --------- Co-authored-by: Rohan Mehta --- src/agents/voice/__init__.py | 2 ++ src/agents/voice/model.py | 7 +++---- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/agents/voice/__init__.py b/src/agents/voice/__init__.py index 499c064c5..e11ee4467 100644 --- a/src/agents/voice/__init__.py +++ b/src/agents/voice/__init__.py @@ -7,6 +7,7 @@ STTModelSettings, TTSModel, TTSModelSettings, + TTSVoice, VoiceModelProvider, ) from .models.openai_model_provider import OpenAIVoiceModelProvider @@ -30,6 +31,7 @@ "STTModelSettings", "TTSModel", "TTSModelSettings", + "TTSVoice", "VoiceModelProvider", "StreamedAudioResult", "SingleAgentVoiceWorkflow", diff --git a/src/agents/voice/model.py b/src/agents/voice/model.py index 220d4b480..c36a4de76 100644 --- a/src/agents/voice/model.py +++ b/src/agents/voice/model.py @@ -14,14 +14,13 @@ ) DEFAULT_TTS_BUFFER_SIZE = 120 +TTSVoice = Literal["alloy", "ash", "coral", "echo", "fable", "onyx", "nova", "sage", "shimmer"] +"""Exportable type for the TTSModelSettings voice enum""" @dataclass class TTSModelSettings: """Settings for a TTS model.""" - - voice: ( - Literal["alloy", "ash", "coral", "echo", "fable", "onyx", "nova", "sage", "shimmer"] | None - ) = None + voice: TTSVoice | None = None """ The voice to use for the TTS model. If not provided, the default voice for the respective model will be used. From 4187fba9556683c5ace4706ebd5dbe90cd14e315 Mon Sep 17 00:00:00 2001 From: N V J K Kartik <114015823+NVJKKartik@users.noreply.github.com> Date: Fri, 25 Apr 2025 21:40:59 +0530 Subject: [PATCH 11/33] docs: add FutureAGI to tracing documentation (#592) Hi Team! This PR adds FutureAGI to the tracing documentation as one of the automatic tracing processors for OpenAI agents SDK. ![image](https://github.com/user-attachments/assets/4de3aadc-5efa-4712-8b02-decdedf8f8ef) --- docs/tracing.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/tracing.md b/docs/tracing.md index ea48a2e28..dd883c5aa 100644 --- a/docs/tracing.md +++ b/docs/tracing.md @@ -101,6 +101,7 @@ To customize this default setup, to send traces to alternative or additional bac - [Weights & Biases](https://weave-docs.wandb.ai/guides/integrations/openai_agents) - [Arize-Phoenix](https://docs.arize.com/phoenix/tracing/integrations-tracing/openai-agents-sdk) +- [Future AGI](https://docs.futureagi.com/future-agi/products/observability/auto-instrumentation/openai_agents) - [MLflow (self-hosted/OSS](https://mlflow.org/docs/latest/tracing/integrations/openai-agent) - [MLflow (Databricks hosted](https://docs.databricks.com/aws/en/mlflow/mlflow-tracing#-automatic-tracing) - [Braintrust](https://braintrust.dev/docs/guides/traces/integrations#openai-agents-sdk) From db0ee9d5a5e7c24adfcc0a98a94540341e43c6e5 Mon Sep 17 00:00:00 2001 From: pakrym-oai Date: Tue, 29 Apr 2025 16:12:24 -0700 Subject: [PATCH 12/33] Update litellm version (#626) Addresses https://github.com/openai/openai-agents-python/issues/614 --- pyproject.toml | 2 +- uv.lock | 9 +++------ 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c1ae467a2..d5600976b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,7 @@ Repository = "https://github.com/openai/openai-agents-python" [project.optional-dependencies] voice = ["numpy>=2.2.0, <3; python_version>='3.10'", "websockets>=15.0, <16"] viz = ["graphviz>=0.17"] -litellm = ["litellm>=1.65.0, <2"] +litellm = ["litellm>=1.67.4.post1, <2"] [dependency-groups] dev = [ diff --git a/uv.lock b/uv.lock index c6824a082..636dbdd12 100644 --- a/uv.lock +++ b/uv.lock @@ -928,7 +928,7 @@ wheels = [ [[package]] name = "litellm" -version = "1.66.1" +version = "1.67.4.post1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aiohttp" }, @@ -943,10 +943,7 @@ dependencies = [ { name = "tiktoken" }, { name = "tokenizers" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c1/21/12562c37310254456afdd277454dac4d14b8b40796216e8a438a9e1c5e86/litellm-1.66.1.tar.gz", hash = "sha256:98f7add913e5eae2131dd412ee27532d9a309defd9dbb64f6c6c42ea8a2af068", size = 7203211 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/e9/33/fdc4615ca621940406e3b0b303e900bc2868cfcd8c62c4a6f5e7d2f6a56c/litellm-1.66.1-py3-none-any.whl", hash = "sha256:1f601fea3f086c1d2d91be60b9db115082a2f3a697e4e0def72f8b9c777c7232", size = 7559553 }, -] +sdist = { url = "https://files.pythonhosted.org/packages/4d/89/bacf75633dd43d6c5536380fb652c4af25046c29f5c6e5fdb4e8fe5af505/litellm-1.67.4.post1.tar.gz", hash = "sha256:057f2505f82d8c3f83d705c375b0d1931de998b13e239a6b06e16ee351fda648", size = 7243930 } [[package]] name = "markdown" @@ -1535,7 +1532,7 @@ dev = [ requires-dist = [ { name = "graphviz", marker = "extra == 'viz'", specifier = ">=0.17" }, { name = "griffe", specifier = ">=1.5.6,<2" }, - { name = "litellm", marker = "extra == 'litellm'", specifier = ">=1.65.0,<2" }, + { name = "litellm", marker = "extra == 'litellm'", specifier = ">=1.67.4.post1,<2" }, { name = "mcp", marker = "python_full_version >= '3.10'", specifier = ">=1.6.0,<2" }, { name = "numpy", marker = "python_full_version >= '3.10' and extra == 'voice'", specifier = ">=2.2.0,<3" }, { name = "openai", specifier = ">=1.76.0" }, From f9763495b86afcf0c421451a92200e1141fa8dcb Mon Sep 17 00:00:00 2001 From: pakrym-oai Date: Wed, 30 Apr 2025 08:15:35 -0700 Subject: [PATCH 13/33] 0.0.14 release (#635) --- pyproject.toml | 2 +- uv.lock | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d5600976b..22b028ae7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "openai-agents" -version = "0.0.13" +version = "0.0.14" description = "OpenAI Agents SDK" readme = "README.md" requires-python = ">=3.9" diff --git a/uv.lock b/uv.lock index 636dbdd12..87dd3cd2c 100644 --- a/uv.lock +++ b/uv.lock @@ -1479,7 +1479,7 @@ wheels = [ [[package]] name = "openai-agents" -version = "0.0.13" +version = "0.0.14" source = { editable = "." } dependencies = [ { name = "griffe" }, From 2c46dae37787e36f68db954f37addf2ddaa69458 Mon Sep 17 00:00:00 2001 From: Daniele Morotti <58258368+DanieleMorotti@users.noreply.github.com> Date: Wed, 14 May 2025 17:31:42 +0200 Subject: [PATCH 14/33] Fixed a bug for "detail" attribute in input image (#685) When an input image is given as input, the code tries to access the 'detail' key, that may not be present as noted in #159. With this pull request, now it tries to access the key, otherwise set the value to `None`. @pakrym-oai or @rm-openai let me know if you want any changes. --- src/agents/models/chatcmpl_converter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/agents/models/chatcmpl_converter.py b/src/agents/models/chatcmpl_converter.py index 613a37453..1d599e8c0 100644 --- a/src/agents/models/chatcmpl_converter.py +++ b/src/agents/models/chatcmpl_converter.py @@ -234,7 +234,7 @@ def extract_all_content( type="image_url", image_url={ "url": casted_image_param["image_url"], - "detail": casted_image_param["detail"], + "detail": casted_image_param.get("detail", "auto"), }, ) ) From 1994f9d4c4bb807fd5460919b4352b715ef1324a Mon Sep 17 00:00:00 2001 From: Ashok Saravanan <90977640+AshokSaravanan222@users.noreply.github.com> Date: Wed, 14 May 2025 11:34:27 -0500 Subject: [PATCH 15/33] feat: pass extra_body through to LiteLLM acompletion (#638) **Purpose** Allow arbitrary `extra_body` parameters (e.g. `cached_content`) to be forwarded into the LiteLLM call. Useful for context caching in Gemini models ([docs](https://ai.google.dev/gemini-api/docs/caching?lang=python)). **Example usage** ```python import os from agents import Agent, ModelSettings from agents.extensions.models.litellm_model import LitellmModel cache_name = "cachedContents/34jopukfx5di" # previously stored context gemini_model = LitellmModel( model="gemini/gemini-1.5-flash-002", api_key=os.getenv("GOOGLE_API_KEY") ) agent = Agent( name="Cached Gemini Agent", model=gemini_model, model_settings=ModelSettings( extra_body={"cached_content": cache_name} ) ) --- src/agents/extensions/models/litellm_model.py | 2 + tests/models/test_litellm_extra_body.py | 45 +++++++++++++++++++ 2 files changed, 47 insertions(+) create mode 100644 tests/models/test_litellm_extra_body.py diff --git a/src/agents/extensions/models/litellm_model.py b/src/agents/extensions/models/litellm_model.py index dc672acd4..d3b25a198 100644 --- a/src/agents/extensions/models/litellm_model.py +++ b/src/agents/extensions/models/litellm_model.py @@ -269,6 +269,8 @@ async def _fetch_response( extra_kwargs["extra_query"] = model_settings.extra_query if model_settings.metadata: extra_kwargs["metadata"] = model_settings.metadata + if model_settings.extra_body and isinstance(model_settings.extra_body, dict): + extra_kwargs.update(model_settings.extra_body) ret = await litellm.acompletion( model=self.model, diff --git a/tests/models/test_litellm_extra_body.py b/tests/models/test_litellm_extra_body.py new file mode 100644 index 000000000..ac56c25cf --- /dev/null +++ b/tests/models/test_litellm_extra_body.py @@ -0,0 +1,45 @@ +import litellm +import pytest +from litellm.types.utils import Choices, Message, ModelResponse, Usage + +from agents.extensions.models.litellm_model import LitellmModel +from agents.model_settings import ModelSettings +from agents.models.interface import ModelTracing + + +@pytest.mark.allow_call_model_methods +@pytest.mark.asyncio +async def test_extra_body_is_forwarded(monkeypatch): + """ + Forward `extra_body` entries into litellm.acompletion kwargs. + + This ensures that user-provided parameters (e.g. cached_content) + arrive alongside default arguments. + """ + captured: dict[str, object] = {} + + async def fake_acompletion(model, messages=None, **kwargs): + captured.update(kwargs) + msg = Message(role="assistant", content="ok") + choice = Choices(index=0, message=msg) + return ModelResponse(choices=[choice], usage=Usage(0, 0, 0)) + + monkeypatch.setattr(litellm, "acompletion", fake_acompletion) + settings = ModelSettings( + temperature=0.1, + extra_body={"cached_content": "some_cache", "foo": 123} + ) + model = LitellmModel(model="test-model") + + await model.get_response( + system_instructions=None, + input=[], + model_settings=settings, + tools=[], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + previous_response_id=None, + ) + + assert {"cached_content": "some_cache", "foo": 123}.items() <= captured.items() From 02b6e7013cf00de92b453610f459be40f344e6a8 Mon Sep 17 00:00:00 2001 From: leohpark <135409779+leohpark@users.noreply.github.com> Date: Wed, 14 May 2025 09:37:06 -0700 Subject: [PATCH 16/33] Update search_agent.py (#677) Added missing word "be" in prompt instructions. This is unlikely to change the agent functionality in most cases, but optimal clarity in prompt language is a best practice. --- examples/research_bot/agents/search_agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/research_bot/agents/search_agent.py b/examples/research_bot/agents/search_agent.py index 0212ce5b5..61f91701f 100644 --- a/examples/research_bot/agents/search_agent.py +++ b/examples/research_bot/agents/search_agent.py @@ -3,7 +3,7 @@ INSTRUCTIONS = ( "You are a research assistant. Given a search term, you search the web for that term and " - "produce a concise summary of the results. The summary must 2-3 paragraphs and less than 300 " + "produce a concise summary of the results. The summary must be 2-3 paragraphs and less than 300 " "words. Capture the main points. Write succinctly, no need to have complete sentences or good " "grammar. This will be consumed by someone synthesizing a report, so its vital you capture the " "essence and ignore any fluff. Do not include any additional commentary other than the summary " From 1847008e0f247e5468d691e4e9d780f984db9a18 Mon Sep 17 00:00:00 2001 From: Akshit97 Date: Thu, 15 May 2025 00:15:14 +0530 Subject: [PATCH 17/33] feat: Streamable HTTP support (#643) Co-authored-by: aagarwal25 --- examples/mcp/streamablehttp_example/README.md | 13 +++ examples/mcp/streamablehttp_example/main.py | 83 ++++++++++++++ examples/mcp/streamablehttp_example/server.py | 33 ++++++ pyproject.toml | 2 +- src/agents/mcp/__init__.py | 4 + src/agents/mcp/server.py | 106 ++++++++++++++++-- uv.lock | 20 +++- 7 files changed, 247 insertions(+), 14 deletions(-) create mode 100644 examples/mcp/streamablehttp_example/README.md create mode 100644 examples/mcp/streamablehttp_example/main.py create mode 100644 examples/mcp/streamablehttp_example/server.py diff --git a/examples/mcp/streamablehttp_example/README.md b/examples/mcp/streamablehttp_example/README.md new file mode 100644 index 000000000..a07fe19be --- /dev/null +++ b/examples/mcp/streamablehttp_example/README.md @@ -0,0 +1,13 @@ +# MCP Streamable HTTP Example + +This example uses a local Streamable HTTP server in [server.py](server.py). + +Run the example via: + +``` +uv run python examples/mcp/streamablehttp_example/main.py +``` + +## Details + +The example uses the `MCPServerStreamableHttp` class from `agents.mcp`. The server runs in a sub-process at `https://localhost:8000/mcp`. diff --git a/examples/mcp/streamablehttp_example/main.py b/examples/mcp/streamablehttp_example/main.py new file mode 100644 index 000000000..cc95e798b --- /dev/null +++ b/examples/mcp/streamablehttp_example/main.py @@ -0,0 +1,83 @@ +import asyncio +import os +import shutil +import subprocess +import time +from typing import Any + +from agents import Agent, Runner, gen_trace_id, trace +from agents.mcp import MCPServer, MCPServerStreamableHttp +from agents.model_settings import ModelSettings + + +async def run(mcp_server: MCPServer): + agent = Agent( + name="Assistant", + instructions="Use the tools to answer the questions.", + mcp_servers=[mcp_server], + model_settings=ModelSettings(tool_choice="required"), + ) + + # Use the `add` tool to add two numbers + message = "Add these numbers: 7 and 22." + print(f"Running: {message}") + result = await Runner.run(starting_agent=agent, input=message) + print(result.final_output) + + # Run the `get_weather` tool + message = "What's the weather in Tokyo?" + print(f"\n\nRunning: {message}") + result = await Runner.run(starting_agent=agent, input=message) + print(result.final_output) + + # Run the `get_secret_word` tool + message = "What's the secret word?" + print(f"\n\nRunning: {message}") + result = await Runner.run(starting_agent=agent, input=message) + print(result.final_output) + + +async def main(): + async with MCPServerStreamableHttp( + name="Streamable HTTP Python Server", + params={ + "url": "http://localhost:8000/mcp", + }, + ) as server: + trace_id = gen_trace_id() + with trace(workflow_name="Streamable HTTP Example", trace_id=trace_id): + print(f"View trace: https://platform.openai.com/traces/trace?trace_id={trace_id}\n") + await run(server) + + +if __name__ == "__main__": + # Let's make sure the user has uv installed + if not shutil.which("uv"): + raise RuntimeError( + "uv is not installed. Please install it: https://docs.astral.sh/uv/getting-started/installation/" + ) + + # We'll run the Streamable HTTP server in a subprocess. Usually this would be a remote server, but for this + # demo, we'll run it locally at http://localhost:8000/mcp + process: subprocess.Popen[Any] | None = None + try: + this_dir = os.path.dirname(os.path.abspath(__file__)) + server_file = os.path.join(this_dir, "server.py") + + print("Starting Streamable HTTP server at http://localhost:8000/mcp ...") + + # Run `uv run server.py` to start the Streamable HTTP server + process = subprocess.Popen(["uv", "run", server_file]) + # Give it 3 seconds to start + time.sleep(3) + + print("Streamable HTTP server started. Running example...\n\n") + except Exception as e: + print(f"Error starting Streamable HTTP server: {e}") + exit(1) + + try: + asyncio.run(main()) + finally: + if process: + process.terminate() diff --git a/examples/mcp/streamablehttp_example/server.py b/examples/mcp/streamablehttp_example/server.py new file mode 100644 index 000000000..d8f839652 --- /dev/null +++ b/examples/mcp/streamablehttp_example/server.py @@ -0,0 +1,33 @@ +import random + +import requests +from mcp.server.fastmcp import FastMCP + +# Create server +mcp = FastMCP("Echo Server") + + +@mcp.tool() +def add(a: int, b: int) -> int: + """Add two numbers""" + print(f"[debug-server] add({a}, {b})") + return a + b + + +@mcp.tool() +def get_secret_word() -> str: + print("[debug-server] get_secret_word()") + return random.choice(["apple", "banana", "cherry"]) + + +@mcp.tool() +def get_current_weather(city: str) -> str: + print(f"[debug-server] get_current_weather({city})") + + endpoint = "https://wttr.in" + response = requests.get(f"{endpoint}/{city}") + return response.text + + +if __name__ == "__main__": + mcp.run(transport="streamable-http") diff --git a/pyproject.toml b/pyproject.toml index 22b028ae7..87a707d3d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ dependencies = [ "typing-extensions>=4.12.2, <5", "requests>=2.0, <3", "types-requests>=2.0, <3", - "mcp>=1.6.0, <2; python_version >= '3.10'", + "mcp>=1.8.0, <2; python_version >= '3.10'", ] classifiers = [ "Typing :: Typed", diff --git a/src/agents/mcp/__init__.py b/src/agents/mcp/__init__.py index 1a72a89f0..d4eb8fa68 100644 --- a/src/agents/mcp/__init__.py +++ b/src/agents/mcp/__init__.py @@ -5,6 +5,8 @@ MCPServerSseParams, MCPServerStdio, MCPServerStdioParams, + MCPServerStreamableHttp, + MCPServerStreamableHttpParams, ) except ImportError: pass @@ -17,5 +19,7 @@ "MCPServerSseParams", "MCPServerStdio", "MCPServerStdioParams", + "MCPServerStreamableHttp", + "MCPServerStreamableHttpParams", "MCPUtil", ] diff --git a/src/agents/mcp/server.py b/src/agents/mcp/server.py index 9916c92b0..c5255ead7 100644 --- a/src/agents/mcp/server.py +++ b/src/agents/mcp/server.py @@ -10,7 +10,9 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from mcp import ClientSession, StdioServerParameters, Tool as MCPTool, stdio_client from mcp.client.sse import sse_client -from mcp.types import CallToolResult, JSONRPCMessage +from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client +from mcp.shared.message import SessionMessage +from mcp.types import CallToolResult from typing_extensions import NotRequired, TypedDict from ..exceptions import UserError @@ -83,8 +85,9 @@ def create_streams( self, ) -> AbstractAsyncContextManager[ tuple[ - MemoryObjectReceiveStream[JSONRPCMessage | Exception], - MemoryObjectSendStream[JSONRPCMessage], + MemoryObjectReceiveStream[SessionMessage | Exception], + MemoryObjectSendStream[SessionMessage], + GetSessionIdCallback | None ] ]: """Create the streams for the server.""" @@ -105,7 +108,11 @@ async def connect(self): """Connect to the server.""" try: transport = await self.exit_stack.enter_async_context(self.create_streams()) - read, write = transport + # streamablehttp_client returns (read, write, get_session_id) + # sse_client returns (read, write) + + read, write, *_ = transport + session = await self.exit_stack.enter_async_context( ClientSession( read, @@ -232,8 +239,9 @@ def create_streams( self, ) -> AbstractAsyncContextManager[ tuple[ - MemoryObjectReceiveStream[JSONRPCMessage | Exception], - MemoryObjectSendStream[JSONRPCMessage], + MemoryObjectReceiveStream[SessionMessage | Exception], + MemoryObjectSendStream[SessionMessage], + GetSessionIdCallback | None ] ]: """Create the streams for the server.""" @@ -302,8 +310,9 @@ def create_streams( self, ) -> AbstractAsyncContextManager[ tuple[ - MemoryObjectReceiveStream[JSONRPCMessage | Exception], - MemoryObjectSendStream[JSONRPCMessage], + MemoryObjectReceiveStream[SessionMessage | Exception], + MemoryObjectSendStream[SessionMessage], + GetSessionIdCallback | None ] ]: """Create the streams for the server.""" @@ -318,3 +327,84 @@ def create_streams( def name(self) -> str: """A readable name for the server.""" return self._name + + +class MCPServerStreamableHttpParams(TypedDict): + """Mirrors the params in`mcp.client.streamable_http.streamablehttp_client`.""" + + url: str + """The URL of the server.""" + + headers: NotRequired[dict[str, str]] + """The headers to send to the server.""" + + timeout: NotRequired[timedelta] + """The timeout for the HTTP request. Defaults to 5 seconds.""" + + sse_read_timeout: NotRequired[timedelta] + """The timeout for the SSE connection, in seconds. Defaults to 5 minutes.""" + + terminate_on_close: NotRequired[bool] + """Terminate on close""" + + +class MCPServerStreamableHttp(_MCPServerWithClientSession): + """MCP server implementation that uses the Streamable HTTP transport. See the [spec] + (https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#streamable-http) + for details. + """ + + def __init__( + self, + params: MCPServerStreamableHttpParams, + cache_tools_list: bool = False, + name: str | None = None, + client_session_timeout_seconds: float | None = 5, + ): + """Create a new MCP server based on the Streamable HTTP transport. + + Args: + params: The params that configure the server. This includes the URL of the server, + the headers to send to the server, the timeout for the HTTP request, and the + timeout for the Streamable HTTP connection and whether we need to + terminate on close. + + cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be + cached and only fetched from the server once. If `False`, the tools list will be + fetched from the server on each call to `list_tools()`. The cache can be + invalidated by calling `invalidate_tools_cache()`. You should set this to `True` + if you know the server will not change its tools list, because it can drastically + improve latency (by avoiding a round-trip to the server every time). + + name: A readable name for the server. If not provided, we'll create one from the + URL. + + client_session_timeout_seconds: the read timeout passed to the MCP ClientSession. + """ + super().__init__(cache_tools_list, client_session_timeout_seconds) + + self.params = params + self._name = name or f"streamable_http: {self.params['url']}" + + def create_streams( + self, + ) -> AbstractAsyncContextManager[ + tuple[ + MemoryObjectReceiveStream[SessionMessage | Exception], + MemoryObjectSendStream[SessionMessage], + GetSessionIdCallback | None + ] + ]: + """Create the streams for the server.""" + return streamablehttp_client( + url=self.params["url"], + headers=self.params.get("headers", None), + timeout=self.params.get("timeout", timedelta(seconds=30)), + sse_read_timeout=self.params.get("sse_read_timeout", timedelta(seconds=60 * 5)), + terminate_on_close=self.params.get("terminate_on_close", True) + ) + + @property + def name(self) -> str: + """A readable name for the server.""" + return self._name diff --git a/uv.lock b/uv.lock index 87dd3cd2c..6ccc19966 100644 --- a/uv.lock +++ b/uv.lock @@ -1047,7 +1047,7 @@ wheels = [ [[package]] name = "mcp" -version = "1.6.0" +version = "1.8.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio", marker = "python_full_version >= '3.10'" }, @@ -1055,13 +1055,14 @@ dependencies = [ { name = "httpx-sse", marker = "python_full_version >= '3.10'" }, { name = "pydantic", marker = "python_full_version >= '3.10'" }, { name = "pydantic-settings", marker = "python_full_version >= '3.10'" }, + { name = "python-multipart", marker = "python_full_version >= '3.10'" }, { name = "sse-starlette", marker = "python_full_version >= '3.10'" }, { name = "starlette", marker = "python_full_version >= '3.10'" }, - { name = "uvicorn", marker = "python_full_version >= '3.10'" }, + { name = "uvicorn", marker = "python_full_version >= '3.10' and sys_platform != 'emscripten'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/95/d2/f587cb965a56e992634bebc8611c5b579af912b74e04eb9164bd49527d21/mcp-1.6.0.tar.gz", hash = "sha256:d9324876de2c5637369f43161cd71eebfd803df5a95e46225cab8d280e366723", size = 200031 } +sdist = { url = "https://files.pythonhosted.org/packages/7c/13/16b712e8a3be6a736b411df2fc6b4e75eb1d3e99b1cd57a3a1decf17f612/mcp-1.8.1.tar.gz", hash = "sha256:ec0646271d93749f784d2316fb5fe6102fb0d1be788ec70a9e2517e8f2722c0e", size = 265605 } wheels = [ - { url = "https://files.pythonhosted.org/packages/10/30/20a7f33b0b884a9d14dd3aa94ff1ac9da1479fe2ad66dd9e2736075d2506/mcp-1.6.0-py3-none-any.whl", hash = "sha256:7bd24c6ea042dbec44c754f100984d186620d8b841ec30f1b19eda9b93a634d0", size = 76077 }, + { url = "https://files.pythonhosted.org/packages/1c/5d/91cf0d40e40ae9ecf8d4004e0f9611eea86085aa0b5505493e0ff53972da/mcp-1.8.1-py3-none-any.whl", hash = "sha256:948e03783859fa35abe05b9b6c0a1d5519be452fc079dc8d7f682549591c1770", size = 119761 }, ] [[package]] @@ -1533,7 +1534,7 @@ requires-dist = [ { name = "graphviz", marker = "extra == 'viz'", specifier = ">=0.17" }, { name = "griffe", specifier = ">=1.5.6,<2" }, { name = "litellm", marker = "extra == 'litellm'", specifier = ">=1.67.4.post1,<2" }, - { name = "mcp", marker = "python_full_version >= '3.10'", specifier = ">=1.6.0,<2" }, + { name = "mcp", marker = "python_full_version >= '3.10'", specifier = ">=1.8.0,<2" }, { name = "numpy", marker = "python_full_version >= '3.10' and extra == 'voice'", specifier = ">=2.2.0,<3" }, { name = "openai", specifier = ">=1.76.0" }, { name = "pydantic", specifier = ">=2.10,<3" }, @@ -2085,6 +2086,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1e/18/98a99ad95133c6a6e2005fe89faedf294a748bd5dc803008059409ac9b1e/python_dotenv-1.1.0-py3-none-any.whl", hash = "sha256:d7c01d9e2293916c18baf562d95698754b0dbbb5e74d457c45d4f6561fb9d55d", size = 20256 }, ] +[[package]] +name = "python-multipart" +version = "0.0.20" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f3/87/f44d7c9f274c7ee665a29b885ec97089ec5dc034c7f3fafa03da9e39a09e/python_multipart-0.0.20.tar.gz", hash = "sha256:8dd0cab45b8e23064ae09147625994d090fa46f5b0d1e13af944c331a7fa9d13", size = 37158 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/45/58/38b5afbc1a800eeea951b9285d3912613f2603bdf897a4ab0f4bd7f405fc/python_multipart-0.0.20-py3-none-any.whl", hash = "sha256:8a62d3a8335e06589fe01f2a3e178cdcc632f3fbe0d492ad9ee0ec35aab1f104", size = 24546 }, +] + [[package]] name = "python-xlib" version = "0.33" From 5fe096df67cf7ada53f0bd38954b7f12743b790b Mon Sep 17 00:00:00 2001 From: Rohan Mehta Date: Thu, 15 May 2025 18:41:00 -0400 Subject: [PATCH 18/33] v0.0.15 (#701) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 87a707d3d..672258c42 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "openai-agents" -version = "0.0.14" +version = "0.0.15" description = "OpenAI Agents SDK" readme = "README.md" requires-python = ">=3.9" From c282324d9519b037ced8abe3678c86f7dd344a97 Mon Sep 17 00:00:00 2001 From: Dominik Kundel Date: Sun, 18 May 2025 10:24:24 -0700 Subject: [PATCH 19/33] Create AGENTS.md (#707) Adding an AGENTS.md file for Codex use --- AGENTS.md | 69 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) create mode 100644 AGENTS.md diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 000000000..ff37db326 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,69 @@ +Welcome to the OpenAI Agents SDK repository. This file contains the main points for new contributors. + +## Repository overview + +- **Source code**: `src/agents/` contains the implementation. +- **Tests**: `tests/` with a short guide in `tests/README.md`. +- **Examples**: under `examples/`. +- **Documentation**: markdown pages live in `docs/` with `mkdocs.yml` controlling the site. +- **Utilities**: developer commands are defined in the `Makefile`. +- **PR template**: `.github/PULL_REQUEST_TEMPLATE/pull_request_template.md` describes the information every PR must include. + +## Local workflow + +1. Format, lint and type‑check your changes: + + ```bash + make format + make lint + make mypy + ``` + +2. Run the tests: + + ```bash + make tests + ``` + + To run a single test, use `uv run pytest -s -k `. + +3. Build the documentation (optional but recommended for docs changes): + + ```bash + make build-docs + ``` + + Coverage can be generated with `make coverage`. + +## Snapshot tests + +Some tests rely on inline snapshots. See `tests/README.md` for details on updating them: + +```bash +make snapshots-fix # update existing snapshots +make snapshots-create # create new snapshots +``` + +Run `make tests` again after updating snapshots to ensure they pass. + +## Style notes + +- Write comments as full sentences and end them with a period. + +## Pull request expectations + +PRs should use the template located at `.github/PULL_REQUEST_TEMPLATE/pull_request_template.md`. Provide a summary, test plan and issue number if applicable, then check that: + +- New tests are added when needed. +- Documentation is updated. +- `make lint` and `make format` have been run. +- The full test suite passes. + +Commit messages should be concise and written in the imperative mood. Small, focused commits are preferred. + +## What reviewers look for + +- Tests covering new behaviour. +- Consistent style: code formatted with `ruff format`, imports sorted, and type hints passing `mypy`. +- Clear documentation for any public API changes. +- Clean history and a helpful PR description. From 003cbfe5f5820cd73ad3adfbae56c054e3cb73ca Mon Sep 17 00:00:00 2001 From: Daniele Morotti <58258368+DanieleMorotti@users.noreply.github.com> Date: Sun, 18 May 2025 19:25:08 +0200 Subject: [PATCH 20/33] Added mcp 'instructions' attribute to the server (#706) Added the `instructions` attribute to the MCP servers to solve #704 . Let me know if you want to add an example to the documentation. --- src/agents/mcp/server.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/agents/mcp/server.py b/src/agents/mcp/server.py index c5255ead7..414b517ab 100644 --- a/src/agents/mcp/server.py +++ b/src/agents/mcp/server.py @@ -12,7 +12,7 @@ from mcp.client.sse import sse_client from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client from mcp.shared.message import SessionMessage -from mcp.types import CallToolResult +from mcp.types import CallToolResult, InitializeResult from typing_extensions import NotRequired, TypedDict from ..exceptions import UserError @@ -73,6 +73,7 @@ def __init__(self, cache_tools_list: bool, client_session_timeout_seconds: float self.exit_stack: AsyncExitStack = AsyncExitStack() self._cleanup_lock: asyncio.Lock = asyncio.Lock() self.cache_tools_list = cache_tools_list + self.server_initialize_result: InitializeResult | None = None self.client_session_timeout_seconds = client_session_timeout_seconds @@ -122,7 +123,8 @@ async def connect(self): else None, ) ) - await session.initialize() + server_result = await session.initialize() + self.server_initialize_result = server_result self.session = session except Exception as e: logger.error(f"Error initializing MCP server: {e}") From 428c9a65bf0c17198d9a2a616159b9eb8badb2b6 Mon Sep 17 00:00:00 2001 From: franz101 Date: Mon, 19 May 2025 21:41:11 +0200 Subject: [PATCH 21/33] Add Galileo to external tracing processors list (#662) --- docs/tracing.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/tracing.md b/docs/tracing.md index dd883c5aa..4a9c1bd90 100644 --- a/docs/tracing.md +++ b/docs/tracing.md @@ -115,3 +115,4 @@ To customize this default setup, to send traces to alternative or additional bac - [Langfuse](https://langfuse.com/docs/integrations/openaiagentssdk/openai-agents) - [Langtrace](https://docs.langtrace.ai/supported-integrations/llm-frameworks/openai-agents-sdk) - [Okahu-Monocle](https://github.com/monocle2ai/monocle) +- [Galileo](https://v2docs.galileo.ai/integrations/openai-agent-integration#openai-agent-integration) From 466b44df180718a5d53c45293db2f57b6e719f95 Mon Sep 17 00:00:00 2001 From: WJPBProjects <76624567+WJPBProjects@users.noreply.github.com> Date: Tue, 20 May 2025 18:23:56 +0100 Subject: [PATCH 22/33] Dev/add usage details to Usage class (#726) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit PR to enhance the `Usage` object and related logic, to support more granular token accounting, matching the details available in the [OpenAI Responses API](https://platform.openai.com/docs/api-reference/responses) . Specifically, it: - Adds `input_tokens_details` and `output_tokens_details` fields to the `Usage` dataclass, storing detailed token breakdowns (e.g., `cached_tokens`, `reasoning_tokens`). - Flows this change through - Updates and extends tests to match - Adds a test for the Usage.add method ### Motivation - Aligns the SDK’s usage with the latest OpenAI responses API Usage object - Supports downstream use cases that require fine-grained token usage data (e.g., billing, analytics, optimization) requested by startups --------- Co-authored-by: Wulfie Bain --- src/agents/extensions/models/litellm_model.py | 11 ++++ src/agents/models/openai_chatcompletions.py | 15 +++++- src/agents/models/openai_responses.py | 2 + src/agents/run.py | 2 + src/agents/usage.py | 22 +++++++- .../test_litellm_chatcompletions_stream.py | 16 +++++- tests/test_extra_headers.py | 20 ++++--- tests/test_openai_chatcompletions.py | 17 +++++- tests/test_openai_chatcompletions_stream.py | 16 +++++- tests/test_responses_tracing.py | 20 ++++++- tests/test_usage.py | 52 +++++++++++++++++++ 11 files changed, 178 insertions(+), 15 deletions(-) create mode 100644 tests/test_usage.py diff --git a/src/agents/extensions/models/litellm_model.py b/src/agents/extensions/models/litellm_model.py index d3b25a198..ffb2c3c1c 100644 --- a/src/agents/extensions/models/litellm_model.py +++ b/src/agents/extensions/models/litellm_model.py @@ -6,6 +6,7 @@ from typing import Any, Literal, cast, overload import litellm.types +from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails from agents.exceptions import ModelBehaviorError @@ -107,6 +108,16 @@ async def get_response( input_tokens=response_usage.prompt_tokens, output_tokens=response_usage.completion_tokens, total_tokens=response_usage.total_tokens, + input_tokens_details=InputTokensDetails( + cached_tokens=getattr( + response_usage.prompt_tokens_details, "cached_tokens", 0 + ) + ), + output_tokens_details=OutputTokensDetails( + reasoning_tokens=getattr( + response_usage.completion_tokens_details, "reasoning_tokens", 0 + ) + ), ) if response.usage else Usage() diff --git a/src/agents/models/openai_chatcompletions.py b/src/agents/models/openai_chatcompletions.py index 89619f838..4465ff2fd 100644 --- a/src/agents/models/openai_chatcompletions.py +++ b/src/agents/models/openai_chatcompletions.py @@ -9,6 +9,7 @@ from openai.types import ChatModel from openai.types.chat import ChatCompletion, ChatCompletionChunk from openai.types.responses import Response +from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails from .. import _debug from ..agent_output import AgentOutputSchemaBase @@ -83,6 +84,18 @@ async def get_response( input_tokens=response.usage.prompt_tokens, output_tokens=response.usage.completion_tokens, total_tokens=response.usage.total_tokens, + input_tokens_details=InputTokensDetails( + cached_tokens=getattr( + response.usage.prompt_tokens_details, "cached_tokens", 0 + ) + or 0, + ), + output_tokens_details=OutputTokensDetails( + reasoning_tokens=getattr( + response.usage.completion_tokens_details, "reasoning_tokens", 0 + ) + or 0, + ), ) if response.usage else Usage() @@ -252,7 +265,7 @@ async def _fetch_response( stream_options=self._non_null_or_not_given(stream_options), store=self._non_null_or_not_given(store), reasoning_effort=self._non_null_or_not_given(reasoning_effort), - extra_headers={ **HEADERS, **(model_settings.extra_headers or {}) }, + extra_headers={**HEADERS, **(model_settings.extra_headers or {})}, extra_query=model_settings.extra_query, extra_body=model_settings.extra_body, metadata=self._non_null_or_not_given(model_settings.metadata), diff --git a/src/agents/models/openai_responses.py b/src/agents/models/openai_responses.py index c1ff85b98..6ec8f8f7b 100644 --- a/src/agents/models/openai_responses.py +++ b/src/agents/models/openai_responses.py @@ -98,6 +98,8 @@ async def get_response( input_tokens=response.usage.input_tokens, output_tokens=response.usage.output_tokens, total_tokens=response.usage.total_tokens, + input_tokens_details=response.usage.input_tokens_details, + output_tokens_details=response.usage.output_tokens_details, ) if response.usage else Usage() diff --git a/src/agents/run.py b/src/agents/run.py index 849da7bfc..b196c3bf1 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -689,6 +689,8 @@ async def _run_single_turn_streamed( input_tokens=event.response.usage.input_tokens, output_tokens=event.response.usage.output_tokens, total_tokens=event.response.usage.total_tokens, + input_tokens_details=event.response.usage.input_tokens_details, + output_tokens_details=event.response.usage.output_tokens_details, ) if event.response.usage else Usage() diff --git a/src/agents/usage.py b/src/agents/usage.py index 23d989b4b..843f62937 100644 --- a/src/agents/usage.py +++ b/src/agents/usage.py @@ -1,4 +1,6 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field + +from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails @dataclass @@ -9,9 +11,18 @@ class Usage: input_tokens: int = 0 """Total input tokens sent, across all requests.""" + input_tokens_details: InputTokensDetails = field( + default_factory=lambda: InputTokensDetails(cached_tokens=0) + ) + """Details about the input tokens, matching responses API usage details.""" output_tokens: int = 0 """Total output tokens received, across all requests.""" + output_tokens_details: OutputTokensDetails = field( + default_factory=lambda: OutputTokensDetails(reasoning_tokens=0) + ) + """Details about the output tokens, matching responses API usage details.""" + total_tokens: int = 0 """Total tokens sent and received, across all requests.""" @@ -20,3 +31,12 @@ def add(self, other: "Usage") -> None: self.input_tokens += other.input_tokens if other.input_tokens else 0 self.output_tokens += other.output_tokens if other.output_tokens else 0 self.total_tokens += other.total_tokens if other.total_tokens else 0 + self.input_tokens_details = InputTokensDetails( + cached_tokens=self.input_tokens_details.cached_tokens + + other.input_tokens_details.cached_tokens + ) + + self.output_tokens_details = OutputTokensDetails( + reasoning_tokens=self.output_tokens_details.reasoning_tokens + + other.output_tokens_details.reasoning_tokens + ) diff --git a/tests/models/test_litellm_chatcompletions_stream.py b/tests/models/test_litellm_chatcompletions_stream.py index 80bd8ea22..06e46b39c 100644 --- a/tests/models/test_litellm_chatcompletions_stream.py +++ b/tests/models/test_litellm_chatcompletions_stream.py @@ -8,7 +8,11 @@ ChoiceDeltaToolCall, ChoiceDeltaToolCallFunction, ) -from openai.types.completion_usage import CompletionUsage +from openai.types.completion_usage import ( + CompletionTokensDetails, + CompletionUsage, + PromptTokensDetails, +) from openai.types.responses import ( Response, ResponseFunctionToolCall, @@ -46,7 +50,13 @@ async def test_stream_response_yields_events_for_text_content(monkeypatch) -> No model="fake", object="chat.completion.chunk", choices=[Choice(index=0, delta=ChoiceDelta(content="llo"))], - usage=CompletionUsage(completion_tokens=5, prompt_tokens=7, total_tokens=12), + usage=CompletionUsage( + completion_tokens=5, + prompt_tokens=7, + total_tokens=12, + completion_tokens_details=CompletionTokensDetails(reasoning_tokens=2), + prompt_tokens_details=PromptTokensDetails(cached_tokens=6), + ), ) async def fake_stream() -> AsyncIterator[ChatCompletionChunk]: @@ -112,6 +122,8 @@ async def patched_fetch_response(self, *args, **kwargs): assert completed_resp.usage.input_tokens == 7 assert completed_resp.usage.output_tokens == 5 assert completed_resp.usage.total_tokens == 12 + assert completed_resp.usage.input_tokens_details.cached_tokens == 6 + assert completed_resp.usage.output_tokens_details.reasoning_tokens == 2 @pytest.mark.allow_call_model_methods diff --git a/tests/test_extra_headers.py b/tests/test_extra_headers.py index f29c25408..a6af30077 100644 --- a/tests/test_extra_headers.py +++ b/tests/test_extra_headers.py @@ -1,6 +1,7 @@ import pytest from openai.types.chat.chat_completion import ChatCompletion, Choice from openai.types.chat.chat_completion_message import ChatCompletionMessage +from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails from agents import ModelSettings, ModelTracing, OpenAIChatCompletionsModel, OpenAIResponsesModel @@ -17,21 +18,29 @@ class DummyResponses: async def create(self, **kwargs): nonlocal called_kwargs called_kwargs = kwargs + class DummyResponse: id = "dummy" output = [] usage = type( - "Usage", (), {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0} + "Usage", + (), + { + "input_tokens": 0, + "output_tokens": 0, + "total_tokens": 0, + "input_tokens_details": InputTokensDetails(cached_tokens=0), + "output_tokens_details": OutputTokensDetails(reasoning_tokens=0), + }, )() + return DummyResponse() class DummyClient: def __init__(self): self.responses = DummyResponses() - - - model = OpenAIResponsesModel(model="gpt-4", openai_client=DummyClient()) # type: ignore + model = OpenAIResponsesModel(model="gpt-4", openai_client=DummyClient()) # type: ignore extra_headers = {"X-Test-Header": "test-value"} await model.get_response( system_instructions=None, @@ -47,7 +56,6 @@ def __init__(self): assert called_kwargs["extra_headers"]["X-Test-Header"] == "test-value" - @pytest.mark.allow_call_model_methods @pytest.mark.asyncio async def test_extra_headers_passed_to_openai_client(): @@ -76,7 +84,7 @@ def __init__(self): self.chat = type("_Chat", (), {"completions": DummyCompletions()})() self.base_url = "https://api.openai.com" - model = OpenAIChatCompletionsModel(model="gpt-4", openai_client=DummyClient()) # type: ignore + model = OpenAIChatCompletionsModel(model="gpt-4", openai_client=DummyClient()) # type: ignore extra_headers = {"X-Test-Header": "test-value"} await model.get_response( system_instructions=None, diff --git a/tests/test_openai_chatcompletions.py b/tests/test_openai_chatcompletions.py index ba3ec68d0..ba4605d08 100644 --- a/tests/test_openai_chatcompletions.py +++ b/tests/test_openai_chatcompletions.py @@ -13,7 +13,10 @@ ChatCompletionMessageToolCall, Function, ) -from openai.types.completion_usage import CompletionUsage +from openai.types.completion_usage import ( + CompletionUsage, + PromptTokensDetails, +) from openai.types.responses import ( Response, ResponseFunctionToolCall, @@ -51,7 +54,13 @@ async def test_get_response_with_text_message(monkeypatch) -> None: model="fake", object="chat.completion", choices=[choice], - usage=CompletionUsage(completion_tokens=5, prompt_tokens=7, total_tokens=12), + usage=CompletionUsage( + completion_tokens=5, + prompt_tokens=7, + total_tokens=12, + # completion_tokens_details left blank to test default + prompt_tokens_details=PromptTokensDetails(cached_tokens=3), + ), ) async def patched_fetch_response(self, *args, **kwargs): @@ -81,6 +90,8 @@ async def patched_fetch_response(self, *args, **kwargs): assert resp.usage.input_tokens == 7 assert resp.usage.output_tokens == 5 assert resp.usage.total_tokens == 12 + assert resp.usage.input_tokens_details.cached_tokens == 3 + assert resp.usage.output_tokens_details.reasoning_tokens == 0 assert resp.response_id is None @@ -127,6 +138,8 @@ async def patched_fetch_response(self, *args, **kwargs): assert resp.usage.requests == 0 assert resp.usage.input_tokens == 0 assert resp.usage.output_tokens == 0 + assert resp.usage.input_tokens_details.cached_tokens == 0 + assert resp.usage.output_tokens_details.reasoning_tokens == 0 @pytest.mark.allow_call_model_methods diff --git a/tests/test_openai_chatcompletions_stream.py b/tests/test_openai_chatcompletions_stream.py index b82f24303..5c8bb9e3a 100644 --- a/tests/test_openai_chatcompletions_stream.py +++ b/tests/test_openai_chatcompletions_stream.py @@ -8,7 +8,11 @@ ChoiceDeltaToolCall, ChoiceDeltaToolCallFunction, ) -from openai.types.completion_usage import CompletionUsage +from openai.types.completion_usage import ( + CompletionTokensDetails, + CompletionUsage, + PromptTokensDetails, +) from openai.types.responses import ( Response, ResponseFunctionToolCall, @@ -46,7 +50,13 @@ async def test_stream_response_yields_events_for_text_content(monkeypatch) -> No model="fake", object="chat.completion.chunk", choices=[Choice(index=0, delta=ChoiceDelta(content="llo"))], - usage=CompletionUsage(completion_tokens=5, prompt_tokens=7, total_tokens=12), + usage=CompletionUsage( + completion_tokens=5, + prompt_tokens=7, + total_tokens=12, + prompt_tokens_details=PromptTokensDetails(cached_tokens=2), + completion_tokens_details=CompletionTokensDetails(reasoning_tokens=3), + ), ) async def fake_stream() -> AsyncIterator[ChatCompletionChunk]: @@ -112,6 +122,8 @@ async def patched_fetch_response(self, *args, **kwargs): assert completed_resp.usage.input_tokens == 7 assert completed_resp.usage.output_tokens == 5 assert completed_resp.usage.total_tokens == 12 + assert completed_resp.usage.input_tokens_details.cached_tokens == 2 + assert completed_resp.usage.output_tokens_details.reasoning_tokens == 3 @pytest.mark.allow_call_model_methods diff --git a/tests/test_responses_tracing.py b/tests/test_responses_tracing.py index 0bc97a953..dfac74bb9 100644 --- a/tests/test_responses_tracing.py +++ b/tests/test_responses_tracing.py @@ -1,7 +1,10 @@ +from typing import Optional + import pytest from inline_snapshot import snapshot from openai import AsyncOpenAI from openai.types.responses import ResponseCompletedEvent +from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails from agents import ModelSettings, ModelTracing, OpenAIResponsesModel, trace from agents.tracing.span_data import ResponseSpanData @@ -16,10 +19,25 @@ def is_disabled(self): class DummyUsage: - def __init__(self, input_tokens=1, output_tokens=1, total_tokens=2): + def __init__( + self, + input_tokens: int = 1, + input_tokens_details: Optional[InputTokensDetails] = None, + output_tokens: int = 1, + output_tokens_details: Optional[OutputTokensDetails] = None, + total_tokens: int = 2, + ): self.input_tokens = input_tokens self.output_tokens = output_tokens self.total_tokens = total_tokens + self.input_tokens_details = ( + input_tokens_details if input_tokens_details else InputTokensDetails(cached_tokens=0) + ) + self.output_tokens_details = ( + output_tokens_details + if output_tokens_details + else OutputTokensDetails(reasoning_tokens=0) + ) class DummyResponse: diff --git a/tests/test_usage.py b/tests/test_usage.py new file mode 100644 index 000000000..405f99ddf --- /dev/null +++ b/tests/test_usage.py @@ -0,0 +1,52 @@ +from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails + +from agents.usage import Usage + + +def test_usage_add_aggregates_all_fields(): + u1 = Usage( + requests=1, + input_tokens=10, + input_tokens_details=InputTokensDetails(cached_tokens=3), + output_tokens=20, + output_tokens_details=OutputTokensDetails(reasoning_tokens=5), + total_tokens=30, + ) + u2 = Usage( + requests=2, + input_tokens=7, + input_tokens_details=InputTokensDetails(cached_tokens=4), + output_tokens=8, + output_tokens_details=OutputTokensDetails(reasoning_tokens=6), + total_tokens=15, + ) + + u1.add(u2) + + assert u1.requests == 3 + assert u1.input_tokens == 17 + assert u1.output_tokens == 28 + assert u1.total_tokens == 45 + assert u1.input_tokens_details.cached_tokens == 7 + assert u1.output_tokens_details.reasoning_tokens == 11 + + +def test_usage_add_aggregates_with_none_values(): + u1 = Usage() + u2 = Usage( + requests=2, + input_tokens=7, + input_tokens_details=InputTokensDetails(cached_tokens=4), + output_tokens=8, + output_tokens_details=OutputTokensDetails(reasoning_tokens=6), + total_tokens=15, + ) + + u1.add(u2) + + assert u1.requests == 2 + assert u1.input_tokens == 7 + assert u1.output_tokens == 8 + assert u1.total_tokens == 15 + assert u1.input_tokens_details.cached_tokens == 4 + assert u1.output_tokens_details.reasoning_tokens == 6 From ce2e2a4571c2b176e8641c558fedaa7bc1692013 Mon Sep 17 00:00:00 2001 From: Rohan Mehta Date: Wed, 21 May 2025 15:17:58 -0400 Subject: [PATCH 23/33] Upgrade openAI sdk version (#730) --- [//]: # (BEGIN SAPLING FOOTER) * #732 * #731 * __->__ #730 --- pyproject.toml | 2 +- src/agents/models/chatcmpl_stream_handler.py | 26 +++++++++++++++++++- src/agents/models/openai_responses.py | 16 ++++-------- tests/fake_model.py | 1 + tests/test_responses_tracing.py | 4 +++ tests/voice/test_workflow.py | 2 ++ uv.lock | 10 ++++---- 7 files changed, 43 insertions(+), 18 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 672258c42..200ac2485 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ requires-python = ">=3.9" license = "MIT" authors = [{ name = "OpenAI", email = "support@openai.com" }] dependencies = [ - "openai>=1.76.0", + "openai>=1.81.0", "pydantic>=2.10, <3", "griffe>=1.5.6, <2", "typing-extensions>=4.12.2, <5", diff --git a/src/agents/models/chatcmpl_stream_handler.py b/src/agents/models/chatcmpl_stream_handler.py index c71adeb55..d18f5912a 100644 --- a/src/agents/models/chatcmpl_stream_handler.py +++ b/src/agents/models/chatcmpl_stream_handler.py @@ -38,6 +38,16 @@ class StreamingState: function_calls: dict[int, ResponseFunctionToolCall] = field(default_factory=dict) +class SequenceNumber: + def __init__(self): + self._sequence_number = 0 + + def get_and_increment(self) -> int: + num = self._sequence_number + self._sequence_number += 1 + return num + + class ChatCmplStreamHandler: @classmethod async def handle_stream( @@ -47,13 +57,14 @@ async def handle_stream( ) -> AsyncIterator[TResponseStreamEvent]: usage: CompletionUsage | None = None state = StreamingState() - + sequence_number = SequenceNumber() async for chunk in stream: if not state.started: state.started = True yield ResponseCreatedEvent( response=response, type="response.created", + sequence_number=sequence_number.get_and_increment(), ) # This is always set by the OpenAI API, but not by others e.g. LiteLLM @@ -89,6 +100,7 @@ async def handle_stream( item=assistant_item, output_index=0, type="response.output_item.added", + sequence_number=sequence_number.get_and_increment(), ) yield ResponseContentPartAddedEvent( content_index=state.text_content_index_and_output[0], @@ -100,6 +112,7 @@ async def handle_stream( annotations=[], ), type="response.content_part.added", + sequence_number=sequence_number.get_and_increment(), ) # Emit the delta for this segment of content yield ResponseTextDeltaEvent( @@ -108,6 +121,7 @@ async def handle_stream( item_id=FAKE_RESPONSES_ID, output_index=0, type="response.output_text.delta", + sequence_number=sequence_number.get_and_increment(), ) # Accumulate the text into the response part state.text_content_index_and_output[1].text += delta.content @@ -134,6 +148,7 @@ async def handle_stream( item=assistant_item, output_index=0, type="response.output_item.added", + sequence_number=sequence_number.get_and_increment(), ) yield ResponseContentPartAddedEvent( content_index=state.refusal_content_index_and_output[0], @@ -145,6 +160,7 @@ async def handle_stream( annotations=[], ), type="response.content_part.added", + sequence_number=sequence_number.get_and_increment(), ) # Emit the delta for this segment of refusal yield ResponseRefusalDeltaEvent( @@ -153,6 +169,7 @@ async def handle_stream( item_id=FAKE_RESPONSES_ID, output_index=0, type="response.refusal.delta", + sequence_number=sequence_number.get_and_increment(), ) # Accumulate the refusal string in the output part state.refusal_content_index_and_output[1].refusal += delta.refusal @@ -190,6 +207,7 @@ async def handle_stream( output_index=0, part=state.text_content_index_and_output[1], type="response.content_part.done", + sequence_number=sequence_number.get_and_increment(), ) if state.refusal_content_index_and_output: @@ -201,6 +219,7 @@ async def handle_stream( output_index=0, part=state.refusal_content_index_and_output[1], type="response.content_part.done", + sequence_number=sequence_number.get_and_increment(), ) # Actually send events for the function calls @@ -216,6 +235,7 @@ async def handle_stream( ), output_index=function_call_starting_index, type="response.output_item.added", + sequence_number=sequence_number.get_and_increment(), ) # Then, yield the args yield ResponseFunctionCallArgumentsDeltaEvent( @@ -223,6 +243,7 @@ async def handle_stream( item_id=FAKE_RESPONSES_ID, output_index=function_call_starting_index, type="response.function_call_arguments.delta", + sequence_number=sequence_number.get_and_increment(), ) # Finally, the ResponseOutputItemDone yield ResponseOutputItemDoneEvent( @@ -235,6 +256,7 @@ async def handle_stream( ), output_index=function_call_starting_index, type="response.output_item.done", + sequence_number=sequence_number.get_and_increment(), ) # Finally, send the Response completed event @@ -258,6 +280,7 @@ async def handle_stream( item=assistant_msg, output_index=0, type="response.output_item.done", + sequence_number=sequence_number.get_and_increment(), ) for function_call in state.function_calls.values(): @@ -289,4 +312,5 @@ async def handle_stream( yield ResponseCompletedEvent( response=final_response, type="response.completed", + sequence_number=sequence_number.get_and_increment(), ) diff --git a/src/agents/models/openai_responses.py b/src/agents/models/openai_responses.py index 6ec8f8f7b..cb6567909 100644 --- a/src/agents/models/openai_responses.py +++ b/src/agents/models/openai_responses.py @@ -10,6 +10,7 @@ from openai.types.responses import ( Response, ResponseCompletedEvent, + ResponseIncludable, ResponseStreamEvent, ResponseTextConfigParam, ToolParam, @@ -36,13 +37,6 @@ _USER_AGENT = f"Agents/Python {__version__}" _HEADERS = {"User-Agent": _USER_AGENT} -# From the Responses API -IncludeLiteral = Literal[ - "file_search_call.results", - "message.input_image.image_url", - "computer_call_output.output.image_url", -] - class OpenAIResponsesModel(Model): """ @@ -273,7 +267,7 @@ def _get_client(self) -> AsyncOpenAI: @dataclass class ConvertedTools: tools: list[ToolParam] - includes: list[IncludeLiteral] + includes: list[ResponseIncludable] class Converter: @@ -330,7 +324,7 @@ def convert_tools( handoffs: list[Handoff[Any]], ) -> ConvertedTools: converted_tools: list[ToolParam] = [] - includes: list[IncludeLiteral] = [] + includes: list[ResponseIncludable] = [] computer_tools = [tool for tool in tools if isinstance(tool, ComputerTool)] if len(computer_tools) > 1: @@ -348,7 +342,7 @@ def convert_tools( return ConvertedTools(tools=converted_tools, includes=includes) @classmethod - def _convert_tool(cls, tool: Tool) -> tuple[ToolParam, IncludeLiteral | None]: + def _convert_tool(cls, tool: Tool) -> tuple[ToolParam, ResponseIncludable | None]: """Returns converted tool and includes""" if isinstance(tool, FunctionTool): @@ -359,7 +353,7 @@ def _convert_tool(cls, tool: Tool) -> tuple[ToolParam, IncludeLiteral | None]: "type": "function", "description": tool.description, } - includes: IncludeLiteral | None = None + includes: ResponseIncludable | None = None elif isinstance(tool, WebSearchTool): ws: WebSearchToolParam = { "type": "web_search_preview", diff --git a/tests/fake_model.py b/tests/fake_model.py index 32f919ef1..9f0c83a2f 100644 --- a/tests/fake_model.py +++ b/tests/fake_model.py @@ -129,6 +129,7 @@ async def stream_response( yield ResponseCompletedEvent( type="response.completed", response=get_response_obj(output, usage=self.hardcoded_usage), + sequence_number=0, ) diff --git a/tests/test_responses_tracing.py b/tests/test_responses_tracing.py index dfac74bb9..db24fe496 100644 --- a/tests/test_responses_tracing.py +++ b/tests/test_responses_tracing.py @@ -50,6 +50,7 @@ def __aiter__(self): yield ResponseCompletedEvent( type="response.completed", response=fake_model.get_response_obj(self.output), + sequence_number=0, ) @@ -201,6 +202,7 @@ async def __aiter__(self): yield ResponseCompletedEvent( type="response.completed", response=fake_model.get_response_obj([], "dummy-id-123"), + sequence_number=0, ) return DummyStream() @@ -253,6 +255,7 @@ async def __aiter__(self): yield ResponseCompletedEvent( type="response.completed", response=fake_model.get_response_obj([], "dummy-id-123"), + sequence_number=0, ) return DummyStream() @@ -304,6 +307,7 @@ async def __aiter__(self): yield ResponseCompletedEvent( type="response.completed", response=fake_model.get_response_obj([], "dummy-id-123"), + sequence_number=0, ) return DummyStream() diff --git a/tests/voice/test_workflow.py b/tests/voice/test_workflow.py index 2bdf2a657..035a05d56 100644 --- a/tests/voice/test_workflow.py +++ b/tests/voice/test_workflow.py @@ -81,11 +81,13 @@ async def stream_response( type="response.output_text.delta", output_index=0, item_id=item.id, + sequence_number=0, ) yield ResponseCompletedEvent( type="response.completed", response=get_response_obj(output), + sequence_number=1, ) diff --git a/uv.lock b/uv.lock index 6ccc19966..7a0cb1e6b 100644 --- a/uv.lock +++ b/uv.lock @@ -1461,7 +1461,7 @@ wheels = [ [[package]] name = "openai" -version = "1.76.0" +version = "1.81.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, @@ -1473,14 +1473,14 @@ dependencies = [ { name = "tqdm" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/84/51/817969ec969b73d8ddad085670ecd8a45ef1af1811d8c3b8a177ca4d1309/openai-1.76.0.tar.gz", hash = "sha256:fd2bfaf4608f48102d6b74f9e11c5ecaa058b60dad9c36e409c12477dfd91fb2", size = 434660 } +sdist = { url = "https://files.pythonhosted.org/packages/1c/89/a1e4f3fa7ca4f7fec90dbf47d93b7cd5ff65924926733af15044e302a192/openai-1.81.0.tar.gz", hash = "sha256:349567a8607e0bcffd28e02f96b5c2397d0d25d06732d90ab3ecbf97abf030f9", size = 456861 } wheels = [ - { url = "https://files.pythonhosted.org/packages/59/aa/84e02ab500ca871eb8f62784426963a1c7c17a72fea3c7f268af4bbaafa5/openai-1.76.0-py3-none-any.whl", hash = "sha256:a712b50e78cf78e6d7b2a8f69c4978243517c2c36999756673e07a14ce37dc0a", size = 661201 }, + { url = "https://files.pythonhosted.org/packages/02/66/bcc7f9bf48e8610a33e3b5c96a5a644dad032d92404ea2a5e8b43ba067e8/openai-1.81.0-py3-none-any.whl", hash = "sha256:1c71572e22b43876c5d7d65ade0b7b516bb527c3d44ae94111267a09125f7bae", size = 717529 }, ] [[package]] name = "openai-agents" -version = "0.0.14" +version = "0.0.15" source = { editable = "." } dependencies = [ { name = "griffe" }, @@ -1536,7 +1536,7 @@ requires-dist = [ { name = "litellm", marker = "extra == 'litellm'", specifier = ">=1.67.4.post1,<2" }, { name = "mcp", marker = "python_full_version >= '3.10'", specifier = ">=1.8.0,<2" }, { name = "numpy", marker = "python_full_version >= '3.10' and extra == 'voice'", specifier = ">=2.2.0,<3" }, - { name = "openai", specifier = ">=1.76.0" }, + { name = "openai", specifier = ">=1.81.0" }, { name = "pydantic", specifier = ">=2.10,<3" }, { name = "requests", specifier = ">=2.0,<3" }, { name = "types-requests", specifier = ">=2.0,<3" }, From 9fa5c39d69937a215a6f247883243fe38c5a39c2 Mon Sep 17 00:00:00 2001 From: Rohan Mehta Date: Wed, 21 May 2025 15:21:37 -0400 Subject: [PATCH 24/33] Hosted MCP support (#731) --- [//]: # (BEGIN SAPLING FOOTER) * #732 * __->__ #731 --- examples/hosted_mcp/__init__.py | 0 examples/hosted_mcp/approvals.py | 61 ++++++++++++++ examples/hosted_mcp/simple.py | 47 +++++++++++ src/agents/__init__.py | 8 ++ src/agents/_run_impl.py | 114 ++++++++++++++++++++++++-- src/agents/items.py | 41 ++++++++- src/agents/models/openai_responses.py | 6 +- src/agents/stream_events.py | 2 + src/agents/tool.py | 64 ++++++++++++++- 9 files changed, 332 insertions(+), 11 deletions(-) create mode 100644 examples/hosted_mcp/__init__.py create mode 100644 examples/hosted_mcp/approvals.py create mode 100644 examples/hosted_mcp/simple.py diff --git a/examples/hosted_mcp/__init__.py b/examples/hosted_mcp/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/hosted_mcp/approvals.py b/examples/hosted_mcp/approvals.py new file mode 100644 index 000000000..2cabb3ee2 --- /dev/null +++ b/examples/hosted_mcp/approvals.py @@ -0,0 +1,61 @@ +import argparse +import asyncio + +from agents import ( + Agent, + HostedMCPTool, + MCPToolApprovalFunctionResult, + MCPToolApprovalRequest, + Runner, +) + +"""This example demonstrates how to use the hosted MCP support in the OpenAI Responses API, with +approval callbacks.""" + + +def approval_callback(request: MCPToolApprovalRequest) -> MCPToolApprovalFunctionResult: + answer = input(f"Approve running the tool `{request.data.name}`? (y/n) ") + result: MCPToolApprovalFunctionResult = {"approve": answer == "y"} + if not result["approve"]: + result["reason"] = "User denied" + return result + + +async def main(verbose: bool, stream: bool): + agent = Agent( + name="Assistant", + tools=[ + HostedMCPTool( + tool_config={ + "type": "mcp", + "server_label": "gitmcp", + "server_url": "https://gitmcp.io/openai/codex", + "require_approval": "always", + }, + on_approval_request=approval_callback, + ) + ], + ) + + if stream: + result = Runner.run_streamed(agent, "Which language is this repo written in?") + async for event in result.stream_events(): + if event.type == "run_item_stream_event": + print(f"Got event of type {event.item.__class__.__name__}") + print(f"Done streaming; final result: {result.final_output}") + else: + res = await Runner.run(agent, "Which language is this repo written in?") + print(res.final_output) + + if verbose: + for item in result.new_items: + print(item) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--verbose", action="store_true", default=False) + parser.add_argument("--stream", action="store_true", default=False) + args = parser.parse_args() + + asyncio.run(main(args.verbose, args.stream)) diff --git a/examples/hosted_mcp/simple.py b/examples/hosted_mcp/simple.py new file mode 100644 index 000000000..508c3a7ae --- /dev/null +++ b/examples/hosted_mcp/simple.py @@ -0,0 +1,47 @@ +import argparse +import asyncio + +from agents import Agent, HostedMCPTool, Runner + +"""This example demonstrates how to use the hosted MCP support in the OpenAI Responses API, with +approvals not required for any tools. You should only use this for trusted MCP servers.""" + + +async def main(verbose: bool, stream: bool): + agent = Agent( + name="Assistant", + tools=[ + HostedMCPTool( + tool_config={ + "type": "mcp", + "server_label": "gitmcp", + "server_url": "https://gitmcp.io/openai/codex", + "require_approval": "never", + } + ) + ], + ) + + if stream: + result = Runner.run_streamed(agent, "Which language is this repo written in?") + async for event in result.stream_events(): + if event.type == "run_item_stream_event": + print(f"Got event of type {event.item.__class__.__name__}") + print(f"Done streaming; final result: {result.final_output}") + else: + res = await Runner.run(agent, "Which language is this repo written in?") + print(res.final_output) + # The repository is primarily written in multiple languages, including Rust and TypeScript... + + if verbose: + for item in result.new_items: + print(item) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--verbose", action="store_true", default=False) + parser.add_argument("--stream", action="store_true", default=False) + args = parser.parse_args() + + asyncio.run(main(args.verbose, args.stream)) diff --git a/src/agents/__init__.py b/src/agents/__init__.py index 6d7c90b4f..36c26b80d 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -58,6 +58,10 @@ FileSearchTool, FunctionTool, FunctionToolResult, + HostedMCPTool, + MCPToolApprovalFunction, + MCPToolApprovalFunctionResult, + MCPToolApprovalRequest, Tool, WebSearchTool, default_tool_error_function, @@ -208,6 +212,10 @@ def enable_verbose_stdout_logging(): "FileSearchTool", "Tool", "WebSearchTool", + "HostedMCPTool", + "MCPToolApprovalFunction", + "MCPToolApprovalRequest", + "MCPToolApprovalFunctionResult", "function_tool", "Usage", "add_trace_processor", diff --git a/src/agents/_run_impl.py b/src/agents/_run_impl.py index b5a83685c..ab1e78797 100644 --- a/src/agents/_run_impl.py +++ b/src/agents/_run_impl.py @@ -25,7 +25,8 @@ ActionType, ActionWait, ) -from openai.types.responses.response_input_param import ComputerCallOutput +from openai.types.responses.response_input_param import ComputerCallOutput, McpApprovalResponse +from openai.types.responses.response_output_item import McpApprovalRequest, McpCall, McpListTools from openai.types.responses.response_reasoning_item import ResponseReasoningItem from .agent import Agent, ToolsToFinalOutputResult @@ -38,6 +39,9 @@ HandoffCallItem, HandoffOutputItem, ItemHelpers, + MCPApprovalRequestItem, + MCPApprovalResponseItem, + MCPListToolsItem, MessageOutputItem, ModelResponse, ReasoningItem, @@ -52,7 +56,14 @@ from .models.interface import ModelTracing from .run_context import RunContextWrapper, TContext from .stream_events import RunItemStreamEvent, StreamEvent -from .tool import ComputerTool, FunctionTool, FunctionToolResult, Tool +from .tool import ( + ComputerTool, + FunctionTool, + FunctionToolResult, + HostedMCPTool, + MCPToolApprovalRequest, + Tool, +) from .tracing import ( SpanError, Trace, @@ -112,6 +123,12 @@ class ToolRunComputerAction: computer_tool: ComputerTool +@dataclass +class ToolRunMCPApprovalRequest: + request_item: McpApprovalRequest + mcp_tool: HostedMCPTool + + @dataclass class ProcessedResponse: new_items: list[RunItem] @@ -119,8 +136,9 @@ class ProcessedResponse: functions: list[ToolRunFunction] computer_actions: list[ToolRunComputerAction] tools_used: list[str] # Names of all tools used, including hosted tools + mcp_approval_requests: list[ToolRunMCPApprovalRequest] # Only requests with callbacks - def has_tools_to_run(self) -> bool: + def has_tools_or_approvals_to_run(self) -> bool: # Handoffs, functions and computer actions need local processing # Hosted tools have already run, so there's nothing to do. return any( @@ -128,6 +146,7 @@ def has_tools_to_run(self) -> bool: self.handoffs, self.functions, self.computer_actions, + self.mcp_approval_requests, ] ) @@ -226,7 +245,16 @@ async def execute_tools_and_side_effects( new_step_items.extend([result.run_item for result in function_results]) new_step_items.extend(computer_results) - # Second, check if there are any handoffs + # Next, run the MCP approval requests + if processed_response.mcp_approval_requests: + approval_results = await cls.execute_mcp_approval_requests( + agent=agent, + approval_requests=processed_response.mcp_approval_requests, + context_wrapper=context_wrapper, + ) + new_step_items.extend(approval_results) + + # Next, check if there are any handoffs if run_handoffs := processed_response.handoffs: return await cls.execute_handoffs( agent=agent, @@ -240,7 +268,7 @@ async def execute_tools_and_side_effects( run_config=run_config, ) - # Third, we'll check if the tool use should result in a final output + # Next, we'll check if the tool use should result in a final output check_tool_use = await cls._check_for_final_output_from_tools( agent=agent, tool_results=function_results, @@ -295,7 +323,7 @@ async def execute_tools_and_side_effects( ) elif ( not output_schema or output_schema.is_plain_text() - ) and not processed_response.has_tools_to_run(): + ) and not processed_response.has_tools_or_approvals_to_run(): return await cls.execute_final_output( agent=agent, original_input=original_input, @@ -343,10 +371,16 @@ def process_model_response( run_handoffs = [] functions = [] computer_actions = [] + mcp_approval_requests = [] tools_used: list[str] = [] handoff_map = {handoff.tool_name: handoff for handoff in handoffs} function_map = {tool.name: tool for tool in all_tools if isinstance(tool, FunctionTool)} computer_tool = next((tool for tool in all_tools if isinstance(tool, ComputerTool)), None) + hosted_mcp_server_map = { + tool.tool_config["server_label"]: tool + for tool in all_tools + if isinstance(tool, HostedMCPTool) + } for output in response.output: if isinstance(output, ResponseOutputMessage): @@ -375,6 +409,34 @@ def process_model_response( computer_actions.append( ToolRunComputerAction(tool_call=output, computer_tool=computer_tool) ) + elif isinstance(output, McpApprovalRequest): + items.append(MCPApprovalRequestItem(raw_item=output, agent=agent)) + if output.server_label not in hosted_mcp_server_map: + _error_tracing.attach_error_to_current_span( + SpanError( + message="MCP server label not found", + data={"server_label": output.server_label}, + ) + ) + raise ModelBehaviorError(f"MCP server label {output.server_label} not found") + else: + server = hosted_mcp_server_map[output.server_label] + if server.on_approval_request: + mcp_approval_requests.append( + ToolRunMCPApprovalRequest( + request_item=output, + mcp_tool=server, + ) + ) + else: + logger.warning( + f"MCP server {output.server_label} has no on_approval_request hook" + ) + elif isinstance(output, McpListTools): + items.append(MCPListToolsItem(raw_item=output, agent=agent)) + elif isinstance(output, McpCall): + items.append(ToolCallItem(raw_item=output, agent=agent)) + tools_used.append(output.name) elif not isinstance(output, ResponseFunctionToolCall): logger.warning(f"Unexpected output type, ignoring: {type(output)}") continue @@ -417,6 +479,7 @@ def process_model_response( functions=functions, computer_actions=computer_actions, tools_used=tools_used, + mcp_approval_requests=mcp_approval_requests, ) @classmethod @@ -643,6 +706,40 @@ async def execute_handoffs( next_step=NextStepHandoff(new_agent), ) + @classmethod + async def execute_mcp_approval_requests( + cls, + *, + agent: Agent[TContext], + approval_requests: list[ToolRunMCPApprovalRequest], + context_wrapper: RunContextWrapper[TContext], + ) -> list[RunItem]: + async def run_single_approval(approval_request: ToolRunMCPApprovalRequest) -> RunItem: + callback = approval_request.mcp_tool.on_approval_request + assert callback is not None, "Callback is required for MCP approval requests" + maybe_awaitable_result = callback( + MCPToolApprovalRequest(context_wrapper, approval_request.request_item) + ) + if inspect.isawaitable(maybe_awaitable_result): + result = await maybe_awaitable_result + else: + result = maybe_awaitable_result + reason = result.get("reason", None) + raw_item: McpApprovalResponse = { + "approval_request_id": approval_request.request_item.id, + "approve": result["approve"], + "type": "mcp_approval_response", + } + if not result["approve"] and reason: + raw_item["reason"] = reason + return MCPApprovalResponseItem( + raw_item=raw_item, + agent=agent, + ) + + tasks = [run_single_approval(approval_request) for approval_request in approval_requests] + return await asyncio.gather(*tasks) + @classmethod async def execute_final_output( cls, @@ -727,6 +824,11 @@ def stream_step_result_to_queue( event = RunItemStreamEvent(item=item, name="tool_output") elif isinstance(item, ReasoningItem): event = RunItemStreamEvent(item=item, name="reasoning_item_created") + elif isinstance(item, MCPApprovalRequestItem): + event = RunItemStreamEvent(item=item, name="mcp_approval_requested") + elif isinstance(item, MCPListToolsItem): + event = RunItemStreamEvent(item=item, name="mcp_list_tools") + else: logger.warning(f"Unexpected item type: {type(item)}") event = None diff --git a/src/agents/items.py b/src/agents/items.py index 8fb2b52a3..65a911798 100644 --- a/src/agents/items.py +++ b/src/agents/items.py @@ -18,7 +18,12 @@ ResponseOutputText, ResponseStreamEvent, ) -from openai.types.responses.response_input_item_param import ComputerCallOutput, FunctionCallOutput +from openai.types.responses.response_input_item_param import ( + ComputerCallOutput, + FunctionCallOutput, + McpApprovalResponse, +) +from openai.types.responses.response_output_item import McpApprovalRequest, McpCall, McpListTools from openai.types.responses.response_reasoning_item import ResponseReasoningItem from pydantic import BaseModel from typing_extensions import TypeAlias @@ -108,6 +113,7 @@ class HandoffOutputItem(RunItemBase[TResponseInputItem]): ResponseComputerToolCall, ResponseFileSearchToolCall, ResponseFunctionWebSearch, + McpCall, ] """A type that represents a tool call item.""" @@ -147,6 +153,36 @@ class ReasoningItem(RunItemBase[ResponseReasoningItem]): type: Literal["reasoning_item"] = "reasoning_item" +@dataclass +class MCPListToolsItem(RunItemBase[McpListTools]): + """Represents a call to an MCP server to list tools.""" + + raw_item: McpListTools + """The raw MCP list tools call.""" + + type: Literal["mcp_list_tools_item"] = "mcp_list_tools_item" + + +@dataclass +class MCPApprovalRequestItem(RunItemBase[McpApprovalRequest]): + """Represents a request for MCP approval.""" + + raw_item: McpApprovalRequest + """The raw MCP approval request.""" + + type: Literal["mcp_approval_request_item"] = "mcp_approval_request_item" + + +@dataclass +class MCPApprovalResponseItem(RunItemBase[McpApprovalResponse]): + """Represents a response to an MCP approval request.""" + + raw_item: McpApprovalResponse + """The raw MCP approval response.""" + + type: Literal["mcp_approval_response_item"] = "mcp_approval_response_item" + + RunItem: TypeAlias = Union[ MessageOutputItem, HandoffCallItem, @@ -154,6 +190,9 @@ class ReasoningItem(RunItemBase[ResponseReasoningItem]): ToolCallItem, ToolCallOutputItem, ReasoningItem, + MCPListToolsItem, + MCPApprovalRequestItem, + MCPApprovalResponseItem, ] """An item generated by an agent.""" diff --git a/src/agents/models/openai_responses.py b/src/agents/models/openai_responses.py index cb6567909..65a4f5caf 100644 --- a/src/agents/models/openai_responses.py +++ b/src/agents/models/openai_responses.py @@ -24,7 +24,7 @@ from ..handoffs import Handoff from ..items import ItemHelpers, ModelResponse, TResponseInputItem from ..logger import logger -from ..tool import ComputerTool, FileSearchTool, FunctionTool, Tool, WebSearchTool +from ..tool import ComputerTool, FileSearchTool, FunctionTool, HostedMCPTool, Tool, WebSearchTool from ..tracing import SpanError, response_span from ..usage import Usage from ..version import __version__ @@ -383,7 +383,9 @@ def _convert_tool(cls, tool: Tool) -> tuple[ToolParam, ResponseIncludable | None "display_height": tool.computer.dimensions[1], } includes = None - + elif isinstance(tool, HostedMCPTool): + converted_tool = tool.tool_config + includes = None else: raise UserError(f"Unknown tool type: {type(tool)}, tool") diff --git a/src/agents/stream_events.py b/src/agents/stream_events.py index bd37d11f3..111d0b951 100644 --- a/src/agents/stream_events.py +++ b/src/agents/stream_events.py @@ -35,6 +35,8 @@ class RunItemStreamEvent: "tool_called", "tool_output", "reasoning_item_created", + "mcp_approval_requested", + "mcp_list_tools", ] """The name of the event.""" diff --git a/src/agents/tool.py b/src/agents/tool.py index c1c162423..3bcd57c2e 100644 --- a/src/agents/tool.py +++ b/src/agents/tool.py @@ -7,9 +7,11 @@ from typing import Any, Callable, Literal, Union, overload from openai.types.responses.file_search_tool_param import Filters, RankingOptions +from openai.types.responses.response_output_item import McpApprovalRequest +from openai.types.responses.tool_param import Mcp from openai.types.responses.web_search_tool_param import UserLocation from pydantic import ValidationError -from typing_extensions import Concatenate, ParamSpec +from typing_extensions import Concatenate, NotRequired, ParamSpec, TypedDict from . import _debug from .computer import AsyncComputer, Computer @@ -130,7 +132,55 @@ def name(self): return "computer_use_preview" -Tool = Union[FunctionTool, FileSearchTool, WebSearchTool, ComputerTool] +@dataclass +class MCPToolApprovalRequest: + """A request to approve a tool call.""" + + ctx_wrapper: RunContextWrapper[Any] + """The run context.""" + + data: McpApprovalRequest + """The data from the MCP tool approval request.""" + + +class MCPToolApprovalFunctionResult(TypedDict): + """The result of an MCP tool approval function.""" + + approve: bool + """Whether to approve the tool call.""" + + reason: NotRequired[str] + """An optional reason, if rejected.""" + + +MCPToolApprovalFunction = Callable[ + [MCPToolApprovalRequest], MaybeAwaitable[MCPToolApprovalFunctionResult] +] +"""A function that approves or rejects a tool call.""" + + +@dataclass +class HostedMCPTool: + """A tool that allows the LLM to use a remote MCP server. The LLM will automatically list and + call tools, without requiring a a round trip back to your code. + If you want to run MCP servers locally via stdio, in a VPC or other non-publicly-accessible + environment, or you just prefer to run tool calls locally, then you can instead use the servers + in `agents.mcp` and pass `Agent(mcp_servers=[...])` to the agent.""" + + tool_config: Mcp + """The MCP tool config, which includes the server URL and other settings.""" + + on_approval_request: MCPToolApprovalFunction | None = None + """An optional function that will be called if approval is requested for an MCP tool. If not + provided, you will need to manually add approvals/rejections to the input and call + `Runner.run(...)` again.""" + + @property + def name(self): + return "hosted_mcp" + + +Tool = Union[FunctionTool, FileSearchTool, WebSearchTool, ComputerTool, HostedMCPTool] """A tool that can be used in an agent.""" @@ -308,3 +358,13 @@ def decorator(real_func: ToolFunction[...]) -> FunctionTool: return _create_function_tool(real_func) return decorator + return decorator + return decorator + return decorator + return decorator + return decorator + return decorator + return decorator + return decorator + return decorator + return decorator From 079764f0ab463fda9ecf397b0a5d8e466e87a86c Mon Sep 17 00:00:00 2001 From: Rohan Mehta Date: Wed, 21 May 2025 15:26:22 -0400 Subject: [PATCH 25/33] Add support for local shell, image generator, code interpreter tools (#732) --- examples/tools/code_interpreter.py | 34 +++++++ examples/tools/image_generator.py | 54 +++++++++++ src/agents/__init__.py | 10 +++ src/agents/_run_impl.py | 124 +++++++++++++++++++++++++- src/agents/items.py | 21 ++++- src/agents/models/openai_responses.py | 35 +++++++- src/agents/tool.py | 76 +++++++++++++--- 7 files changed, 334 insertions(+), 20 deletions(-) create mode 100644 examples/tools/code_interpreter.py create mode 100644 examples/tools/image_generator.py diff --git a/examples/tools/code_interpreter.py b/examples/tools/code_interpreter.py new file mode 100644 index 000000000..a5843ce3f --- /dev/null +++ b/examples/tools/code_interpreter.py @@ -0,0 +1,34 @@ +import asyncio + +from agents import Agent, CodeInterpreterTool, Runner, trace + + +async def main(): + agent = Agent( + name="Code interpreter", + instructions="You love doing math.", + tools=[ + CodeInterpreterTool( + tool_config={"type": "code_interpreter", "container": {"type": "auto"}}, + ) + ], + ) + + with trace("Code interpreter example"): + print("Solving math problem...") + result = Runner.run_streamed(agent, "What is the square root of273 * 312821 plus 1782?") + async for event in result.stream_events(): + if ( + event.type == "run_item_stream_event" + and event.item.type == "tool_call_item" + and event.item.raw_item.type == "code_interpreter_call" + ): + print(f"Code interpreter code:\n```\n{event.item.raw_item.code}\n```\n") + elif event.type == "run_item_stream_event": + print(f"Other event: {event.item.type}") + + print(f"Final output: {result.final_output}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/tools/image_generator.py b/examples/tools/image_generator.py new file mode 100644 index 000000000..fd6fcc6ba --- /dev/null +++ b/examples/tools/image_generator.py @@ -0,0 +1,54 @@ +import asyncio +import base64 +import os +import subprocess +import sys +import tempfile + +from agents import Agent, ImageGenerationTool, Runner, trace + + +def open_file(path: str) -> None: + if sys.platform.startswith("darwin"): + subprocess.run(["open", path], check=False) # macOS + elif os.name == "nt": # Windows + os.astartfile(path) # type: ignore + elif os.name == "posix": + subprocess.run(["xdg-open", path], check=False) # Linux/Unix + else: + print(f"Don't know how to open files on this platform: {sys.platform}") + + +async def main(): + agent = Agent( + name="Image generator", + instructions="You are a helpful agent.", + tools=[ + ImageGenerationTool( + tool_config={"type": "image_generation", "quality": "low"}, + ) + ], + ) + + with trace("Image generation example"): + print("Generating image, this may take a while...") + result = await Runner.run( + agent, "Create an image of a frog eating a pizza, comic book style." + ) + print(result.final_output) + for item in result.new_items: + if ( + item.type == "tool_call_item" + and item.raw_item.type == "image_generation_call" + and (img_result := item.raw_item.result) + ): + with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp: + tmp.write(base64.b64decode(img_result)) + temp_path = tmp.name + + # Open the image + open_file(temp_path) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/agents/__init__.py b/src/agents/__init__.py index 36c26b80d..58949157a 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -54,11 +54,16 @@ StreamEvent, ) from .tool import ( + CodeInterpreterTool, ComputerTool, FileSearchTool, FunctionTool, FunctionToolResult, HostedMCPTool, + ImageGenerationTool, + LocalShellCommandRequest, + LocalShellExecutor, + LocalShellTool, MCPToolApprovalFunction, MCPToolApprovalFunctionResult, MCPToolApprovalRequest, @@ -210,6 +215,11 @@ def enable_verbose_stdout_logging(): "FunctionToolResult", "ComputerTool", "FileSearchTool", + "CodeInterpreterTool", + "ImageGenerationTool", + "LocalShellCommandRequest", + "LocalShellExecutor", + "LocalShellTool", "Tool", "WebSearchTool", "HostedMCPTool", diff --git a/src/agents/_run_impl.py b/src/agents/_run_impl.py index ab1e78797..2cfa270e0 100644 --- a/src/agents/_run_impl.py +++ b/src/agents/_run_impl.py @@ -14,6 +14,9 @@ ResponseFunctionWebSearch, ResponseOutputMessage, ) +from openai.types.responses.response_code_interpreter_tool_call import ( + ResponseCodeInterpreterToolCall, +) from openai.types.responses.response_computer_tool_call import ( ActionClick, ActionDoubleClick, @@ -26,7 +29,12 @@ ActionWait, ) from openai.types.responses.response_input_param import ComputerCallOutput, McpApprovalResponse -from openai.types.responses.response_output_item import McpApprovalRequest, McpCall, McpListTools +from openai.types.responses.response_output_item import ( + ImageGenerationCall, + LocalShellCall, + McpApprovalRequest, + McpListTools, +) from openai.types.responses.response_reasoning_item import ResponseReasoningItem from .agent import Agent, ToolsToFinalOutputResult @@ -61,6 +69,8 @@ FunctionTool, FunctionToolResult, HostedMCPTool, + LocalShellCommandRequest, + LocalShellTool, MCPToolApprovalRequest, Tool, ) @@ -129,12 +139,19 @@ class ToolRunMCPApprovalRequest: mcp_tool: HostedMCPTool +@dataclass +class ToolRunLocalShellCall: + tool_call: LocalShellCall + local_shell_tool: LocalShellTool + + @dataclass class ProcessedResponse: new_items: list[RunItem] handoffs: list[ToolRunHandoff] functions: list[ToolRunFunction] computer_actions: list[ToolRunComputerAction] + local_shell_calls: list[ToolRunLocalShellCall] tools_used: list[str] # Names of all tools used, including hosted tools mcp_approval_requests: list[ToolRunMCPApprovalRequest] # Only requests with callbacks @@ -146,6 +163,7 @@ def has_tools_or_approvals_to_run(self) -> bool: self.handoffs, self.functions, self.computer_actions, + self.local_shell_calls, self.mcp_approval_requests, ] ) @@ -371,11 +389,15 @@ def process_model_response( run_handoffs = [] functions = [] computer_actions = [] + local_shell_calls = [] mcp_approval_requests = [] tools_used: list[str] = [] handoff_map = {handoff.tool_name: handoff for handoff in handoffs} function_map = {tool.name: tool for tool in all_tools if isinstance(tool, FunctionTool)} computer_tool = next((tool for tool in all_tools if isinstance(tool, ComputerTool)), None) + local_shell_tool = next( + (tool for tool in all_tools if isinstance(tool, LocalShellTool)), None + ) hosted_mcp_server_map = { tool.tool_config["server_label"]: tool for tool in all_tools @@ -434,9 +456,29 @@ def process_model_response( ) elif isinstance(output, McpListTools): items.append(MCPListToolsItem(raw_item=output, agent=agent)) - elif isinstance(output, McpCall): + elif isinstance(output, ImageGenerationCall): + items.append(ToolCallItem(raw_item=output, agent=agent)) + tools_used.append("image_generation") + elif isinstance(output, ResponseCodeInterpreterToolCall): items.append(ToolCallItem(raw_item=output, agent=agent)) - tools_used.append(output.name) + tools_used.append("code_interpreter") + elif isinstance(output, LocalShellCall): + items.append(ToolCallItem(raw_item=output, agent=agent)) + tools_used.append("local_shell") + if not local_shell_tool: + _error_tracing.attach_error_to_current_span( + SpanError( + message="Local shell tool not found", + data={}, + ) + ) + raise ModelBehaviorError( + "Model produced local shell call without a local shell tool." + ) + local_shell_calls.append( + ToolRunLocalShellCall(tool_call=output, local_shell_tool=local_shell_tool) + ) + elif not isinstance(output, ResponseFunctionToolCall): logger.warning(f"Unexpected output type, ignoring: {type(output)}") continue @@ -478,6 +520,7 @@ def process_model_response( handoffs=run_handoffs, functions=functions, computer_actions=computer_actions, + local_shell_calls=local_shell_calls, tools_used=tools_used, mcp_approval_requests=mcp_approval_requests, ) @@ -552,6 +595,30 @@ async def run_single_tool( for tool_run, result in zip(tool_runs, results) ] + @classmethod + async def execute_local_shell_calls( + cls, + *, + agent: Agent[TContext], + calls: list[ToolRunLocalShellCall], + context_wrapper: RunContextWrapper[TContext], + hooks: RunHooks[TContext], + config: RunConfig, + ) -> list[RunItem]: + results: list[RunItem] = [] + # Need to run these serially, because each call can affect the local shell state + for call in calls: + results.append( + await LocalShellAction.execute( + agent=agent, + call=call, + hooks=hooks, + context_wrapper=context_wrapper, + config=config, + ) + ) + return results + @classmethod async def execute_computer_actions( cls, @@ -1021,3 +1088,54 @@ async def _get_screenshot_async( await computer.wait() return await computer.screenshot() + + +class LocalShellAction: + @classmethod + async def execute( + cls, + *, + agent: Agent[TContext], + call: ToolRunLocalShellCall, + hooks: RunHooks[TContext], + context_wrapper: RunContextWrapper[TContext], + config: RunConfig, + ) -> RunItem: + await asyncio.gather( + hooks.on_tool_start(context_wrapper, agent, call.local_shell_tool), + ( + agent.hooks.on_tool_start(context_wrapper, agent, call.local_shell_tool) + if agent.hooks + else _coro.noop_coroutine() + ), + ) + + request = LocalShellCommandRequest( + ctx_wrapper=context_wrapper, + data=call.tool_call, + ) + output = call.local_shell_tool.executor(request) + if inspect.isawaitable(output): + result = await output + else: + result = output + + await asyncio.gather( + hooks.on_tool_end(context_wrapper, agent, call.local_shell_tool, result), + ( + agent.hooks.on_tool_end(context_wrapper, agent, call.local_shell_tool, result) + if agent.hooks + else _coro.noop_coroutine() + ), + ) + + return ToolCallOutputItem( + agent=agent, + output=output, + raw_item={ + "type": "local_shell_call_output", + "id": call.tool_call.call_id, + "output": result, + # "id": "out" + call.tool_call.id, # TODO remove this, it should be optional + }, + ) diff --git a/src/agents/items.py b/src/agents/items.py index 65a911798..64797ad22 100644 --- a/src/agents/items.py +++ b/src/agents/items.py @@ -18,12 +18,22 @@ ResponseOutputText, ResponseStreamEvent, ) +from openai.types.responses.response_code_interpreter_tool_call import ( + ResponseCodeInterpreterToolCall, +) from openai.types.responses.response_input_item_param import ( ComputerCallOutput, FunctionCallOutput, + LocalShellCallOutput, McpApprovalResponse, ) -from openai.types.responses.response_output_item import McpApprovalRequest, McpCall, McpListTools +from openai.types.responses.response_output_item import ( + ImageGenerationCall, + LocalShellCall, + McpApprovalRequest, + McpCall, + McpListTools, +) from openai.types.responses.response_reasoning_item import ResponseReasoningItem from pydantic import BaseModel from typing_extensions import TypeAlias @@ -114,6 +124,9 @@ class HandoffOutputItem(RunItemBase[TResponseInputItem]): ResponseFileSearchToolCall, ResponseFunctionWebSearch, McpCall, + ResponseCodeInterpreterToolCall, + ImageGenerationCall, + LocalShellCall, ] """A type that represents a tool call item.""" @@ -129,10 +142,12 @@ class ToolCallItem(RunItemBase[ToolCallItemTypes]): @dataclass -class ToolCallOutputItem(RunItemBase[Union[FunctionCallOutput, ComputerCallOutput]]): +class ToolCallOutputItem( + RunItemBase[Union[FunctionCallOutput, ComputerCallOutput, LocalShellCallOutput]] +): """Represents the output of a tool call.""" - raw_item: FunctionCallOutput | ComputerCallOutput + raw_item: FunctionCallOutput | ComputerCallOutput | LocalShellCallOutput """The raw item from the model.""" output: Any diff --git a/src/agents/models/openai_responses.py b/src/agents/models/openai_responses.py index 65a4f5caf..86c8e69cb 100644 --- a/src/agents/models/openai_responses.py +++ b/src/agents/models/openai_responses.py @@ -24,7 +24,17 @@ from ..handoffs import Handoff from ..items import ItemHelpers, ModelResponse, TResponseInputItem from ..logger import logger -from ..tool import ComputerTool, FileSearchTool, FunctionTool, HostedMCPTool, Tool, WebSearchTool +from ..tool import ( + CodeInterpreterTool, + ComputerTool, + FileSearchTool, + FunctionTool, + HostedMCPTool, + ImageGenerationTool, + LocalShellTool, + Tool, + WebSearchTool, +) from ..tracing import SpanError, response_span from ..usage import Usage from ..version import __version__ @@ -295,6 +305,18 @@ def convert_tool_choice( return { "type": "computer_use_preview", } + elif tool_choice == "image_generation": + return { + "type": "image_generation", + } + elif tool_choice == "code_interpreter": + return { + "type": "code_interpreter", + } + elif tool_choice == "mcp": + return { + "type": "mcp", + } else: return { "type": "function", @@ -386,6 +408,17 @@ def _convert_tool(cls, tool: Tool) -> tuple[ToolParam, ResponseIncludable | None elif isinstance(tool, HostedMCPTool): converted_tool = tool.tool_config includes = None + elif isinstance(tool, ImageGenerationTool): + converted_tool = tool.tool_config + includes = None + elif isinstance(tool, CodeInterpreterTool): + converted_tool = tool.tool_config + includes = None + elif isinstance(tool, LocalShellTool): + converted_tool = { + "type": "local_shell", + } + includes = None else: raise UserError(f"Unknown tool type: {type(tool)}, tool") diff --git a/src/agents/tool.py b/src/agents/tool.py index 3bcd57c2e..fd5a21c89 100644 --- a/src/agents/tool.py +++ b/src/agents/tool.py @@ -7,8 +7,8 @@ from typing import Any, Callable, Literal, Union, overload from openai.types.responses.file_search_tool_param import Filters, RankingOptions -from openai.types.responses.response_output_item import McpApprovalRequest -from openai.types.responses.tool_param import Mcp +from openai.types.responses.response_output_item import LocalShellCall, McpApprovalRequest +from openai.types.responses.tool_param import CodeInterpreter, ImageGeneration, Mcp from openai.types.responses.web_search_tool_param import UserLocation from pydantic import ValidationError from typing_extensions import Concatenate, NotRequired, ParamSpec, TypedDict @@ -180,7 +180,67 @@ def name(self): return "hosted_mcp" -Tool = Union[FunctionTool, FileSearchTool, WebSearchTool, ComputerTool, HostedMCPTool] +@dataclass +class CodeInterpreterTool: + """A tool that allows the LLM to execute code in a sandboxed environment.""" + + tool_config: CodeInterpreter + """The tool config, which includes the container and other settings.""" + + @property + def name(self): + return "code_interpreter" + + +@dataclass +class ImageGenerationTool: + """A tool that allows the LLM to generate images.""" + + tool_config: ImageGeneration + """The tool config, which image generation settings.""" + + @property + def name(self): + return "image_generation" + + +@dataclass +class LocalShellCommandRequest: + """A request to execute a command on a shell.""" + + ctx_wrapper: RunContextWrapper[Any] + """The run context.""" + + data: LocalShellCall + """The data from the local shell tool call.""" + + +LocalShellExecutor = Callable[[LocalShellCommandRequest], MaybeAwaitable[str]] +"""A function that executes a command on a shell.""" + + +@dataclass +class LocalShellTool: + """A tool that allows the LLM to execute commands on a shell.""" + + executor: LocalShellExecutor + """A function that executes a command on a shell.""" + + @property + def name(self): + return "local_shell" + + +Tool = Union[ + FunctionTool, + FileSearchTool, + WebSearchTool, + ComputerTool, + HostedMCPTool, + LocalShellTool, + ImageGenerationTool, + CodeInterpreterTool, +] """A tool that can be used in an agent.""" @@ -358,13 +418,3 @@ def decorator(real_func: ToolFunction[...]) -> FunctionTool: return _create_function_tool(real_func) return decorator - return decorator - return decorator - return decorator - return decorator - return decorator - return decorator - return decorator - return decorator - return decorator - return decorator From 1992be3e8d1746164f0f47f3e2001de4ab4059b9 Mon Sep 17 00:00:00 2001 From: Rohan Mehta Date: Wed, 21 May 2025 16:06:13 -0400 Subject: [PATCH 26/33] v0.0.16 (#733) --- pyproject.toml | 2 +- uv.lock | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 200ac2485..38a2f2b64 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "openai-agents" -version = "0.0.15" +version = "0.0.16" description = "OpenAI Agents SDK" readme = "README.md" requires-python = ">=3.9" diff --git a/uv.lock b/uv.lock index 7a0cb1e6b..6f2f3f843 100644 --- a/uv.lock +++ b/uv.lock @@ -1480,7 +1480,7 @@ wheels = [ [[package]] name = "openai-agents" -version = "0.0.15" +version = "0.0.16" source = { editable = "." } dependencies = [ { name = "griffe" }, From 1364f4408e8e0d8bec1d6ac8337f68b3b888a4a7 Mon Sep 17 00:00:00 2001 From: Andrew Han Date: Wed, 21 May 2025 14:59:47 -0700 Subject: [PATCH 27/33] fix Gemini token validation issue with LiteLLM (#735) Fix for #734 --- src/agents/extensions/models/litellm_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/agents/extensions/models/litellm_model.py b/src/agents/extensions/models/litellm_model.py index ffb2c3c1c..49e2d42d7 100644 --- a/src/agents/extensions/models/litellm_model.py +++ b/src/agents/extensions/models/litellm_model.py @@ -111,12 +111,12 @@ async def get_response( input_tokens_details=InputTokensDetails( cached_tokens=getattr( response_usage.prompt_tokens_details, "cached_tokens", 0 - ) + ) or 0 ), output_tokens_details=OutputTokensDetails( reasoning_tokens=getattr( response_usage.completion_tokens_details, "reasoning_tokens", 0 - ) + ) or 0 ), ) if response.usage From db462e32a30182d1e9fffe7328609e05c1bc347f Mon Sep 17 00:00:00 2001 From: Rohan Mehta Date: Fri, 23 May 2025 13:00:10 -0400 Subject: [PATCH 28/33] Fix visualization recursion with cycle detection (#737) ## Summary - avoid infinite recursion in visualization by tracking visited agents - test cycle detection in graph utility ## Testing - `make mypy` - `make tests` Resolves #668 --- src/agents/extensions/visualization.py | 53 +++++++++++++++++--------- tests/test_visualization.py | 15 ++++++++ 2 files changed, 50 insertions(+), 18 deletions(-) diff --git a/src/agents/extensions/visualization.py b/src/agents/extensions/visualization.py index 888e262c3..be762a330 100644 --- a/src/agents/extensions/visualization.py +++ b/src/agents/extensions/visualization.py @@ -1,4 +1,4 @@ -from typing import Optional +from __future__ import annotations import graphviz # type: ignore @@ -31,7 +31,9 @@ def get_main_graph(agent: Agent) -> str: return "".join(parts) -def get_all_nodes(agent: Agent, parent: Optional[Agent] = None) -> str: +def get_all_nodes( + agent: Agent, parent: Agent | None = None, visited: set[str] | None = None +) -> str: """ Recursively generates the nodes for the given agent and its handoffs in DOT format. @@ -41,17 +43,23 @@ def get_all_nodes(agent: Agent, parent: Optional[Agent] = None) -> str: Returns: str: The DOT format string representing the nodes. """ + if visited is None: + visited = set() + if agent.name in visited: + return "" + visited.add(agent.name) + parts = [] # Start and end the graph - parts.append( - '"__start__" [label="__start__", shape=ellipse, style=filled, ' - "fillcolor=lightblue, width=0.5, height=0.3];" - '"__end__" [label="__end__", shape=ellipse, style=filled, ' - "fillcolor=lightblue, width=0.5, height=0.3];" - ) - # Ensure parent agent node is colored if not parent: + parts.append( + '"__start__" [label="__start__", shape=ellipse, style=filled, ' + "fillcolor=lightblue, width=0.5, height=0.3];" + '"__end__" [label="__end__", shape=ellipse, style=filled, ' + "fillcolor=lightblue, width=0.5, height=0.3];" + ) + # Ensure parent agent node is colored parts.append( f'"{agent.name}" [label="{agent.name}", shape=box, style=filled, ' "fillcolor=lightyellow, width=1.5, height=0.8];" @@ -71,17 +79,20 @@ def get_all_nodes(agent: Agent, parent: Optional[Agent] = None) -> str: f"fillcolor=lightyellow, width=1.5, height=0.8];" ) if isinstance(handoff, Agent): - parts.append( - f'"{handoff.name}" [label="{handoff.name}", ' - f"shape=box, style=filled, style=rounded, " - f"fillcolor=lightyellow, width=1.5, height=0.8];" - ) - parts.append(get_all_nodes(handoff)) + if handoff.name not in visited: + parts.append( + f'"{handoff.name}" [label="{handoff.name}", ' + f"shape=box, style=filled, style=rounded, " + f"fillcolor=lightyellow, width=1.5, height=0.8];" + ) + parts.append(get_all_nodes(handoff, agent, visited)) return "".join(parts) -def get_all_edges(agent: Agent, parent: Optional[Agent] = None) -> str: +def get_all_edges( + agent: Agent, parent: Agent | None = None, visited: set[str] | None = None +) -> str: """ Recursively generates the edges for the given agent and its handoffs in DOT format. @@ -92,6 +103,12 @@ def get_all_edges(agent: Agent, parent: Optional[Agent] = None) -> str: Returns: str: The DOT format string representing the edges. """ + if visited is None: + visited = set() + if agent.name in visited: + return "" + visited.add(agent.name) + parts = [] if not parent: @@ -109,7 +126,7 @@ def get_all_edges(agent: Agent, parent: Optional[Agent] = None) -> str: if isinstance(handoff, Agent): parts.append(f""" "{agent.name}" -> "{handoff.name}";""") - parts.append(get_all_edges(handoff, agent)) + parts.append(get_all_edges(handoff, agent, visited)) if not agent.handoffs and not isinstance(agent, Tool): # type: ignore parts.append(f'"{agent.name}" -> "__end__";') @@ -117,7 +134,7 @@ def get_all_edges(agent: Agent, parent: Optional[Agent] = None) -> str: return "".join(parts) -def draw_graph(agent: Agent, filename: Optional[str] = None) -> graphviz.Source: +def draw_graph(agent: Agent, filename: str | None = None) -> graphviz.Source: """ Draws the graph for the given agent and optionally saves it as a PNG file. diff --git a/tests/test_visualization.py b/tests/test_visualization.py index 6aa867743..8bce897e9 100644 --- a/tests/test_visualization.py +++ b/tests/test_visualization.py @@ -134,3 +134,18 @@ def test_draw_graph(mock_agent): '"Handoff1" [label="Handoff1", shape=box, style=filled, style=rounded, ' "fillcolor=lightyellow, width=1.5, height=0.8];" in graph.source ) + + +def test_cycle_detection(): + agent_a = Agent(name="A") + agent_b = Agent(name="B") + agent_a.handoffs.append(agent_b) + agent_b.handoffs.append(agent_a) + + nodes = get_all_nodes(agent_a) + edges = get_all_edges(agent_a) + + assert nodes.count('"A" [label="A"') == 1 + assert nodes.count('"B" [label="B"') == 1 + assert '"A" -> "B"' in edges + assert '"B" -> "A"' in edges From a96108e27915e55c133eaa66d155c19f043d7aca Mon Sep 17 00:00:00 2001 From: Rohan Mehta Date: Fri, 23 May 2025 13:00:21 -0400 Subject: [PATCH 29/33] Update MCP and tool docs (#736) ## Summary - mention MCPServerStreamableHttp in MCP server docs - document CodeInterpreterTool, HostedMCPTool, ImageGenerationTool and LocalShellTool - update Japanese translations --- docs/ja/mcp.md | 9 +++++---- docs/ja/tools.md | 4 ++++ docs/mcp.md | 5 +++-- docs/tools.md | 4 ++++ 4 files changed, 16 insertions(+), 6 deletions(-) diff --git a/docs/ja/mcp.md b/docs/ja/mcp.md index 7cdaa57ee..09804beb2 100644 --- a/docs/ja/mcp.md +++ b/docs/ja/mcp.md @@ -12,12 +12,13 @@ Agents SDK は MCP をサポートしており、これにより幅広い MCP ## MCP サーバー -現在、MCP 仕様では使用するトランスポート方式に基づき 2 種類のサーバーが定義されています。 +現在、MCP 仕様では使用するトランスポート方式に基づき 3 種類のサーバーが定義されています。 -1. **stdio** サーバー: アプリケーションのサブプロセスとして実行されます。ローカルで動かすイメージです。 +1. **stdio** サーバー: アプリケーションのサブプロセスとして実行されます。ローカルで動かすイメージです。 2. **HTTP over SSE** サーバー: リモートで動作し、 URL 経由で接続します。 +3. **Streamable HTTP** サーバー: MCP 仕様に定義された Streamable HTTP トランスポートを使用してリモートで動作します。 -これらのサーバーへは [`MCPServerStdio`][agents.mcp.server.MCPServerStdio] と [`MCPServerSse`][agents.mcp.server.MCPServerSse] クラスを使用して接続できます。 +これらのサーバーへは [`MCPServerStdio`][agents.mcp.server.MCPServerStdio]、[`MCPServerSse`][agents.mcp.server.MCPServerSse]、[`MCPServerStreamableHttp`][agents.mcp.server.MCPServerStreamableHttp] クラスを使用して接続できます。 たとえば、[公式 MCP filesystem サーバー](https://www.npmjs.com/package/@modelcontextprotocol/server-filesystem)を利用する場合は次のようになります。 @@ -46,7 +47,7 @@ agent=Agent( ## キャッシュ -エージェントが実行されるたびに、MCP サーバーへ `list_tools()` が呼び出されます。サーバーがリモートの場合は特にレイテンシが発生します。ツール一覧を自動でキャッシュしたい場合は、[`MCPServerStdio`][agents.mcp.server.MCPServerStdio] と [`MCPServerSse`][agents.mcp.server.MCPServerSse] の両方に `cache_tools_list=True` を渡してください。ツール一覧が変更されないと確信できる場合のみ使用してください。 +エージェントが実行されるたびに、MCP サーバーへ `list_tools()` が呼び出されます。サーバーがリモートの場合は特にレイテンシが発生します。ツール一覧を自動でキャッシュしたい場合は、[`MCPServerStdio`][agents.mcp.server.MCPServerStdio]、[`MCPServerSse`][agents.mcp.server.MCPServerSse]、[`MCPServerStreamableHttp`][agents.mcp.server.MCPServerStreamableHttp] の各クラスに `cache_tools_list=True` を渡してください。ツール一覧が変更されないと確信できる場合のみ使用してください。 キャッシュを無効化したい場合は、サーバーで `invalidate_tools_cache()` を呼び出します。 diff --git a/docs/ja/tools.md b/docs/ja/tools.md index 7ab15e472..cd80092d5 100644 --- a/docs/ja/tools.md +++ b/docs/ja/tools.md @@ -17,6 +17,10 @@ OpenAI は [`OpenAIResponsesModel`][agents.models.openai_responses.OpenAIRespons - [`WebSearchTool`][agents.tool.WebSearchTool] はエージェントに Web 検索を行わせます。 - [`FileSearchTool`][agents.tool.FileSearchTool] は OpenAI ベクトルストアから情報を取得します。 - [`ComputerTool`][agents.tool.ComputerTool] はコンピュータ操作タスクを自動化します。 +- [`CodeInterpreterTool`][agents.tool.CodeInterpreterTool] はサンドボックス環境でコードを実行します。 +- [`HostedMCPTool`][agents.tool.HostedMCPTool] はリモート MCP サーバーのツールをモデルから直接利用できるようにします。 +- [`ImageGenerationTool`][agents.tool.ImageGenerationTool] はプロンプトから画像を生成します。 +- [`LocalShellTool`][agents.tool.LocalShellTool] はローカルマシンでシェルコマンドを実行します。 ```python from agents import Agent, FileSearchTool, Runner, WebSearchTool diff --git a/docs/mcp.md b/docs/mcp.md index e279a25e0..2cd0aad9e 100644 --- a/docs/mcp.md +++ b/docs/mcp.md @@ -12,8 +12,9 @@ Currently, the MCP spec defines two kinds of servers, based on the transport mec 1. **stdio** servers run as a subprocess of your application. You can think of them as running "locally". 2. **HTTP over SSE** servers run remotely. You connect to them via a URL. +3. **Streamable HTTP** servers run remotely using the Streamable HTTP transport defined in the MCP spec. -You can use the [`MCPServerStdio`][agents.mcp.server.MCPServerStdio] and [`MCPServerSse`][agents.mcp.server.MCPServerSse] classes to connect to these servers. +You can use the [`MCPServerStdio`][agents.mcp.server.MCPServerStdio], [`MCPServerSse`][agents.mcp.server.MCPServerSse], and [`MCPServerStreamableHttp`][agents.mcp.server.MCPServerStreamableHttp] classes to connect to these servers. For example, this is how you'd use the [official MCP filesystem server](https://www.npmjs.com/package/@modelcontextprotocol/server-filesystem). @@ -42,7 +43,7 @@ agent=Agent( ## Caching -Every time an Agent runs, it calls `list_tools()` on the MCP server. This can be a latency hit, especially if the server is a remote server. To automatically cache the list of tools, you can pass `cache_tools_list=True` to both [`MCPServerStdio`][agents.mcp.server.MCPServerStdio] and [`MCPServerSse`][agents.mcp.server.MCPServerSse]. You should only do this if you're certain the tool list will not change. +Every time an Agent runs, it calls `list_tools()` on the MCP server. This can be a latency hit, especially if the server is a remote server. To automatically cache the list of tools, you can pass `cache_tools_list=True` to [`MCPServerStdio`][agents.mcp.server.MCPServerStdio], [`MCPServerSse`][agents.mcp.server.MCPServerSse], and [`MCPServerStreamableHttp`][agents.mcp.server.MCPServerStreamableHttp]. You should only do this if you're certain the tool list will not change. If you want to invalidate the cache, you can call `invalidate_tools_cache()` on the servers. diff --git a/docs/tools.md b/docs/tools.md index 5fe2ecedb..89e28d998 100644 --- a/docs/tools.md +++ b/docs/tools.md @@ -13,6 +13,10 @@ OpenAI offers a few built-in tools when using the [`OpenAIResponsesModel`][agent - The [`WebSearchTool`][agents.tool.WebSearchTool] lets an agent search the web. - The [`FileSearchTool`][agents.tool.FileSearchTool] allows retrieving information from your OpenAI Vector Stores. - The [`ComputerTool`][agents.tool.ComputerTool] allows automating computer use tasks. +- The [`CodeInterpreterTool`][agents.tool.CodeInterpreterTool] lets the LLM execute code in a sandboxed environment. +- The [`HostedMCPTool`][agents.tool.HostedMCPTool] exposes a remote MCP server's tools to the model. +- The [`ImageGenerationTool`][agents.tool.ImageGenerationTool] generates images from a prompt. +- The [`LocalShellTool`][agents.tool.LocalShellTool] runs shell commands on your machine. ```python from agents import Agent, FileSearchTool, Runner, WebSearchTool From 6e078bf7a99badc5a36be961e4e6dd6dde4c3674 Mon Sep 17 00:00:00 2001 From: Rohan Mehta Date: Fri, 23 May 2025 13:00:30 -0400 Subject: [PATCH 30/33] Fix Gemini API content filter handling (#746) ## Summary - avoid AttributeError when Gemini API returns `None` for chat message - return empty output if message is filtered - add regression test ## Testing - `make format` - `make lint` - `make mypy` - `make tests` Towards #744 --- src/agents/models/openai_chatcompletions.py | 22 ++++++++++--- tests/test_openai_chatcompletions.py | 34 +++++++++++++++++++++ 2 files changed, 51 insertions(+), 5 deletions(-) diff --git a/src/agents/models/openai_chatcompletions.py b/src/agents/models/openai_chatcompletions.py index 4465ff2fd..6b4045d21 100644 --- a/src/agents/models/openai_chatcompletions.py +++ b/src/agents/models/openai_chatcompletions.py @@ -71,12 +71,22 @@ async def get_response( stream=False, ) + first_choice = response.choices[0] + message = first_choice.message + if _debug.DONT_LOG_MODEL_DATA: logger.debug("Received model response") else: - logger.debug( - f"LLM resp:\n{json.dumps(response.choices[0].message.model_dump(), indent=2)}\n" - ) + if message is not None: + logger.debug( + "LLM resp:\n%s\n", + json.dumps(message.model_dump(), indent=2), + ) + else: + logger.debug( + "LLM resp had no message. finish_reason: %s", + first_choice.finish_reason, + ) usage = ( Usage( @@ -101,13 +111,15 @@ async def get_response( else Usage() ) if tracing.include_data(): - span_generation.span_data.output = [response.choices[0].message.model_dump()] + span_generation.span_data.output = ( + [message.model_dump()] if message is not None else [] + ) span_generation.span_data.usage = { "input_tokens": usage.input_tokens, "output_tokens": usage.output_tokens, } - items = Converter.message_to_output_items(response.choices[0].message) + items = Converter.message_to_output_items(message) if message is not None else [] return ModelResponse( output=items, diff --git a/tests/test_openai_chatcompletions.py b/tests/test_openai_chatcompletions.py index ba4605d08..9a85dcb7b 100644 --- a/tests/test_openai_chatcompletions.py +++ b/tests/test_openai_chatcompletions.py @@ -191,6 +191,40 @@ async def patched_fetch_response(self, *args, **kwargs): assert fn_call_item.arguments == "{'x':1}" +@pytest.mark.allow_call_model_methods +@pytest.mark.asyncio +async def test_get_response_with_no_message(monkeypatch) -> None: + """If the model returns no message, get_response should return an empty output.""" + msg = ChatCompletionMessage(role="assistant", content="ignored") + choice = Choice(index=0, finish_reason="content_filter", message=msg) + choice.message = None # type: ignore[assignment] + chat = ChatCompletion( + id="resp-id", + created=0, + model="fake", + object="chat.completion", + choices=[choice], + usage=None, + ) + + async def patched_fetch_response(self, *args, **kwargs): + return chat + + monkeypatch.setattr(OpenAIChatCompletionsModel, "_fetch_response", patched_fetch_response) + model = OpenAIProvider(use_responses=False).get_model("gpt-4") + resp: ModelResponse = await model.get_response( + system_instructions=None, + input="", + model_settings=ModelSettings(), + tools=[], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + previous_response_id=None, + ) + assert resp.output == [] + + @pytest.mark.asyncio async def test_fetch_response_non_stream(monkeypatch) -> None: """ From d46e2ec35baeec3c84a873438eb341ed64a6eede Mon Sep 17 00:00:00 2001 From: siddharth Sambharia Date: Thu, 29 May 2025 21:13:25 +0530 Subject: [PATCH 31/33] Add Portkey AI as a tracing provider (#785) This PR adds Portkey AI as a tracing provider. Portkey helps you take your OpenAI agents from prototype to production. Portkey turns your experimental OpenAI Agents into production-ready systems by providing: - Complete observability of every agent step, tool use, and interaction - Built-in reliability with fallbacks, retries, and load balancing - Cost tracking and optimization to manage your AI spend - Access to 1600+ LLMs through a single integration - Guardrails to keep agent behavior safe and compliant - Version-controlled prompts for consistent agent performance Towards #786 --- docs/ja/tracing.md | 3 ++- docs/tracing.md | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/ja/tracing.md b/docs/ja/tracing.md index 0e0d0e77d..d67200ce4 100644 --- a/docs/ja/tracing.md +++ b/docs/ja/tracing.md @@ -119,4 +119,5 @@ async def main(): - [Comet Opik](https://www.comet.com/docs/opik/tracing/integrations/openai_agents) - [Langfuse](https://langfuse.com/docs/integrations/openaiagentssdk/openai-agents) - [Langtrace](https://docs.langtrace.ai/supported-integrations/llm-frameworks/openai-agents-sdk) -- [Okahu‑Monocle](https://github.com/monocle2ai/monocle) \ No newline at end of file +- [Okahu‑Monocle](https://github.com/monocle2ai/monocle) +- [Portkey AI](https://portkey.ai/docs/integrations/agents/openai-agents) diff --git a/docs/tracing.md b/docs/tracing.md index 4a9c1bd90..c7776ad7b 100644 --- a/docs/tracing.md +++ b/docs/tracing.md @@ -116,3 +116,4 @@ To customize this default setup, to send traces to alternative or additional bac - [Langtrace](https://docs.langtrace.ai/supported-integrations/llm-frameworks/openai-agents-sdk) - [Okahu-Monocle](https://github.com/monocle2ai/monocle) - [Galileo](https://v2docs.galileo.ai/integrations/openai-agent-integration#openai-agent-integration) +- [Portkey AI](https://portkey.ai/docs/integrations/agents/openai-agents) From 71968625ccafe62f7015dc6ab5f8eaea81a95bd4 Mon Sep 17 00:00:00 2001 From: Daniele Morotti <58258368+DanieleMorotti@users.noreply.github.com> Date: Thu, 29 May 2025 22:11:33 +0200 Subject: [PATCH 32/33] Added RunErrorDetails object for MaxTurnsExceeded exception (#743) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Summary Introduced the `RunErrorDetails` object to get partial results from a run interrupted by `MaxTurnsExceeded` exception. In this proposal the `RunErrorDetails` object contains all the fields from `RunResult` with `final_output` set to `None` and `output_guardrail_results` set to an empty list. We can decide to return less information. @rm-openai At the moment the exception doesn't return the `RunErrorDetails` object for the streaming mode. Do you have any suggestions on how to deal with it? In the `_check_errors` function of `agents/result.py` file. ### Test plan I have not implemented any tests currently, but if needed I can implement a basic test to retrieve partial data. ### Issue number This PR is an attempt to solve issue #719 ### Checks - [✅ ] I've added new tests (if relevant) - [ ] I've added/updated the relevant documentation - [ ✅] I've run `make lint` and `make format` - [ ✅] I've made sure tests pass --- src/agents/__init__.py | 2 + src/agents/exceptions.py | 41 ++++++++++++++++--- src/agents/result.py | 57 +++++++++++++++++++++------ src/agents/run.py | 30 +++++++++++++- src/agents/util/_pretty_print.py | 12 ++++++ tests/test_run_error_details.py | 44 +++++++++++++++++++++ tests/test_tracing_errors_streamed.py | 4 -- 7 files changed, 167 insertions(+), 23 deletions(-) create mode 100644 tests/test_run_error_details.py diff --git a/src/agents/__init__.py b/src/agents/__init__.py index 58949157a..820616437 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -14,6 +14,7 @@ MaxTurnsExceeded, ModelBehaviorError, OutputGuardrailTripwireTriggered, + RunErrorDetails, UserError, ) from .guardrail import ( @@ -204,6 +205,7 @@ def enable_verbose_stdout_logging(): "AgentHooks", "RunContextWrapper", "TContext", + "RunErrorDetails", "RunResult", "RunResultStreaming", "RunConfig", diff --git a/src/agents/exceptions.py b/src/agents/exceptions.py index 78898f017..4f6e2e768 100644 --- a/src/agents/exceptions.py +++ b/src/agents/exceptions.py @@ -1,11 +1,39 @@ -from typing import TYPE_CHECKING +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: + from .agent import Agent from .guardrail import InputGuardrailResult, OutputGuardrailResult + from .items import ModelResponse, RunItem, TResponseInputItem + from .run_context import RunContextWrapper + +from .util._pretty_print import pretty_print_run_error_details + + +@dataclass +class RunErrorDetails: + """Data collected from an agent run when an exception occurs.""" + input: str | list[TResponseInputItem] + new_items: list[RunItem] + raw_responses: list[ModelResponse] + last_agent: Agent[Any] + context_wrapper: RunContextWrapper[Any] + input_guardrail_results: list[InputGuardrailResult] + output_guardrail_results: list[OutputGuardrailResult] + + def __str__(self) -> str: + return pretty_print_run_error_details(self) class AgentsException(Exception): """Base class for all exceptions in the Agents SDK.""" + run_data: RunErrorDetails | None + + def __init__(self, *args: object) -> None: + super().__init__(*args) + self.run_data = None class MaxTurnsExceeded(AgentsException): @@ -15,6 +43,7 @@ class MaxTurnsExceeded(AgentsException): def __init__(self, message: str): self.message = message + super().__init__(message) class ModelBehaviorError(AgentsException): @@ -26,6 +55,7 @@ class ModelBehaviorError(AgentsException): def __init__(self, message: str): self.message = message + super().__init__(message) class UserError(AgentsException): @@ -35,15 +65,16 @@ class UserError(AgentsException): def __init__(self, message: str): self.message = message + super().__init__(message) class InputGuardrailTripwireTriggered(AgentsException): """Exception raised when a guardrail tripwire is triggered.""" - guardrail_result: "InputGuardrailResult" + guardrail_result: InputGuardrailResult """The result data of the guardrail that was triggered.""" - def __init__(self, guardrail_result: "InputGuardrailResult"): + def __init__(self, guardrail_result: InputGuardrailResult): self.guardrail_result = guardrail_result super().__init__( f"Guardrail {guardrail_result.guardrail.__class__.__name__} triggered tripwire" @@ -53,10 +84,10 @@ def __init__(self, guardrail_result: "InputGuardrailResult"): class OutputGuardrailTripwireTriggered(AgentsException): """Exception raised when a guardrail tripwire is triggered.""" - guardrail_result: "OutputGuardrailResult" + guardrail_result: OutputGuardrailResult """The result data of the guardrail that was triggered.""" - def __init__(self, guardrail_result: "OutputGuardrailResult"): + def __init__(self, guardrail_result: OutputGuardrailResult): self.guardrail_result = guardrail_result super().__init__( f"Guardrail {guardrail_result.guardrail.__class__.__name__} triggered tripwire" diff --git a/src/agents/result.py b/src/agents/result.py index 243db155c..764815246 100644 --- a/src/agents/result.py +++ b/src/agents/result.py @@ -11,14 +11,22 @@ from ._run_impl import QueueCompleteSentinel from .agent import Agent from .agent_output import AgentOutputSchemaBase -from .exceptions import InputGuardrailTripwireTriggered, MaxTurnsExceeded +from .exceptions import ( + AgentsException, + InputGuardrailTripwireTriggered, + MaxTurnsExceeded, + RunErrorDetails, +) from .guardrail import InputGuardrailResult, OutputGuardrailResult from .items import ItemHelpers, ModelResponse, RunItem, TResponseInputItem from .logger import logger from .run_context import RunContextWrapper from .stream_events import StreamEvent from .tracing import Trace -from .util._pretty_print import pretty_print_result, pretty_print_run_result_streaming +from .util._pretty_print import ( + pretty_print_result, + pretty_print_run_result_streaming, +) if TYPE_CHECKING: from ._run_impl import QueueCompleteSentinel @@ -206,31 +214,53 @@ async def stream_events(self) -> AsyncIterator[StreamEvent]: if self._stored_exception: raise self._stored_exception + def _create_error_details(self) -> RunErrorDetails: + """Return a `RunErrorDetails` object considering the current attributes of the class.""" + return RunErrorDetails( + input=self.input, + new_items=self.new_items, + raw_responses=self.raw_responses, + last_agent=self.current_agent, + context_wrapper=self.context_wrapper, + input_guardrail_results=self.input_guardrail_results, + output_guardrail_results=self.output_guardrail_results, + ) + def _check_errors(self): if self.current_turn > self.max_turns: - self._stored_exception = MaxTurnsExceeded(f"Max turns ({self.max_turns}) exceeded") + max_turns_exc = MaxTurnsExceeded(f"Max turns ({self.max_turns}) exceeded") + max_turns_exc.run_data = self._create_error_details() + self._stored_exception = max_turns_exc # Fetch all the completed guardrail results from the queue and raise if needed while not self._input_guardrail_queue.empty(): guardrail_result = self._input_guardrail_queue.get_nowait() if guardrail_result.output.tripwire_triggered: - self._stored_exception = InputGuardrailTripwireTriggered(guardrail_result) + tripwire_exc = InputGuardrailTripwireTriggered(guardrail_result) + tripwire_exc.run_data = self._create_error_details() + self._stored_exception = tripwire_exc # Check the tasks for any exceptions if self._run_impl_task and self._run_impl_task.done(): - exc = self._run_impl_task.exception() - if exc and isinstance(exc, Exception): - self._stored_exception = exc + run_impl_exc = self._run_impl_task.exception() + if run_impl_exc and isinstance(run_impl_exc, Exception): + if isinstance(run_impl_exc, AgentsException) and run_impl_exc.run_data is None: + run_impl_exc.run_data = self._create_error_details() + self._stored_exception = run_impl_exc if self._input_guardrails_task and self._input_guardrails_task.done(): - exc = self._input_guardrails_task.exception() - if exc and isinstance(exc, Exception): - self._stored_exception = exc + in_guard_exc = self._input_guardrails_task.exception() + if in_guard_exc and isinstance(in_guard_exc, Exception): + if isinstance(in_guard_exc, AgentsException) and in_guard_exc.run_data is None: + in_guard_exc.run_data = self._create_error_details() + self._stored_exception = in_guard_exc if self._output_guardrails_task and self._output_guardrails_task.done(): - exc = self._output_guardrails_task.exception() - if exc and isinstance(exc, Exception): - self._stored_exception = exc + out_guard_exc = self._output_guardrails_task.exception() + if out_guard_exc and isinstance(out_guard_exc, Exception): + if isinstance(out_guard_exc, AgentsException) and out_guard_exc.run_data is None: + out_guard_exc.run_data = self._create_error_details() + self._stored_exception = out_guard_exc def _cleanup_tasks(self): if self._run_impl_task and not self._run_impl_task.done(): @@ -244,3 +274,4 @@ def _cleanup_tasks(self): def __str__(self) -> str: return pretty_print_run_result_streaming(self) + diff --git a/src/agents/run.py b/src/agents/run.py index b196c3bf1..c67386495 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -1,3 +1,4 @@ + from __future__ import annotations import asyncio @@ -26,6 +27,7 @@ MaxTurnsExceeded, ModelBehaviorError, OutputGuardrailTripwireTriggered, + RunErrorDetails, ) from .guardrail import InputGuardrail, InputGuardrailResult, OutputGuardrail, OutputGuardrailResult from .handoffs import Handoff, HandoffInputFilter, handoff @@ -208,7 +210,9 @@ async def run( data={"max_turns": max_turns}, ), ) - raise MaxTurnsExceeded(f"Max turns ({max_turns}) exceeded") + raise MaxTurnsExceeded( + f"Max turns ({max_turns}) exceeded" + ) logger.debug( f"Running agent {current_agent.name} (turn {current_turn})", @@ -283,6 +287,17 @@ async def run( raise AgentsException( f"Unknown next step type: {type(turn_result.next_step)}" ) + except AgentsException as exc: + exc.run_data = RunErrorDetails( + input=original_input, + new_items=generated_items, + raw_responses=model_responses, + last_agent=current_agent, + context_wrapper=context_wrapper, + input_guardrail_results=input_guardrail_results, + output_guardrail_results=[] + ) + raise finally: if current_span: current_span.finish(reset_current=True) @@ -609,6 +624,19 @@ async def _run_streamed_impl( streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) elif isinstance(turn_result.next_step, NextStepRunAgain): pass + except AgentsException as exc: + streamed_result.is_complete = True + streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) + exc.run_data = RunErrorDetails( + input=streamed_result.input, + new_items=streamed_result.new_items, + raw_responses=streamed_result.raw_responses, + last_agent=current_agent, + context_wrapper=context_wrapper, + input_guardrail_results=streamed_result.input_guardrail_results, + output_guardrail_results=streamed_result.output_guardrail_results, + ) + raise except Exception as e: if current_span: _error_tracing.attach_error_to_span( diff --git a/src/agents/util/_pretty_print.py b/src/agents/util/_pretty_print.py index afd3e2b1b..29df3562e 100644 --- a/src/agents/util/_pretty_print.py +++ b/src/agents/util/_pretty_print.py @@ -3,6 +3,7 @@ from pydantic import BaseModel if TYPE_CHECKING: + from ..exceptions import RunErrorDetails from ..result import RunResult, RunResultBase, RunResultStreaming @@ -38,6 +39,17 @@ def pretty_print_result(result: "RunResult") -> str: return output +def pretty_print_run_error_details(result: "RunErrorDetails") -> str: + output = "RunErrorDetails:" + output += f'\n- Last agent: Agent(name="{result.last_agent.name}", ...)' + output += f"\n- {len(result.new_items)} new item(s)" + output += f"\n- {len(result.raw_responses)} raw response(s)" + output += f"\n- {len(result.input_guardrail_results)} input guardrail result(s)" + output += "\n(See `RunErrorDetails` for more details)" + + return output + + def pretty_print_run_result_streaming(result: "RunResultStreaming") -> str: output = "RunResultStreaming:" output += f'\n- Current agent: Agent(name="{result.current_agent.name}", ...)' diff --git a/tests/test_run_error_details.py b/tests/test_run_error_details.py new file mode 100644 index 000000000..2268b3780 --- /dev/null +++ b/tests/test_run_error_details.py @@ -0,0 +1,44 @@ +import json + +import pytest + +from agents import Agent, MaxTurnsExceeded, RunErrorDetails, Runner + +from .fake_model import FakeModel +from .test_responses import get_function_tool, get_function_tool_call, get_text_message + + +@pytest.mark.asyncio +async def test_run_error_includes_data(): + model = FakeModel() + agent = Agent(name="test", model=model, tools=[get_function_tool("foo", "res")]) + model.add_multiple_turn_outputs([ + [get_text_message("1"), get_function_tool_call("foo", json.dumps({"a": "b"}))], + [get_text_message("done")], + ]) + with pytest.raises(MaxTurnsExceeded) as exc: + await Runner.run(agent, input="hello", max_turns=1) + data = exc.value.run_data + assert isinstance(data, RunErrorDetails) + assert data.last_agent == agent + assert len(data.raw_responses) == 1 + assert len(data.new_items) > 0 + + +@pytest.mark.asyncio +async def test_streamed_run_error_includes_data(): + model = FakeModel() + agent = Agent(name="test", model=model, tools=[get_function_tool("foo", "res")]) + model.add_multiple_turn_outputs([ + [get_text_message("1"), get_function_tool_call("foo", json.dumps({"a": "b"}))], + [get_text_message("done")], + ]) + result = Runner.run_streamed(agent, input="hello", max_turns=1) + with pytest.raises(MaxTurnsExceeded) as exc: + async for _ in result.stream_events(): + pass + data = exc.value.run_data + assert isinstance(data, RunErrorDetails) + assert data.last_agent == agent + assert len(data.raw_responses) == 1 + assert len(data.new_items) > 0 diff --git a/tests/test_tracing_errors_streamed.py b/tests/test_tracing_errors_streamed.py index 416793e70..40efef3fa 100644 --- a/tests/test_tracing_errors_streamed.py +++ b/tests/test_tracing_errors_streamed.py @@ -168,10 +168,6 @@ async def test_tool_call_error(): "children": [ { "type": "agent", - "error": { - "message": "Error in agent run", - "data": {"error": "Invalid JSON input for tool foo: bad_json"}, - }, "data": { "name": "test_agent", "handoffs": [], From 47fa8e87b1b1553f1733dd1909b784cedbc98872 Mon Sep 17 00:00:00 2001 From: Sarmad Gulzar Date: Fri, 30 May 2025 02:24:31 +0500 Subject: [PATCH 33/33] Fixed Python syntax (#665) --- docs/tools.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/tools.md b/docs/tools.md index 89e28d998..4e9a20d32 100644 --- a/docs/tools.md +++ b/docs/tools.md @@ -270,7 +270,7 @@ The `agent.as_tool` function is a convenience method to make it easy to turn an ```python @function_tool async def run_my_agent() -> str: - """A tool that runs the agent with custom configs". + """A tool that runs the agent with custom configs""" agent = Agent(name="My agent", instructions="...")