From c1466c62167edbb6a10e4cb6381ff47a00708436 Mon Sep 17 00:00:00 2001 From: Rohan Mehta Date: Tue, 24 Jun 2025 10:53:33 -0400 Subject: [PATCH 01/23] Add is_enabled to handoffs (#925) Was added to function tools before, now handoffs. Towards #918 --- src/agents/handoffs.py | 14 +++++ src/agents/run.py | 31 +++++++-- tests/test_agent_config.py | 6 +- tests/test_handoff_tool.py | 100 +++++++++++++++++++++++++++--- tests/test_run_step_execution.py | 2 +- tests/test_run_step_processing.py | 8 +-- 6 files changed, 140 insertions(+), 21 deletions(-) diff --git a/src/agents/handoffs.py b/src/agents/handoffs.py index 76c93a298..cb2752e4f 100644 --- a/src/agents/handoffs.py +++ b/src/agents/handoffs.py @@ -15,6 +15,7 @@ from .strict_schema import ensure_strict_json_schema from .tracing.spans import SpanError from .util import _error_tracing, _json, _transforms +from .util._types import MaybeAwaitable if TYPE_CHECKING: from .agent import Agent @@ -99,6 +100,11 @@ class Handoff(Generic[TContext]): True, as it increases the likelihood of correct JSON input. """ + is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True + """Whether the handoff is enabled. Either a bool or a Callable that takes the run context and + agent and returns whether the handoff is enabled. You can use this to dynamically enable/disable + a handoff based on your context/state.""" + def get_transfer_message(self, agent: Agent[Any]) -> str: return json.dumps({"assistant": agent.name}) @@ -121,6 +127,7 @@ def handoff( tool_name_override: str | None = None, tool_description_override: str | None = None, input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None, + is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True, ) -> Handoff[TContext]: ... @@ -133,6 +140,7 @@ def handoff( tool_description_override: str | None = None, tool_name_override: str | None = None, input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None, + is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True, ) -> Handoff[TContext]: ... @@ -144,6 +152,7 @@ def handoff( tool_description_override: str | None = None, tool_name_override: str | None = None, input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None, + is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True, ) -> Handoff[TContext]: ... @@ -154,6 +163,7 @@ def handoff( on_handoff: OnHandoffWithInput[THandoffInput] | OnHandoffWithoutInput | None = None, input_type: type[THandoffInput] | None = None, input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None, + is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True, ) -> Handoff[TContext]: """Create a handoff from an agent. @@ -166,6 +176,9 @@ def handoff( input_type: the type of the input to the handoff. If provided, the input will be validated against this type. Only relevant if you pass a function that takes an input. input_filter: a function that filters the inputs that are passed to the next agent. + is_enabled: Whether the handoff is enabled. Can be a bool or a callable that takes the run + context and agent and returns whether the handoff is enabled. Disabled handoffs are + hidden from the LLM at runtime. """ assert (on_handoff and input_type) or not (on_handoff and input_type), ( "You must provide either both on_handoff and input_type, or neither" @@ -233,4 +246,5 @@ async def _invoke_handoff( on_invoke_handoff=_invoke_handoff, input_filter=input_filter, agent_name=agent.name, + is_enabled=is_enabled, ) diff --git a/src/agents/run.py b/src/agents/run.py index 8a44a0e54..e5f9378ec 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -2,6 +2,7 @@ import asyncio import copy +import inspect from dataclasses import dataclass, field from typing import Any, Generic, cast @@ -361,7 +362,8 @@ async def run( # agent changes, or if the agent loop ends. if current_span is None: handoff_names = [ - h.agent_name for h in AgentRunner._get_handoffs(current_agent) + h.agent_name + for h in await AgentRunner._get_handoffs(current_agent, context_wrapper) ] if output_schema := AgentRunner._get_output_schema(current_agent): output_type_name = output_schema.name() @@ -641,7 +643,10 @@ async def _start_streaming( # Start an agent span if we don't have one. This span is ended if the current # agent changes, or if the agent loop ends. if current_span is None: - handoff_names = [h.agent_name for h in cls._get_handoffs(current_agent)] + handoff_names = [ + h.agent_name + for h in await cls._get_handoffs(current_agent, context_wrapper) + ] if output_schema := cls._get_output_schema(current_agent): output_type_name = output_schema.name() else: @@ -798,7 +803,7 @@ async def _run_single_turn_streamed( agent.get_prompt(context_wrapper), ) - handoffs = cls._get_handoffs(agent) + handoffs = await cls._get_handoffs(agent, context_wrapper) model = cls._get_model(agent, run_config) model_settings = agent.model_settings.resolve(run_config.model_settings) model_settings = RunImpl.maybe_reset_tool_choice(agent, tool_use_tracker, model_settings) @@ -898,7 +903,7 @@ async def _run_single_turn( ) output_schema = cls._get_output_schema(agent) - handoffs = cls._get_handoffs(agent) + handoffs = await cls._get_handoffs(agent, context_wrapper) input = ItemHelpers.input_to_new_input_list(original_input) input.extend([generated_item.to_input_item() for generated_item in generated_items]) @@ -1091,14 +1096,28 @@ def _get_output_schema(cls, agent: Agent[Any]) -> AgentOutputSchemaBase | None: return AgentOutputSchema(agent.output_type) @classmethod - def _get_handoffs(cls, agent: Agent[Any]) -> list[Handoff]: + async def _get_handoffs( + cls, agent: Agent[Any], context_wrapper: RunContextWrapper[Any] + ) -> list[Handoff]: handoffs = [] for handoff_item in agent.handoffs: if isinstance(handoff_item, Handoff): handoffs.append(handoff_item) elif isinstance(handoff_item, Agent): handoffs.append(handoff(handoff_item)) - return handoffs + + async def _check_handoff_enabled(handoff_obj: Handoff) -> bool: + attr = handoff_obj.is_enabled + if isinstance(attr, bool): + return attr + res = attr(context_wrapper, agent) + if inspect.isawaitable(res): + return bool(await res) + return bool(res) + + results = await asyncio.gather(*(_check_handoff_enabled(h) for h in handoffs)) + enabled: list[Handoff] = [h for h, ok in zip(handoffs, results) if ok] + return enabled @classmethod async def _get_all_tools( diff --git a/tests/test_agent_config.py b/tests/test_agent_config.py index f9423619d..a985fd60d 100644 --- a/tests/test_agent_config.py +++ b/tests/test_agent_config.py @@ -43,7 +43,7 @@ async def test_handoff_with_agents(): handoffs=[agent_1, agent_2], ) - handoffs = AgentRunner._get_handoffs(agent_3) + handoffs = await AgentRunner._get_handoffs(agent_3, RunContextWrapper(None)) assert len(handoffs) == 2 assert handoffs[0].agent_name == "agent_1" @@ -78,7 +78,7 @@ async def test_handoff_with_handoff_obj(): ], ) - handoffs = AgentRunner._get_handoffs(agent_3) + handoffs = await AgentRunner._get_handoffs(agent_3, RunContextWrapper(None)) assert len(handoffs) == 2 assert handoffs[0].agent_name == "agent_1" @@ -112,7 +112,7 @@ async def test_handoff_with_handoff_obj_and_agent(): handoffs=[handoff(agent_1), agent_2], ) - handoffs = AgentRunner._get_handoffs(agent_3) + handoffs = await AgentRunner._get_handoffs(agent_3, RunContextWrapper(None)) assert len(handoffs) == 2 assert handoffs[0].agent_name == "agent_1" diff --git a/tests/test_handoff_tool.py b/tests/test_handoff_tool.py index a1b5b80ba..0f7fc2166 100644 --- a/tests/test_handoff_tool.py +++ b/tests/test_handoff_tool.py @@ -38,16 +38,17 @@ def get_len(data: HandoffInputData) -> int: return input_len + pre_handoff_len + new_items_len -def test_single_handoff_setup(): +@pytest.mark.asyncio +async def test_single_handoff_setup(): agent_1 = Agent(name="test_1") agent_2 = Agent(name="test_2", handoffs=[agent_1]) assert not agent_1.handoffs assert agent_2.handoffs == [agent_1] - assert not AgentRunner._get_handoffs(agent_1) + assert not (await AgentRunner._get_handoffs(agent_1, RunContextWrapper(agent_1))) - handoff_objects = AgentRunner._get_handoffs(agent_2) + handoff_objects = await AgentRunner._get_handoffs(agent_2, RunContextWrapper(agent_2)) assert len(handoff_objects) == 1 obj = handoff_objects[0] assert obj.tool_name == Handoff.default_tool_name(agent_1) @@ -55,7 +56,8 @@ def test_single_handoff_setup(): assert obj.agent_name == agent_1.name -def test_multiple_handoffs_setup(): +@pytest.mark.asyncio +async def test_multiple_handoffs_setup(): agent_1 = Agent(name="test_1") agent_2 = Agent(name="test_2") agent_3 = Agent(name="test_3", handoffs=[agent_1, agent_2]) @@ -64,7 +66,7 @@ def test_multiple_handoffs_setup(): assert not agent_1.handoffs assert not agent_2.handoffs - handoff_objects = AgentRunner._get_handoffs(agent_3) + handoff_objects = await AgentRunner._get_handoffs(agent_3, RunContextWrapper(agent_3)) assert len(handoff_objects) == 2 assert handoff_objects[0].tool_name == Handoff.default_tool_name(agent_1) assert handoff_objects[1].tool_name == Handoff.default_tool_name(agent_2) @@ -76,7 +78,8 @@ def test_multiple_handoffs_setup(): assert handoff_objects[1].agent_name == agent_2.name -def test_custom_handoff_setup(): +@pytest.mark.asyncio +async def test_custom_handoff_setup(): agent_1 = Agent(name="test_1") agent_2 = Agent(name="test_2") agent_3 = Agent( @@ -95,7 +98,7 @@ def test_custom_handoff_setup(): assert not agent_1.handoffs assert not agent_2.handoffs - handoff_objects = AgentRunner._get_handoffs(agent_3) + handoff_objects = await AgentRunner._get_handoffs(agent_3, RunContextWrapper(agent_3)) assert len(handoff_objects) == 2 first_handoff = handoff_objects[0] @@ -284,3 +287,86 @@ def test_get_transfer_message_is_valid_json() -> None: obj = handoff(agent) transfer = obj.get_transfer_message(agent) assert json.loads(transfer) == {"assistant": agent.name} + + +def test_handoff_is_enabled_bool(): + """Test that handoff respects is_enabled boolean parameter.""" + agent = Agent(name="test") + + # Test enabled handoff (default) + handoff_enabled = handoff(agent) + assert handoff_enabled.is_enabled is True + + # Test explicitly enabled handoff + handoff_explicit_enabled = handoff(agent, is_enabled=True) + assert handoff_explicit_enabled.is_enabled is True + + # Test disabled handoff + handoff_disabled = handoff(agent, is_enabled=False) + assert handoff_disabled.is_enabled is False + + +@pytest.mark.asyncio +async def test_handoff_is_enabled_callable(): + """Test that handoff respects is_enabled callable parameter.""" + agent = Agent(name="test") + + # Test callable that returns True + def always_enabled(ctx: RunContextWrapper[Any], agent: Agent[Any]) -> bool: + return True + + handoff_callable_enabled = handoff(agent, is_enabled=always_enabled) + assert callable(handoff_callable_enabled.is_enabled) + result = handoff_callable_enabled.is_enabled(RunContextWrapper(agent), agent) + assert result is True + + # Test callable that returns False + def always_disabled(ctx: RunContextWrapper[Any], agent: Agent[Any]) -> bool: + return False + + handoff_callable_disabled = handoff(agent, is_enabled=always_disabled) + assert callable(handoff_callable_disabled.is_enabled) + result = handoff_callable_disabled.is_enabled(RunContextWrapper(agent), agent) + assert result is False + + # Test async callable + async def async_enabled(ctx: RunContextWrapper[Any], agent: Agent[Any]) -> bool: + return True + + handoff_async_enabled = handoff(agent, is_enabled=async_enabled) + assert callable(handoff_async_enabled.is_enabled) + result = await handoff_async_enabled.is_enabled(RunContextWrapper(agent), agent) # type: ignore + assert result is True + + +@pytest.mark.asyncio +async def test_handoff_is_enabled_filtering_integration(): + """Integration test that disabled handoffs are filtered out by the runner.""" + + # Set up agents + agent_1 = Agent(name="agent_1") + agent_2 = Agent(name="agent_2") + agent_3 = Agent(name="agent_3") + + # Create main agent with mixed enabled/disabled handoffs + main_agent = Agent( + name="main_agent", + handoffs=[ + handoff(agent_1, is_enabled=True), # enabled + handoff(agent_2, is_enabled=False), # disabled + handoff(agent_3, is_enabled=lambda ctx, agent: True), # enabled callable + ], + ) + + context_wrapper = RunContextWrapper(main_agent) + + # Get filtered handoffs using the runner's method + filtered_handoffs = await AgentRunner._get_handoffs(main_agent, context_wrapper) + + # Should only have 2 handoffs (agent_1 and agent_3), agent_2 should be filtered out + assert len(filtered_handoffs) == 2 + + # Check that the correct agents are present + agent_names = {h.agent_name for h in filtered_handoffs} + assert agent_names == {"agent_1", "agent_3"} + assert "agent_2" not in agent_names diff --git a/tests/test_run_step_execution.py b/tests/test_run_step_execution.py index 2454a4462..4cf9ae832 100644 --- a/tests/test_run_step_execution.py +++ b/tests/test_run_step_execution.py @@ -325,7 +325,7 @@ async def get_execute_result( run_config: RunConfig | None = None, ) -> SingleStepResult: output_schema = AgentRunner._get_output_schema(agent) - handoffs = AgentRunner._get_handoffs(agent) + handoffs = await AgentRunner._get_handoffs(agent, context_wrapper or RunContextWrapper(None)) processed_response = RunImpl.process_model_response( agent=agent, diff --git a/tests/test_run_step_processing.py b/tests/test_run_step_processing.py index 5a75ec837..6a2904791 100644 --- a/tests/test_run_step_processing.py +++ b/tests/test_run_step_processing.py @@ -186,7 +186,7 @@ async def test_handoffs_parsed_correctly(): agent=agent_3, response=response, output_schema=None, - handoffs=AgentRunner._get_handoffs(agent_3), + handoffs=await AgentRunner._get_handoffs(agent_3, _dummy_ctx()), all_tools=await agent_3.get_all_tools(_dummy_ctx()), ) assert len(result.handoffs) == 1, "Should have a handoff here" @@ -216,7 +216,7 @@ async def test_missing_handoff_fails(): agent=agent_3, response=response, output_schema=None, - handoffs=AgentRunner._get_handoffs(agent_3), + handoffs=await AgentRunner._get_handoffs(agent_3, _dummy_ctx()), all_tools=await agent_3.get_all_tools(_dummy_ctx()), ) @@ -239,7 +239,7 @@ async def test_multiple_handoffs_doesnt_error(): agent=agent_3, response=response, output_schema=None, - handoffs=AgentRunner._get_handoffs(agent_3), + handoffs=await AgentRunner._get_handoffs(agent_3, _dummy_ctx()), all_tools=await agent_3.get_all_tools(_dummy_ctx()), ) assert len(result.handoffs) == 2, "Should have multiple handoffs here" @@ -471,7 +471,7 @@ async def test_tool_and_handoff_parsed_correctly(): agent=agent_3, response=response, output_schema=None, - handoffs=AgentRunner._get_handoffs(agent_3), + handoffs=await AgentRunner._get_handoffs(agent_3, _dummy_ctx()), all_tools=await agent_3.get_all_tools(_dummy_ctx()), ) assert result.functions and len(result.functions) == 1 From 0f21c8aedb0251c6269934f9543021ee10e9489e Mon Sep 17 00:00:00 2001 From: Daniele Morotti <58258368+DanieleMorotti@users.noreply.github.com> Date: Tue, 24 Jun 2025 16:56:24 +0200 Subject: [PATCH 02/23] Removed lines to avoid duplicate output in REPL utility (#928) In the REPL utility, the final output was printed after the streaming, causing a duplication issue. I only removed the lines within the stream mode that printed the final output. It should not affect any other use case of the utility. It solves a comment done in #784 . --- src/agents/repl.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/agents/repl.py b/src/agents/repl.py index 9a4f30759..f7142555f 100644 --- a/src/agents/repl.py +++ b/src/agents/repl.py @@ -5,7 +5,7 @@ from openai.types.responses.response_text_delta_event import ResponseTextDeltaEvent from .agent import Agent -from .items import ItemHelpers, TResponseInputItem +from .items import TResponseInputItem from .result import RunResultBase from .run import Runner from .stream_events import AgentUpdatedStreamEvent, RawResponsesStreamEvent, RunItemStreamEvent @@ -50,9 +50,6 @@ async def run_demo_loop(agent: Agent[Any], *, stream: bool = True) -> None: print("\n[tool called]", flush=True) elif event.item.type == "tool_call_output_item": print(f"\n[tool output: {event.item.output}]", flush=True) - elif event.item.type == "message_output_item": - message = ItemHelpers.text_message_output(event.item) - print(message, end="", flush=True) elif isinstance(event, AgentUpdatedStreamEvent): print(f"\n[Agent updated: {event.new_agent.name}]", flush=True) print() From d88bf14360920a5b8a2576a610a5d04519f1857b Mon Sep 17 00:00:00 2001 From: Niv Hertz <46317590+niv-hertz@users.noreply.github.com> Date: Tue, 24 Jun 2025 18:14:19 +0300 Subject: [PATCH 03/23] Bugfix | Fixed a bug when calling reasoning models with `store=False` (#920) resolves https://github.com/openai/openai-agents-python/issues/919 --- src/agents/model_settings.py | 5 +++++ src/agents/models/openai_responses.py | 6 +++++- tests/model_settings/test_serialization.py | 1 + 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/src/agents/model_settings.py b/src/agents/model_settings.py index 881390279..f844eb87c 100644 --- a/src/agents/model_settings.py +++ b/src/agents/model_settings.py @@ -5,6 +5,7 @@ from typing import Any, Literal from openai._types import Body, Headers, Query +from openai.types.responses import ResponseIncludable from openai.types.shared import Reasoning from pydantic import BaseModel @@ -61,6 +62,10 @@ class ModelSettings: """Whether to include usage chunk. Defaults to True if not provided.""" + response_include: list[ResponseIncludable] | None = None + """Additional output data to include in the model response. + [include parameter](https://platform.openai.com/docs/api-reference/responses/create#responses-create-include)""" + extra_query: Query | None = None """Additional query fields to provide with the request. Defaults to None if not provided.""" diff --git a/src/agents/models/openai_responses.py b/src/agents/models/openai_responses.py index 961f690b0..637adaccd 100644 --- a/src/agents/models/openai_responses.py +++ b/src/agents/models/openai_responses.py @@ -240,6 +240,10 @@ async def _fetch_response( converted_tools = Converter.convert_tools(tools, handoffs) response_format = Converter.get_response_format(output_schema) + include: list[ResponseIncludable] = converted_tools.includes + if model_settings.response_include is not None: + include = list({*include, *model_settings.response_include}) + if _debug.DONT_LOG_MODEL_DATA: logger.debug("Calling LLM") else: @@ -258,7 +262,7 @@ async def _fetch_response( instructions=self._non_null_or_not_given(system_instructions), model=self.model, input=list_input, - include=converted_tools.includes, + include=include, tools=converted_tools.tools, prompt=self._non_null_or_not_given(prompt), temperature=self._non_null_or_not_given(model_settings.temperature), diff --git a/tests/model_settings/test_serialization.py b/tests/model_settings/test_serialization.py index ad4db4019..2bbc7ce2c 100644 --- a/tests/model_settings/test_serialization.py +++ b/tests/model_settings/test_serialization.py @@ -44,6 +44,7 @@ def test_all_fields_serialization() -> None: metadata={"foo": "bar"}, store=False, include_usage=False, + response_include=["reasoning.encrypted_content"], extra_query={"foo": "bar"}, extra_body={"foo": "bar"}, extra_headers={"foo": "bar"}, From 3e5f3c03e539607b65a586ec200f2df8ab1f58d0 Mon Sep 17 00:00:00 2001 From: Rohan Mehta Date: Tue, 24 Jun 2025 12:16:02 -0400 Subject: [PATCH 04/23] Point CLAUDE.md to AGENTS.md (#930) People soemtimes use claude code to send PRs, this allows the agents.md to be shared for that --- AGENTS.md | 4 +++- CLAUDE.md | 1 + 2 files changed, 4 insertions(+), 1 deletion(-) create mode 100644 CLAUDE.md diff --git a/AGENTS.md b/AGENTS.md index ff37db326..291c31837 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -35,6 +35,8 @@ Welcome to the OpenAI Agents SDK repository. This file contains the main points Coverage can be generated with `make coverage`. +All python commands should be run via `uv run python ...` + ## Snapshot tests Some tests rely on inline snapshots. See `tests/README.md` for details on updating them: @@ -64,6 +66,6 @@ Commit messages should be concise and written in the imperative mood. Small, foc ## What reviewers look for - Tests covering new behaviour. -- Consistent style: code formatted with `ruff format`, imports sorted, and type hints passing `mypy`. +- Consistent style: code formatted with `uv run ruff format`, imports sorted, and type hints passing `uv run mypy .`. - Clear documentation for any public API changes. - Clean history and a helpful PR description. diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 000000000..5e01a1c3d --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1 @@ +Read the AGENTS.md file for instructions. \ No newline at end of file From 421693ab84ad7a4ac864e0f331340dba0eaf5908 Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Wed, 25 Jun 2025 02:46:21 +0900 Subject: [PATCH 05/23] Fix #890 by adjusting the guardrail document page (#903) This pull request resolves #890 and enables the translation script to accept a file to work on. --- docs/guardrails.md | 2 +- docs/ja/guardrails.md | 44 ++++++++++++------------- docs/scripts/translate_docs.py | 59 +++++++++++++++++++++++----------- 3 files changed, 64 insertions(+), 41 deletions(-) diff --git a/docs/guardrails.md b/docs/guardrails.md index 2f0be0f2a..8df904a4c 100644 --- a/docs/guardrails.md +++ b/docs/guardrails.md @@ -23,7 +23,7 @@ Input guardrails run in 3 steps: Output guardrails run in 3 steps: -1. First, the guardrail receives the same input passed to the agent. +1. First, the guardrail receives the output produced by the agent. 2. Next, the guardrail function runs to produce a [`GuardrailFunctionOutput`][agents.guardrail.GuardrailFunctionOutput], which is then wrapped in an [`OutputGuardrailResult`][agents.guardrail.OutputGuardrailResult] 3. Finally, we check if [`.tripwire_triggered`][agents.guardrail.GuardrailFunctionOutput.tripwire_triggered] is true. If true, an [`OutputGuardrailTripwireTriggered`][agents.exceptions.OutputGuardrailTripwireTriggered] exception is raised, so you can appropriately respond to the user or handle the exception. diff --git a/docs/ja/guardrails.md b/docs/ja/guardrails.md index e7b02a6ed..b67bb8bad 100644 --- a/docs/ja/guardrails.md +++ b/docs/ja/guardrails.md @@ -4,44 +4,44 @@ search: --- # ガードレール -ガードレールは エージェント と _並列_ に実行され、 ユーザー入力 のチェックとバリデーションを行います。たとえば、顧客からのリクエストを支援するために非常に賢い (そのため遅く / 高価な) モデルを使うエージェントがあるとします。悪意のある ユーザー がモデルに数学の宿題を手伝わせようとするのは避けたいですよね。その場合、 高速 / 低コスト のモデルでガードレールを実行できます。ガードレールが悪意のある利用を検知した場合、即座にエラーを送出して高価なモデルの実行を停止し、時間と費用を節約できます。 +ガードレールは エージェント と _並行して_ 実行され、ユーザー入力のチェックとバリデーションを行えます。例えば、とても賢い(つまり遅く/高価な)モデルを使用してカスタマーリクエストを処理するエージェントがあるとします。悪意のある ユーザー がモデルに数学の宿題を手伝わせようとするのは避けたいでしょう。そこで、速く/安価なモデルで動くガードレールを実行できます。ガードレールが悪意のある利用を検知すると、直ちにエラーを送出して高価なモデルの実行を停止し、時間とコストを節約できます。 -ガードレールには 2 種類あります。 +ガードレールには 2 種類あります: -1. Input ガードレールは最初の ユーザー入力 に対して実行されます -2. Output ガードレールは最終的なエージェント出力に対して実行されます +1. 入力ガードレール は初期 ユーザー 入力に対して実行されます +2. 出力ガードレール は最終的なエージェント出力に対して実行されます -## Input ガードレール +## 入力ガードレール -Input ガードレールは 3 つのステップで実行されます。 +入力ガードレールは 3 ステップで実行されます: 1. まず、ガードレールはエージェントに渡されたものと同じ入力を受け取ります。 -2. 次に、ガードレール関数が実行され [`GuardrailFunctionOutput`][agents.guardrail.GuardrailFunctionOutput] を生成し、それが [`InputGuardrailResult`][agents.guardrail.InputGuardrailResult] でラップされます。 -3. 最後に [`.tripwire_triggered`][agents.guardrail.GuardrailFunctionOutput.tripwire_triggered] が true かどうかを確認します。true の場合、[`InputGuardrailTripwireTriggered`][agents.exceptions.InputGuardrailTripwireTriggered] 例外が送出されるので、 ユーザー への適切な応答や例外処理を行えます。 +2. 次に、ガードレール関数が実行され [`GuardrailFunctionOutput`][agents.guardrail.GuardrailFunctionOutput] を生成し、それが [`InputGuardrailResult`][agents.guardrail.InputGuardrailResult] にラップされます。 +3. 最後に [.tripwire_triggered][agents.guardrail.GuardrailFunctionOutput.tripwire_triggered] が true かどうかを確認します。true の場合、[`InputGuardrailTripwireTriggered`][agents.exceptions.InputGuardrailTripwireTriggered] 例外が送出されるので、適切に ユーザー に応答したり例外を処理できます。 !!! Note - Input ガードレールは ユーザー入力 に対して実行されることを想定しているため、エージェントのガードレールが実行されるのはそのエージェントが *最初* のエージェントである場合だけです。「なぜ `guardrails` プロパティがエージェントにあり、 `Runner.run` に渡さないのか?」と思うかもしれません。ガードレールは実際の エージェント に密接に関連する場合が多く、エージェントごとに異なるガードレールを実行するため、コードを同じ場所に置くことで可読性が向上するからです。 + 入力ガードレールは ユーザー 入力に対して実行されることを意図しているため、ガードレールは *最初* のエージェントでのみ実行されます。「なぜ `guardrails` プロパティがエージェントにあり、`Runner.run` に渡さないのか」と疑問に思うかもしれません。これは、ガードレールが実際の エージェント と密接に関連していることが多いからです。異なるエージェントには異なるガードレールを実行するため、コードを同じ場所に置くことで可読性が向上します。 -## Output ガードレール +## 出力ガードレール -Output ガードレールは 3 つのステップで実行されます。 +出力ガードレールは 3 ステップで実行されます: -1. まず、ガードレールはエージェントに渡されたものと同じ入力を受け取ります。 -2. 次に、ガードレール関数が実行され [`GuardrailFunctionOutput`][agents.guardrail.GuardrailFunctionOutput] を生成し、それが [`OutputGuardrailResult`][agents.guardrail.OutputGuardrailResult] でラップされます。 -3. 最後に [`.tripwire_triggered`][agents.guardrail.GuardrailFunctionOutput.tripwire_triggered] が true かどうかを確認します。true の場合、[`OutputGuardrailTripwireTriggered`][agents.exceptions.OutputGuardrailTripwireTriggered] 例外が送出されるので、 ユーザー への適切な応答や例外処理を行えます。 +1. まず、ガードレールはエージェントが生成した出力を受け取ります。 +2. 次に、ガードレール関数が実行され [`GuardrailFunctionOutput`][agents.guardrail.GuardrailFunctionOutput] を生成し、それが [`OutputGuardrailResult`][agents.guardrail.OutputGuardrailResult] にラップされます。 +3. 最後に [.tripwire_triggered][agents.guardrail.GuardrailFunctionOutput.tripwire_triggered] が true かどうかを確認します。true の場合、[`OutputGuardrailTripwireTriggered`][agents.exceptions.OutputGuardrailTripwireTriggered] 例外が送出されるので、適切に ユーザー に応答したり例外を処理できます。 !!! Note - Output ガードレールは最終的なエージェント出力に対して実行されることを想定しているため、エージェントのガードレールが実行されるのはそのエージェントが *最後* のエージェントである場合だけです。Input ガードレール同様、ガードレールは実際の エージェント に密接に関連するため、コードを同じ場所に置くことで可読性が向上します。 + 出力ガードレールは最終的なエージェント出力に対して実行されることを意図しているため、ガードレールは *最後* のエージェントでのみ実行されます。入力ガードレールの場合と同様、ガードレールが実際の エージェント と密接に関連していることが多いため、コードを同じ場所に置くことで可読性が向上します。 -## トリップワイヤ +## トリップワイヤー -入力または出力がガードレールに失敗した場合、ガードレールはトリップワイヤを用いてそれを通知できます。ガードレールがトリップワイヤを発火したことを検知すると、ただちに `{Input,Output}GuardrailTripwireTriggered` 例外を送出してエージェントの実行を停止します。 +入力または出力がガードレールを通過できなかった場合、ガードレールはトリップワイヤーでそれを示すことができます。トリップワイヤーがトリガーされたガードレールを検知した時点で、直ちに `{Input,Output}GuardrailTripwireTriggered` 例外を送出し、エージェントの実行を停止します。 ## ガードレールの実装 -入力を受け取り、[`GuardrailFunctionOutput`][agents.guardrail.GuardrailFunctionOutput] を返す関数を用意する必要があります。次の例では、内部で エージェント を実行してこれを行います。 +入力を受け取り、[`GuardrailFunctionOutput`][agents.guardrail.GuardrailFunctionOutput] を返す関数を提供する必要があります。この例では、内部で エージェント を実行してこれを行います。 ```python from pydantic import BaseModel @@ -94,12 +94,12 @@ async def main(): print("Math homework guardrail tripped") ``` -1. この エージェント をガードレール関数内で使用します。 -2. これはエージェントの入力 / コンテキストを受け取り、結果を返すガードレール関数です。 +1. このエージェントをガードレール関数内で使用します。 +2. これはエージェントの入力/コンテキストを受け取り、結果を返すガードレール関数です。 3. ガードレール結果に追加情報を含めることができます。 4. これはワークフローを定義する実際のエージェントです。 -Output ガードレールも同様です。 +出力ガードレールも同様です。 ```python from pydantic import BaseModel @@ -155,4 +155,4 @@ async def main(): 1. これは実際のエージェントの出力型です。 2. これはガードレールの出力型です。 3. これはエージェントの出力を受け取り、結果を返すガードレール関数です。 -4. これはワークフローを定義する実際のエージェントです。 \ No newline at end of file +4. これはワークフローを定義する実際のエージェントです。 \ No newline at end of file diff --git a/docs/scripts/translate_docs.py b/docs/scripts/translate_docs.py index b2e8b44fc..5dada2681 100644 --- a/docs/scripts/translate_docs.py +++ b/docs/scripts/translate_docs.py @@ -1,5 +1,7 @@ # ruff: noqa import os +import sys +import argparse from openai import OpenAI from concurrent.futures import ThreadPoolExecutor @@ -263,24 +265,45 @@ def translate_single_source_file(file_path: str) -> None: def main(): - # Traverse the source directory - for root, _, file_names in os.walk(source_dir): - # Skip the target directories - if any(lang in root for lang in languages): - continue - # Increasing this will make the translation faster; you can decide considering the model's capacity - concurrency = 6 - with ThreadPoolExecutor(max_workers=concurrency) as executor: - futures = [] - for file_name in file_names: - filepath = os.path.join(root, file_name) - futures.append(executor.submit(translate_single_source_file, filepath)) - if len(futures) >= concurrency: - for future in futures: - future.result() - futures.clear() - - print("Translation completed.") + parser = argparse.ArgumentParser(description="Translate documentation files") + parser.add_argument("--file", type=str, help="Specific file to translate (relative to docs directory)") + args = parser.parse_args() + + if args.file: + # Translate a single file + # Handle both "foo.md" and "docs/foo.md" formats + if args.file.startswith("docs/"): + # Remove "docs/" prefix if present + relative_file = args.file[5:] + else: + relative_file = args.file + + file_path = os.path.join(source_dir, relative_file) + if os.path.exists(file_path): + translate_single_source_file(file_path) + print(f"Translation completed for {relative_file}") + else: + print(f"Error: File {file_path} does not exist") + sys.exit(1) + else: + # Traverse the source directory (original behavior) + for root, _, file_names in os.walk(source_dir): + # Skip the target directories + if any(lang in root for lang in languages): + continue + # Increasing this will make the translation faster; you can decide considering the model's capacity + concurrency = 6 + with ThreadPoolExecutor(max_workers=concurrency) as executor: + futures = [] + for file_name in file_names: + filepath = os.path.join(root, file_name) + futures.append(executor.submit(translate_single_source_file, filepath)) + if len(futures) >= concurrency: + for future in futures: + future.result() + futures.clear() + + print("Translation completed.") if __name__ == "__main__": From 7665dd1acc03ddecc2e8f70f907ee9941060b76f Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Thu, 26 Jun 2025 00:18:18 +0900 Subject: [PATCH 06/23] Add exempt-issue-labels to the issue triage GH Action job (#936) Like I did for the TypeScript SDK project (see https://github.com/openai/openai-agents-js/pull/97), we may want to have the labels to exclude the issues from auto-closing. Some of the issues that have been open for a while could just need time. --- .github/workflows/issues.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/issues.yml b/.github/workflows/issues.yml index 6447f83ef..dca717752 100644 --- a/.github/workflows/issues.yml +++ b/.github/workflows/issues.yml @@ -21,6 +21,7 @@ jobs: days-before-pr-stale: 10 days-before-pr-close: 7 stale-pr-label: "stale" + exempt-issue-labels: "skip-stale" stale-pr-message: "This PR is stale because it has been open for 10 days with no activity." close-pr-message: "This PR was closed because it has been inactive for 7 days since being marked as stale." repo-token: ${{ secrets.GITHUB_TOKEN }} From c2005f8273a899ffd32e4b77c06701f347f8ec92 Mon Sep 17 00:00:00 2001 From: Easonshi Date: Wed, 25 Jun 2025 23:32:20 +0800 Subject: [PATCH 07/23] Remove duplicate entry from `__init__.py` (#897) ### Summary This PR fixes a duplicated line in `__init__.py` ### Checks - [x] I've added/updated the relevant documentation - [x] I've run `make lint` and `make format` Co-authored-by: easonsshi --- src/agents/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/agents/__init__.py b/src/agents/__init__.py index 3eecaead0..1296b72be 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -206,7 +206,6 @@ def enable_verbose_stdout_logging(): "ToolCallItem", "ToolCallOutputItem", "ReasoningItem", - "ModelResponse", "ItemHelpers", "RunHooks", "AgentHooks", From 94f80353d9f60d3005f330cd37eea8a7d0ac78a1 Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Thu, 26 Jun 2025 08:28:10 +0900 Subject: [PATCH 08/23] Fix #604 Chat Completion model raises runtime error when response.choices is empty (#935) This pull request resolves #604; just checking the existence of the data and avoiding the runtime exception should make sense. --- src/agents/models/openai_chatcompletions.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/agents/models/openai_chatcompletions.py b/src/agents/models/openai_chatcompletions.py index 08803d8c0..120d726db 100644 --- a/src/agents/models/openai_chatcompletions.py +++ b/src/agents/models/openai_chatcompletions.py @@ -7,7 +7,8 @@ from openai import NOT_GIVEN, AsyncOpenAI, AsyncStream from openai.types import ChatModel -from openai.types.chat import ChatCompletion, ChatCompletionChunk +from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage +from openai.types.chat.chat_completion import Choice from openai.types.responses import Response from openai.types.responses.response_prompt_param import ResponsePromptParam from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails @@ -74,8 +75,11 @@ async def get_response( prompt=prompt, ) - first_choice = response.choices[0] - message = first_choice.message + message: ChatCompletionMessage | None = None + first_choice: Choice | None = None + if response.choices and len(response.choices) > 0: + first_choice = response.choices[0] + message = first_choice.message if _debug.DONT_LOG_MODEL_DATA: logger.debug("Received model response") @@ -86,10 +90,8 @@ async def get_response( json.dumps(message.model_dump(), indent=2), ) else: - logger.debug( - "LLM resp had no message. finish_reason: %s", - first_choice.finish_reason, - ) + finish_reason = first_choice.finish_reason if first_choice else "-" + logger.debug(f"LLM resp had no message. finish_reason: {finish_reason}") usage = ( Usage( From 653abc5f598204250193cef4664e8bf183436ddb Mon Sep 17 00:00:00 2001 From: Abbas Asad <168946441+Abbas-Asad@users.noreply.github.com> Date: Thu, 26 Jun 2025 04:40:22 +0500 Subject: [PATCH 09/23] Fix #892 by adding proper tool description in context.md (#937) ### Summary This pull request fixes [issue #892](https://github.com/openai/openai-agents-python/issues/892) by adding a missing docstring to the `fetch_user_age` tool function in `docs/context.md`. ### Problem Many non-OpenAI LLMs (such as Claude, Gemini, Mistral, etc.) are unable to use the `fetch_user_age` function because it lacks a docstring. As a result, they return responses like: > "I cannot determine the user's age because the available tools lack the ability to fetch user-specific information." ### Changes Made - Added a one-line docstring to the `fetch_user_age` function - Improved return statement to match expected tool output ```python @function_tool async def fetch_user_age(wrapper: RunContextWrapper[UserInfo]) -> str: """Fetch the age of the user. Call this function to get user's age information.""" return f"The user {wrapper.context.name} is 47 years old" --- docs/context.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/context.md b/docs/context.md index 4176ec51f..6e54565e0 100644 --- a/docs/context.md +++ b/docs/context.md @@ -38,7 +38,8 @@ class UserInfo: # (1)! @function_tool async def fetch_user_age(wrapper: RunContextWrapper[UserInfo]) -> str: # (2)! - return f"User {wrapper.context.name} is 47 years old" + """Fetch the age of the user. Call this function to get user's age information.""" + return f"The user {wrapper.context.name} is 47 years old" async def main(): user_info = UserInfo(name="John", uid=123) From c1cb7ce5384329272b6b3cd379bd5c70f9655236 Mon Sep 17 00:00:00 2001 From: Daniel Hashmi <151331476+DanielHashmi@users.noreply.github.com> Date: Fri, 27 Jun 2025 11:55:13 +0500 Subject: [PATCH 10/23] Duplicate Classifier in pyproject.toml (#951) "Intended Audience :: Developers" appears twice in the classifiers list --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 75e714768..6c0bbe9f1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,6 @@ classifiers = [ "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", - "Intended Audience :: Developers", "Operating System :: OS Independent", "Topic :: Software Development :: Libraries :: Python Modules", "License :: OSI Approved :: MIT License", From a6580a5090399d1558b219018fa8a3a8aa622654 Mon Sep 17 00:00:00 2001 From: Kat <74958594+KatHaruto@users.noreply.github.com> Date: Fri, 27 Jun 2025 23:31:22 +0900 Subject: [PATCH 11/23] fix: add ensure_ascii=False to json.dumps for correct Unicode output (#639) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Set `ensure_ascii=False` in `json.dumps(...)` to allow proper Unicode character output (e.g., Japanese). Without this, non-ASCII characters are escaped (e.g., 日本語 → \u65e5\u672c\u8a9e), which makes logs or exports harder to read. --- src/agents/extensions/models/litellm_model.py | 10 +++++++--- src/agents/models/openai_chatcompletions.py | 6 +++--- src/agents/models/openai_responses.py | 12 +++++++++--- 3 files changed, 19 insertions(+), 9 deletions(-) diff --git a/src/agents/extensions/models/litellm_model.py b/src/agents/extensions/models/litellm_model.py index c58a52dae..a06c61dc3 100644 --- a/src/agents/extensions/models/litellm_model.py +++ b/src/agents/extensions/models/litellm_model.py @@ -98,7 +98,11 @@ async def get_response( logger.debug("Received model response") else: logger.debug( - f"LLM resp:\n{json.dumps(response.choices[0].message.model_dump(), indent=2)}\n" + f"""LLM resp:\n{ + json.dumps( + response.choices[0].message.model_dump(), indent=2, ensure_ascii=False + ) + }\n""" ) if hasattr(response, "usage"): @@ -269,8 +273,8 @@ async def _fetch_response( else: logger.debug( f"Calling Litellm model: {self.model}\n" - f"{json.dumps(converted_messages, indent=2)}\n" - f"Tools:\n{json.dumps(converted_tools, indent=2)}\n" + f"{json.dumps(converted_messages, indent=2, ensure_ascii=False)}\n" + f"Tools:\n{json.dumps(converted_tools, indent=2, ensure_ascii=False)}\n" f"Stream: {stream}\n" f"Tool choice: {tool_choice}\n" f"Response format: {response_format}\n" diff --git a/src/agents/models/openai_chatcompletions.py b/src/agents/models/openai_chatcompletions.py index 120d726db..6de431b4d 100644 --- a/src/agents/models/openai_chatcompletions.py +++ b/src/agents/models/openai_chatcompletions.py @@ -87,7 +87,7 @@ async def get_response( if message is not None: logger.debug( "LLM resp:\n%s\n", - json.dumps(message.model_dump(), indent=2), + json.dumps(message.model_dump(), indent=2, ensure_ascii=False), ) else: finish_reason = first_choice.finish_reason if first_choice else "-" @@ -256,8 +256,8 @@ async def _fetch_response( logger.debug("Calling LLM") else: logger.debug( - f"{json.dumps(converted_messages, indent=2)}\n" - f"Tools:\n{json.dumps(converted_tools, indent=2)}\n" + f"{json.dumps(converted_messages, indent=2, ensure_ascii=False)}\n" + f"Tools:\n{json.dumps(converted_tools, indent=2, ensure_ascii=False)}\n" f"Stream: {stream}\n" f"Tool choice: {tool_choice}\n" f"Response format: {response_format}\n" diff --git a/src/agents/models/openai_responses.py b/src/agents/models/openai_responses.py index 637adaccd..a7ce62983 100644 --- a/src/agents/models/openai_responses.py +++ b/src/agents/models/openai_responses.py @@ -96,7 +96,13 @@ async def get_response( else: logger.debug( "LLM resp:\n" - f"{json.dumps([x.model_dump() for x in response.output], indent=2)}\n" + f"""{ + json.dumps( + [x.model_dump() for x in response.output], + indent=2, + ensure_ascii=False, + ) + }\n""" ) usage = ( @@ -249,8 +255,8 @@ async def _fetch_response( else: logger.debug( f"Calling LLM {self.model} with input:\n" - f"{json.dumps(list_input, indent=2)}\n" - f"Tools:\n{json.dumps(converted_tools.tools, indent=2)}\n" + f"{json.dumps(list_input, indent=2, ensure_ascii=False)}\n" + f"Tools:\n{json.dumps(converted_tools.tools, indent=2, ensure_ascii=False)}\n" f"Stream: {stream}\n" f"Tool choice: {tool_choice}\n" f"Response format: {response_format}\n" From 271a1a4b9722016c1baf4d27a0b898fbf1dbf89d Mon Sep 17 00:00:00 2001 From: Richard <164130786@qq.com> Date: Fri, 27 Jun 2025 22:33:08 +0800 Subject: [PATCH 12/23] Add uv as an alternative Python environment setup option for issue #884 (#909) - Add Option A (venv) and Option B (uv) in Get started section - Mark uv as recommended to align with development workflow - Include Windows activation commands for both options - Resolves #884 --- README.md | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 785177916..b334bc20b 100644 --- a/README.md +++ b/README.md @@ -20,14 +20,21 @@ Explore the [examples](examples) directory to see the SDK in action, and read ou 1. Set up your Python environment -``` +- Option A: Using venv (traditional method) +```bash python -m venv env -source env/bin/activate +source env/bin/activate # On Windows: env\Scripts\activate +``` + +- Option B: Using uv (recommended) +```bash +uv venv +source .venv/bin/activate # On Windows: .venv\Scripts\activate ``` 2. Install Agents SDK -``` +```bash pip install openai-agents ``` From 816b7702bc78fed6aae7678f98ac827f03e4e4df Mon Sep 17 00:00:00 2001 From: Rehan Ul Haq Date: Fri, 27 Jun 2025 19:37:57 +0500 Subject: [PATCH 13/23] Fix and Document `parallel_tool_calls` Attribute in ModelSettings (#763) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Closes #762 **Description** This PR updates the docstring for the `parallel_tool_calls` attribute in the `ModelSettings` dataclass to accurately reflect its default behavior.. The previous docstring incorrectly stated that the default of `False`, while the actual behavior is dependent on the underlying model provider's default. As noted in OpenAI's (here refers as model provider) [Function Calling](https://platform.openai.com/docs/guides/function-calling?api-mode=responses#parallel-function-calling) documentation, "The model may choose to call multiple functions in a single turn. You can prevent this by setting parallel_tool_calls to false, which ensures exactly zero or one tool is called." Therefore, when the `parallel_tool_calls` attribute in the `ModelSettings` dataclass is set to `None` (i.e., `parallel_tool_calls: bool | None = None`), and this value is passed directly to the API without modification, it defers to the model provider's default behavior for parallel tool calls. This is typically `True` for most current providers, but it's important to acknowledge that this isn't a fixed default within our codebase. The new docstring is formatted for automatic documentation generation and provides clear, accurate information for users and developers. **Key changes:** * **Clarified the default behavior of `parallel_tool_calls`:** Instead of stating a fixed default, the docstring now accurately reflects that the behavior defaults to whatever the model provider does when the attribute is `None`. * Improved the docstring format for compatibility with documentation tools. * Clarified the purpose and usage of the `parallel_tool_calls` attribute. **Testing:** * Explicitly set `parallel_tool_calls=False` in both the `run_config` of the `Runner.run` method and in the agent’s `model_settings` attribute. * Example for `Runner.run`: ```python Runner.run(agent, input, run_config=RunConfig(model_settings=ModelSettings(parallel_tool_calls=False))) ``` * Example for agent initialization: ```python agent = Agent(..., model_settings=ModelSettings(parallel_tool_calls=False)) ``` * Verified that when `parallel_tool_calls=False`, tools are called sequentially. * Confirmed that by default (without setting the attribute), tools are called in parallel (Tested with openai models). * Checked that the updated docstring renders correctly in the generated documentation. * Ensured the default value in code matches the documentation. **Why this is important:** * Prevents confusion for users and developers regarding the default behavior of `parallel_tool_calls`. * Ensures that generated documentation is accurate and up-to-date. * Improves overall code quality and maintainability. --- src/agents/model_settings.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/agents/model_settings.py b/src/agents/model_settings.py index f844eb87c..709046b7a 100644 --- a/src/agents/model_settings.py +++ b/src/agents/model_settings.py @@ -37,8 +37,13 @@ class ModelSettings: """The tool choice to use when calling the model.""" parallel_tool_calls: bool | None = None - """Whether to use parallel tool calls when calling the model. - Defaults to False if not provided.""" + """Controls whether the model can make multiple parallel tool calls in a single turn. + If not provided (i.e., set to None), this behavior defers to the underlying + model provider's default. For most current providers (e.g., OpenAI), this typically + means parallel tool calls are enabled (True). + Set to True to explicitly enable parallel tool calls, or False to restrict the + model to at most one tool call per turn. + """ truncation: Literal["auto", "disabled"] | None = None """The truncation strategy to use when calling the model.""" From 0475450d59daec1f1610f17fc549cc3d3dc0f04d Mon Sep 17 00:00:00 2001 From: CCM Date: Fri, 27 Jun 2025 22:39:49 +0800 Subject: [PATCH 14/23] replace .py file with .ipynb for Jupyter example (#262) Co-authored-by: chenchaomin --- README.md | 2 +- examples/basic/hello_world_jupyter.ipynb | 45 ++++++++++++++++++++++++ examples/basic/hello_world_jupyter.py | 11 ------ 3 files changed, 46 insertions(+), 12 deletions(-) create mode 100644 examples/basic/hello_world_jupyter.ipynb delete mode 100644 examples/basic/hello_world_jupyter.py diff --git a/README.md b/README.md index b334bc20b..079aebb06 100644 --- a/README.md +++ b/README.md @@ -57,7 +57,7 @@ print(result.final_output) (_If running this, ensure you set the `OPENAI_API_KEY` environment variable_) -(_For Jupyter notebook users, see [hello_world_jupyter.py](examples/basic/hello_world_jupyter.py)_) +(_For Jupyter notebook users, see [hello_world_jupyter.ipynb](examples/basic/hello_world_jupyter.ipynb)_) ## Handoffs example diff --git a/examples/basic/hello_world_jupyter.ipynb b/examples/basic/hello_world_jupyter.ipynb new file mode 100644 index 000000000..42ee8e6a2 --- /dev/null +++ b/examples/basic/hello_world_jupyter.ipynb @@ -0,0 +1,45 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "8a77ee2e-22f2-409c-837d-b994978b0aa2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "A function calls self, \n", + "Unraveling layers deep, \n", + "Base case ends the quest. \n", + "\n", + "Infinite loops lurk, \n", + "Mind the base condition well, \n", + "Or it will not work. \n", + "\n", + "Trees and lists unfold, \n", + "Elegant solutions bloom, \n", + "Recursion's art told.\n" + ] + } + ], + "source": [ + "from agents import Agent, Runner\n", + "\n", + "agent = Agent(name=\"Assistant\", instructions=\"You are a helpful assistant\")\n", + "\n", + "# Intended for Jupyter notebooks where there's an existing event loop\n", + "result = await Runner.run(agent, \"Write a haiku about recursion in programming.\") # type: ignore[top-level-await] # noqa: F704\n", + "print(result.final_output)" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/basic/hello_world_jupyter.py b/examples/basic/hello_world_jupyter.py deleted file mode 100644 index c929a7c68..000000000 --- a/examples/basic/hello_world_jupyter.py +++ /dev/null @@ -1,11 +0,0 @@ -from agents import Agent, Runner - -agent = Agent(name="Assistant", instructions="You are a helpful assistant") - -# Intended for Jupyter notebooks where there's an existing event loop -result = await Runner.run(agent, "Write a haiku about recursion in programming.") # type: ignore[top-level-await] # noqa: F704 -print(result.final_output) - -# Code within code loops, -# Infinite mirrors reflect— -# Logic folds on self. From 445033946bfa5996985c99a9bff5c26705179703 Mon Sep 17 00:00:00 2001 From: devtalker Date: Fri, 27 Jun 2025 23:22:03 +0800 Subject: [PATCH 15/23] feat: add MCP tool filtering support (#861) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Add MCP Tool Filtering Support This PR implements tool filtering capabilities for MCP servers, addressing multiple community requests for this feature. ### Problem Currently, Agent SDK automatically fetches all available tools from MCP servers without the ability to select specific tools. This creates several issues: - Unwanted tools occupy LLM context and affect tool selection accuracy - Security concerns when limiting tool access scope - Tool name conflicts between multiple servers - Need for different tool subsets across different agents ### Solution Implements a two-level filtering system: **Server-level filtering:** ```python server = MCPServerStdio( params={"command": "npx", "args": ["-y", "@modelcontextprotocol/server-filesystem", "/path"]}, allowed_tools=["read_file", "write_file"], # whitelist excluded_tools=["delete_file"] # blacklist ) ``` **Agent-level filtering:** ```python agent = Agent( name="Assistant", mcp_servers=[server1, server2], mcp_config={ "allowed_tools": {"server1": ["read_file", "write_file"]}, "excluded_tools": {"server2": ["dangerous_tool"]} } ) ``` ### Features - ✅ Server-level `allowed_tools`/`excluded_tools` parameters for all MCP server types - ✅ Agent-level filtering via `mcp_config` - ✅ Hierarchical filtering (server-level first, then agent-level) - ✅ Comprehensive test coverage (8 test cases) - ✅ Updated documentation with examples ### Related Issues #376, #851, #830, #863 ### Testing All existing tests pass + 8 new test cases covering various filtering scenarios. --- docs/ja/mcp.md | 9 +- docs/mcp.md | 96 +++++++++++- src/agents/agent.py | 10 +- src/agents/mcp/__init__.py | 14 +- src/agents/mcp/server.py | 155 ++++++++++++++++++-- src/agents/mcp/util.py | 94 +++++++++++- src/agents/tool.py | 1 + tests/mcp/helpers.py | 49 ++++++- tests/mcp/test_caching.py | 22 ++- tests/mcp/test_mcp_util.py | 20 ++- tests/mcp/test_server_errors.py | 7 +- tests/mcp/test_tool_filtering.py | 243 +++++++++++++++++++++++++++++++ 12 files changed, 675 insertions(+), 45 deletions(-) create mode 100644 tests/mcp/test_tool_filtering.py diff --git a/docs/ja/mcp.md b/docs/ja/mcp.md index 09804beb2..1e394a5e6 100644 --- a/docs/ja/mcp.md +++ b/docs/ja/mcp.md @@ -23,13 +23,20 @@ Agents SDK は MCP をサポートしており、これにより幅広い MCP たとえば、[公式 MCP filesystem サーバー](https://www.npmjs.com/package/@modelcontextprotocol/server-filesystem)を利用する場合は次のようになります。 ```python +from agents.run_context import RunContextWrapper + async with MCPServerStdio( params={ "command": "npx", "args": ["-y", "@modelcontextprotocol/server-filesystem", samples_dir], } ) as server: - tools = await server.list_tools() + # 注意:実際には通常は MCP サーバーをエージェントに追加し、 + # フレームワークがツール一覧の取得を自動的に処理するようにします。 + # list_tools() への直接呼び出しには run_context と agent パラメータが必要です。 + run_context = RunContextWrapper(context=None) + agent = Agent(name="test", instructions="test") + tools = await server.list_tools(run_context, agent) ``` ## MCP サーバーの利用 diff --git a/docs/mcp.md b/docs/mcp.md index 76d142029..d30a916ac 100644 --- a/docs/mcp.md +++ b/docs/mcp.md @@ -19,13 +19,20 @@ You can use the [`MCPServerStdio`][agents.mcp.server.MCPServerStdio], [`MCPServe For example, this is how you'd use the [official MCP filesystem server](https://www.npmjs.com/package/@modelcontextprotocol/server-filesystem). ```python +from agents.run_context import RunContextWrapper + async with MCPServerStdio( params={ "command": "npx", "args": ["-y", "@modelcontextprotocol/server-filesystem", samples_dir], } ) as server: - tools = await server.list_tools() + # Note: In practice, you typically add the server to an Agent + # and let the framework handle tool listing automatically. + # Direct calls to list_tools() require run_context and agent parameters. + run_context = RunContextWrapper(context=None) + agent = Agent(name="test", instructions="test") + tools = await server.list_tools(run_context, agent) ``` ## Using MCP servers @@ -41,6 +48,93 @@ agent=Agent( ) ``` +## Tool filtering + +You can filter which tools are available to your Agent by configuring tool filters on MCP servers. The SDK supports both static and dynamic tool filtering. + +### Static tool filtering + +For simple allow/block lists, you can use static filtering: + +```python +from agents.mcp import create_static_tool_filter + +# Only expose specific tools from this server +server = MCPServerStdio( + params={ + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-filesystem", samples_dir], + }, + tool_filter=create_static_tool_filter( + allowed_tool_names=["read_file", "write_file"] + ) +) + +# Exclude specific tools from this server +server = MCPServerStdio( + params={ + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-filesystem", samples_dir], + }, + tool_filter=create_static_tool_filter( + blocked_tool_names=["delete_file"] + ) +) + +``` + +**When both `allowed_tool_names` and `blocked_tool_names` are configured, the processing order is:** +1. First apply `allowed_tool_names` (allowlist) - only keep the specified tools +2. Then apply `blocked_tool_names` (blocklist) - exclude specified tools from the remaining tools + +For example, if you configure `allowed_tool_names=["read_file", "write_file", "delete_file"]` and `blocked_tool_names=["delete_file"]`, only `read_file` and `write_file` tools will be available. + +### Dynamic tool filtering + +For more complex filtering logic, you can use dynamic filters with functions: + +```python +from agents.mcp import ToolFilterContext + +# Simple synchronous filter +def custom_filter(context: ToolFilterContext, tool) -> bool: + """Example of a custom tool filter.""" + # Filter logic based on tool name patterns + return tool.name.startswith("allowed_prefix") + +# Context-aware filter +def context_aware_filter(context: ToolFilterContext, tool) -> bool: + """Filter tools based on context information.""" + # Access agent information + agent_name = context.agent.name + + # Access server information + server_name = context.server_name + + # Implement your custom filtering logic here + return some_filtering_logic(agent_name, server_name, tool) + +# Asynchronous filter +async def async_filter(context: ToolFilterContext, tool) -> bool: + """Example of an asynchronous filter.""" + # Perform async operations if needed + result = await some_async_check(context, tool) + return result + +server = MCPServerStdio( + params={ + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-filesystem", samples_dir], + }, + tool_filter=custom_filter # or context_aware_filter or async_filter +) +``` + +The `ToolFilterContext` provides access to: +- `run_context`: The current run context +- `agent`: The agent requesting the tools +- `server_name`: The name of the MCP server + ## Caching Every time an Agent runs, it calls `list_tools()` on the MCP server. This can be a latency hit, especially if the server is a remote server. To automatically cache the list of tools, you can pass `cache_tools_list=True` to [`MCPServerStdio`][agents.mcp.server.MCPServerStdio], [`MCPServerSse`][agents.mcp.server.MCPServerSse], and [`MCPServerStreamableHttp`][agents.mcp.server.MCPServerStreamableHttp]. You should only do this if you're certain the tool list will not change. diff --git a/src/agents/agent.py b/src/agents/agent.py index 61a9abe0c..6c87297f1 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -256,14 +256,18 @@ async def get_prompt( """Get the prompt for the agent.""" return await PromptUtil.to_model_input(self.prompt, run_context, self) - async def get_mcp_tools(self) -> list[Tool]: + async def get_mcp_tools( + self, run_context: RunContextWrapper[TContext] + ) -> list[Tool]: """Fetches the available tools from the MCP servers.""" convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False) - return await MCPUtil.get_all_function_tools(self.mcp_servers, convert_schemas_to_strict) + return await MCPUtil.get_all_function_tools( + self.mcp_servers, convert_schemas_to_strict, run_context, self + ) async def get_all_tools(self, run_context: RunContextWrapper[Any]) -> list[Tool]: """All agent tools, including MCP tools and function tools.""" - mcp_tools = await self.get_mcp_tools() + mcp_tools = await self.get_mcp_tools(run_context) async def _check_tool_enabled(tool: Tool) -> bool: if not isinstance(tool, FunctionTool): diff --git a/src/agents/mcp/__init__.py b/src/agents/mcp/__init__.py index d4eb8fa68..da5a68b16 100644 --- a/src/agents/mcp/__init__.py +++ b/src/agents/mcp/__init__.py @@ -11,7 +11,14 @@ except ImportError: pass -from .util import MCPUtil +from .util import ( + MCPUtil, + ToolFilter, + ToolFilterCallable, + ToolFilterContext, + ToolFilterStatic, + create_static_tool_filter, +) __all__ = [ "MCPServer", @@ -22,4 +29,9 @@ "MCPServerStreamableHttp", "MCPServerStreamableHttpParams", "MCPUtil", + "ToolFilter", + "ToolFilterCallable", + "ToolFilterContext", + "ToolFilterStatic", + "create_static_tool_filter", ] diff --git a/src/agents/mcp/server.py b/src/agents/mcp/server.py index 3d1e17790..f6c2b58ef 100644 --- a/src/agents/mcp/server.py +++ b/src/agents/mcp/server.py @@ -2,10 +2,11 @@ import abc import asyncio +import inspect from contextlib import AbstractAsyncContextManager, AsyncExitStack from datetime import timedelta from pathlib import Path -from typing import Any, Literal +from typing import TYPE_CHECKING, Any, Literal, cast from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from mcp import ClientSession, StdioServerParameters, Tool as MCPTool, stdio_client @@ -17,6 +18,11 @@ from ..exceptions import UserError from ..logger import logger +from ..run_context import RunContextWrapper +from .util import ToolFilter, ToolFilterCallable, ToolFilterContext, ToolFilterStatic + +if TYPE_CHECKING: + from ..agent import Agent class MCPServer(abc.ABC): @@ -44,7 +50,11 @@ async def cleanup(self): pass @abc.abstractmethod - async def list_tools(self) -> list[MCPTool]: + async def list_tools( + self, + run_context: RunContextWrapper[Any], + agent: Agent[Any], + ) -> list[MCPTool]: """List the tools available on the server.""" pass @@ -57,7 +67,12 @@ async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> C class _MCPServerWithClientSession(MCPServer, abc.ABC): """Base class for MCP servers that use a `ClientSession` to communicate with the server.""" - def __init__(self, cache_tools_list: bool, client_session_timeout_seconds: float | None): + def __init__( + self, + cache_tools_list: bool, + client_session_timeout_seconds: float | None, + tool_filter: ToolFilter = None, + ): """ Args: cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be @@ -68,6 +83,7 @@ def __init__(self, cache_tools_list: bool, client_session_timeout_seconds: float (by avoiding a round-trip to the server every time). client_session_timeout_seconds: the read timeout passed to the MCP ClientSession. + tool_filter: The tool filter to use for filtering tools. """ self.session: ClientSession | None = None self.exit_stack: AsyncExitStack = AsyncExitStack() @@ -81,6 +97,88 @@ def __init__(self, cache_tools_list: bool, client_session_timeout_seconds: float self._cache_dirty = True self._tools_list: list[MCPTool] | None = None + self.tool_filter = tool_filter + + async def _apply_tool_filter( + self, + tools: list[MCPTool], + run_context: RunContextWrapper[Any], + agent: Agent[Any], + ) -> list[MCPTool]: + """Apply the tool filter to the list of tools.""" + if self.tool_filter is None: + return tools + + # Handle static tool filter + if isinstance(self.tool_filter, dict): + return self._apply_static_tool_filter(tools, self.tool_filter) + + # Handle callable tool filter (dynamic filter) + else: + return await self._apply_dynamic_tool_filter(tools, run_context, agent) + + def _apply_static_tool_filter( + self, + tools: list[MCPTool], + static_filter: ToolFilterStatic + ) -> list[MCPTool]: + """Apply static tool filtering based on allowlist and blocklist.""" + filtered_tools = tools + + # Apply allowed_tool_names filter (whitelist) + if "allowed_tool_names" in static_filter: + allowed_names = static_filter["allowed_tool_names"] + filtered_tools = [t for t in filtered_tools if t.name in allowed_names] + + # Apply blocked_tool_names filter (blacklist) + if "blocked_tool_names" in static_filter: + blocked_names = static_filter["blocked_tool_names"] + filtered_tools = [t for t in filtered_tools if t.name not in blocked_names] + + return filtered_tools + + async def _apply_dynamic_tool_filter( + self, + tools: list[MCPTool], + run_context: RunContextWrapper[Any], + agent: Agent[Any], + ) -> list[MCPTool]: + """Apply dynamic tool filtering using a callable filter function.""" + + # Ensure we have a callable filter and cast to help mypy + if not callable(self.tool_filter): + raise ValueError("Tool filter must be callable for dynamic filtering") + tool_filter_func = cast(ToolFilterCallable, self.tool_filter) + + # Create filter context + filter_context = ToolFilterContext( + run_context=run_context, + agent=agent, + server_name=self.name, + ) + + filtered_tools = [] + for tool in tools: + try: + # Call the filter function with context + result = tool_filter_func(filter_context, tool) + + if inspect.isawaitable(result): + should_include = await result + else: + should_include = result + + if should_include: + filtered_tools.append(tool) + except Exception as e: + logger.error( + f"Error applying tool filter to tool '{tool.name}' on server '{self.name}': {e}" + ) + # On error, exclude the tool for safety + continue + + return filtered_tools + @abc.abstractmethod def create_streams( self, @@ -131,21 +229,30 @@ async def connect(self): await self.cleanup() raise - async def list_tools(self) -> list[MCPTool]: + async def list_tools( + self, + run_context: RunContextWrapper[Any], + agent: Agent[Any], + ) -> list[MCPTool]: """List the tools available on the server.""" if not self.session: raise UserError("Server not initialized. Make sure you call `connect()` first.") # Return from cache if caching is enabled, we have tools, and the cache is not dirty if self.cache_tools_list and not self._cache_dirty and self._tools_list: - return self._tools_list - - # Reset the cache dirty to False - self._cache_dirty = False - - # Fetch the tools from the server - self._tools_list = (await self.session.list_tools()).tools - return self._tools_list + tools = self._tools_list + else: + # Reset the cache dirty to False + self._cache_dirty = False + # Fetch the tools from the server + self._tools_list = (await self.session.list_tools()).tools + tools = self._tools_list + + # Filter tools based on tool_filter + filtered_tools = tools + if self.tool_filter is not None: + filtered_tools = await self._apply_tool_filter(filtered_tools, run_context, agent) + return filtered_tools async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> CallToolResult: """Invoke a tool on the server.""" @@ -206,6 +313,7 @@ def __init__( cache_tools_list: bool = False, name: str | None = None, client_session_timeout_seconds: float | None = 5, + tool_filter: ToolFilter = None, ): """Create a new MCP server based on the stdio transport. @@ -223,8 +331,13 @@ def __init__( name: A readable name for the server. If not provided, we'll create one from the command. client_session_timeout_seconds: the read timeout passed to the MCP ClientSession. + tool_filter: The tool filter to use for filtering tools. """ - super().__init__(cache_tools_list, client_session_timeout_seconds) + super().__init__( + cache_tools_list, + client_session_timeout_seconds, + tool_filter, + ) self.params = StdioServerParameters( command=params["command"], @@ -283,6 +396,7 @@ def __init__( cache_tools_list: bool = False, name: str | None = None, client_session_timeout_seconds: float | None = 5, + tool_filter: ToolFilter = None, ): """Create a new MCP server based on the HTTP with SSE transport. @@ -302,8 +416,13 @@ def __init__( URL. client_session_timeout_seconds: the read timeout passed to the MCP ClientSession. + tool_filter: The tool filter to use for filtering tools. """ - super().__init__(cache_tools_list, client_session_timeout_seconds) + super().__init__( + cache_tools_list, + client_session_timeout_seconds, + tool_filter, + ) self.params = params self._name = name or f"sse: {self.params['url']}" @@ -362,6 +481,7 @@ def __init__( cache_tools_list: bool = False, name: str | None = None, client_session_timeout_seconds: float | None = 5, + tool_filter: ToolFilter = None, ): """Create a new MCP server based on the Streamable HTTP transport. @@ -382,8 +502,13 @@ def __init__( URL. client_session_timeout_seconds: the read timeout passed to the MCP ClientSession. + tool_filter: The tool filter to use for filtering tools. """ - super().__init__(cache_tools_list, client_session_timeout_seconds) + super().__init__( + cache_tools_list, + client_session_timeout_seconds, + tool_filter, + ) self.params = params self._name = name or f"streamable_http: {self.params['url']}" diff --git a/src/agents/mcp/util.py b/src/agents/mcp/util.py index 5a963bc01..48da9f841 100644 --- a/src/agents/mcp/util.py +++ b/src/agents/mcp/util.py @@ -1,6 +1,9 @@ import functools import json -from typing import TYPE_CHECKING, Any +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Callable, Optional, Union + +from typing_extensions import NotRequired, TypedDict from agents.strict_schema import ensure_strict_json_schema @@ -10,25 +13,102 @@ from ..run_context import RunContextWrapper from ..tool import FunctionTool, Tool from ..tracing import FunctionSpanData, get_current_span, mcp_tools_span +from ..util._types import MaybeAwaitable if TYPE_CHECKING: from mcp.types import Tool as MCPTool + from ..agent import Agent from .server import MCPServer +@dataclass +class ToolFilterContext: + """Context information available to tool filter functions.""" + + run_context: RunContextWrapper[Any] + """The current run context.""" + + agent: "Agent[Any]" + """The agent that is requesting the tool list.""" + + server_name: str + """The name of the MCP server.""" + + +ToolFilterCallable = Callable[["ToolFilterContext", "MCPTool"], MaybeAwaitable[bool]] +"""A function that determines whether a tool should be available. + +Args: + context: The context information including run context, agent, and server name. + tool: The MCP tool to filter. + +Returns: + Whether the tool should be available (True) or filtered out (False). +""" + + +class ToolFilterStatic(TypedDict): + """Static tool filter configuration using allowlists and blocklists.""" + + allowed_tool_names: NotRequired[list[str]] + """Optional list of tool names to allow (whitelist). + If set, only these tools will be available.""" + + blocked_tool_names: NotRequired[list[str]] + """Optional list of tool names to exclude (blacklist). + If set, these tools will be filtered out.""" + + +ToolFilter = Union[ToolFilterCallable, ToolFilterStatic, None] +"""A tool filter that can be either a function, static configuration, or None (no filtering).""" + + +def create_static_tool_filter( + allowed_tool_names: Optional[list[str]] = None, + blocked_tool_names: Optional[list[str]] = None, +) -> Optional[ToolFilterStatic]: + """Create a static tool filter from allowlist and blocklist parameters. + + This is a convenience function for creating a ToolFilterStatic. + + Args: + allowed_tool_names: Optional list of tool names to allow (whitelist). + blocked_tool_names: Optional list of tool names to exclude (blacklist). + + Returns: + A ToolFilterStatic if any filtering is specified, None otherwise. + """ + if allowed_tool_names is None and blocked_tool_names is None: + return None + + filter_dict: ToolFilterStatic = {} + if allowed_tool_names is not None: + filter_dict["allowed_tool_names"] = allowed_tool_names + if blocked_tool_names is not None: + filter_dict["blocked_tool_names"] = blocked_tool_names + + return filter_dict + + class MCPUtil: """Set of utilities for interop between MCP and Agents SDK tools.""" @classmethod async def get_all_function_tools( - cls, servers: list["MCPServer"], convert_schemas_to_strict: bool + cls, + servers: list["MCPServer"], + convert_schemas_to_strict: bool, + run_context: RunContextWrapper[Any], + agent: "Agent[Any]", ) -> list[Tool]: """Get all function tools from a list of MCP servers.""" tools = [] tool_names: set[str] = set() for server in servers: - server_tools = await cls.get_function_tools(server, convert_schemas_to_strict) + server_tools = await cls.get_function_tools( + server, convert_schemas_to_strict, run_context, agent + ) server_tool_names = {tool.name for tool in server_tools} if len(server_tool_names & tool_names) > 0: raise UserError( @@ -42,12 +122,16 @@ async def get_all_function_tools( @classmethod async def get_function_tools( - cls, server: "MCPServer", convert_schemas_to_strict: bool + cls, + server: "MCPServer", + convert_schemas_to_strict: bool, + run_context: RunContextWrapper[Any], + agent: "Agent[Any]", ) -> list[Tool]: """Get all function tools from a single MCP server.""" with mcp_tools_span(server=server.name) as span: - tools = await server.list_tools() + tools = await server.list_tools(run_context, agent) span.span_data.result = [tool.name for tool in tools] return [cls.to_function_tool(tool, server, convert_schemas_to_strict) for tool in tools] diff --git a/src/agents/tool.py b/src/agents/tool.py index ce66a53ba..283a9e24f 100644 --- a/src/agents/tool.py +++ b/src/agents/tool.py @@ -26,6 +26,7 @@ from .util._types import MaybeAwaitable if TYPE_CHECKING: + from .agent import Agent ToolParams = ParamSpec("ToolParams") diff --git a/tests/mcp/helpers.py b/tests/mcp/helpers.py index 8ff153c18..e0d8a813d 100644 --- a/tests/mcp/helpers.py +++ b/tests/mcp/helpers.py @@ -1,3 +1,4 @@ +import asyncio import json import shutil from typing import Any @@ -6,6 +7,8 @@ from mcp.types import CallToolResult, TextContent from agents.mcp import MCPServer +from agents.mcp.server import _MCPServerWithClientSession +from agents.mcp.util import ToolFilter tee = shutil.which("tee") or "" assert tee, "tee not found" @@ -28,11 +31,41 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): pass +class _TestFilterServer(_MCPServerWithClientSession): + """Minimal implementation of _MCPServerWithClientSession for testing tool filtering""" + + def __init__(self, tool_filter: ToolFilter, server_name: str): + # Initialize parent class properly to avoid type errors + super().__init__( + cache_tools_list=False, + client_session_timeout_seconds=None, + tool_filter=tool_filter, + ) + self._server_name: str = server_name + # Override some attributes for test isolation + self.session = None + self._cleanup_lock = asyncio.Lock() + + def create_streams(self): + raise NotImplementedError("Not needed for filtering tests") + + @property + def name(self) -> str: + return self._server_name + + class FakeMCPServer(MCPServer): - def __init__(self, tools: list[MCPTool] | None = None): + def __init__( + self, + tools: list[MCPTool] | None = None, + tool_filter: ToolFilter = None, + server_name: str = "fake_mcp_server", + ): self.tools: list[MCPTool] = tools or [] self.tool_calls: list[str] = [] self.tool_results: list[str] = [] + self.tool_filter = tool_filter + self._server_name = server_name def add_tool(self, name: str, input_schema: dict[str, Any]): self.tools.append(MCPTool(name=name, inputSchema=input_schema)) @@ -43,8 +76,16 @@ async def connect(self): async def cleanup(self): pass - async def list_tools(self): - return self.tools + async def list_tools(self, run_context=None, agent=None): + tools = self.tools + + # Apply tool filtering using the REAL implementation + if self.tool_filter is not None: + # Use the real _MCPServerWithClientSession filtering logic + filter_server = _TestFilterServer(self.tool_filter, self.name) + tools = await filter_server._apply_tool_filter(tools, run_context, agent) + + return tools async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> CallToolResult: self.tool_calls.append(tool_name) @@ -55,4 +96,4 @@ async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> C @property def name(self) -> str: - return "fake_mcp_server" + return self._server_name diff --git a/tests/mcp/test_caching.py b/tests/mcp/test_caching.py index cac409e6e..f31cdf951 100644 --- a/tests/mcp/test_caching.py +++ b/tests/mcp/test_caching.py @@ -3,7 +3,9 @@ import pytest from mcp.types import ListToolsResult, Tool as MCPTool +from agents import Agent from agents.mcp import MCPServerStdio +from agents.run_context import RunContextWrapper from .helpers import DummyStreamsContextManager, tee @@ -33,25 +35,29 @@ async def test_server_caching_works( mock_list_tools.return_value = ListToolsResult(tools=tools) async with server: + # Create test context and agent + run_context = RunContextWrapper(context=None) + agent = Agent(name="test_agent", instructions="Test agent") + # Call list_tools() multiple times - tools = await server.list_tools() - assert tools == tools + result_tools = await server.list_tools(run_context, agent) + assert result_tools == tools assert mock_list_tools.call_count == 1, "list_tools() should have been called once" # Call list_tools() again, should return the cached value - tools = await server.list_tools() - assert tools == tools + result_tools = await server.list_tools(run_context, agent) + assert result_tools == tools assert mock_list_tools.call_count == 1, "list_tools() should not have been called again" # Invalidate the cache and call list_tools() again server.invalidate_tools_cache() - tools = await server.list_tools() - assert tools == tools + result_tools = await server.list_tools(run_context, agent) + assert result_tools == tools assert mock_list_tools.call_count == 2, "list_tools() should be called again" # Without invalidating the cache, calling list_tools() again should return the cached value - tools = await server.list_tools() - assert tools == tools + result_tools = await server.list_tools(run_context, agent) + assert result_tools == tools diff --git a/tests/mcp/test_mcp_util.py b/tests/mcp/test_mcp_util.py index 74356a16d..3230e63dd 100644 --- a/tests/mcp/test_mcp_util.py +++ b/tests/mcp/test_mcp_util.py @@ -57,7 +57,10 @@ async def test_get_all_function_tools(): server3.add_tool(names[4], schemas[4]) servers: list[MCPServer] = [server1, server2, server3] - tools = await MCPUtil.get_all_function_tools(servers, convert_schemas_to_strict=False) + run_context = RunContextWrapper(context=None) + agent = Agent(name="test_agent", instructions="Test agent") + + tools = await MCPUtil.get_all_function_tools(servers, False, run_context, agent) assert len(tools) == 5 assert all(tool.name in names for tool in tools) @@ -70,7 +73,7 @@ async def test_get_all_function_tools(): assert tool.name == names[idx] # Also make sure it works with strict schemas - tools = await MCPUtil.get_all_function_tools(servers, convert_schemas_to_strict=True) + tools = await MCPUtil.get_all_function_tools(servers, True, run_context, agent) assert len(tools) == 5 assert all(tool.name in names for tool in tools) @@ -144,7 +147,8 @@ async def test_agent_convert_schemas_true(): agent = Agent( name="test_agent", mcp_servers=[server], mcp_config={"convert_schemas_to_strict": True} ) - tools = await agent.get_mcp_tools() + run_context = RunContextWrapper(context=None) + tools = await agent.get_mcp_tools(run_context) foo_tool = next(tool for tool in tools if tool.name == "foo") assert isinstance(foo_tool, FunctionTool) @@ -208,7 +212,8 @@ async def test_agent_convert_schemas_false(): agent = Agent( name="test_agent", mcp_servers=[server], mcp_config={"convert_schemas_to_strict": False} ) - tools = await agent.get_mcp_tools() + run_context = RunContextWrapper(context=None) + tools = await agent.get_mcp_tools(run_context) foo_tool = next(tool for tool in tools if tool.name == "foo") assert isinstance(foo_tool, FunctionTool) @@ -245,7 +250,8 @@ async def test_agent_convert_schemas_unset(): server.add_tool("bar", non_strict_schema) server.add_tool("baz", possible_to_convert_schema) agent = Agent(name="test_agent", mcp_servers=[server]) - tools = await agent.get_mcp_tools() + run_context = RunContextWrapper(context=None) + tools = await agent.get_mcp_tools(run_context) foo_tool = next(tool for tool in tools if tool.name == "foo") assert isinstance(foo_tool, FunctionTool) @@ -279,7 +285,9 @@ async def test_util_adds_properties(): server = FakeMCPServer() server.add_tool("test_tool", schema) - tools = await MCPUtil.get_all_function_tools([server], convert_schemas_to_strict=False) + run_context = RunContextWrapper(context=None) + agent = Agent(name="test_agent", instructions="Test agent") + tools = await MCPUtil.get_all_function_tools([server], False, run_context, agent) tool = next(tool for tool in tools if tool.name == "test_tool") assert isinstance(tool, FunctionTool) diff --git a/tests/mcp/test_server_errors.py b/tests/mcp/test_server_errors.py index fbd8db17d..9e0455115 100644 --- a/tests/mcp/test_server_errors.py +++ b/tests/mcp/test_server_errors.py @@ -1,7 +1,9 @@ import pytest +from agents import Agent from agents.exceptions import UserError from agents.mcp.server import _MCPServerWithClientSession +from agents.run_context import RunContextWrapper class CrashingClientSessionServer(_MCPServerWithClientSession): @@ -35,8 +37,11 @@ async def test_server_errors_cause_error_and_cleanup_called(): async def test_not_calling_connect_causes_error(): server = CrashingClientSessionServer() + run_context = RunContextWrapper(context=None) + agent = Agent(name="test_agent", instructions="Test agent") + with pytest.raises(UserError): - await server.list_tools() + await server.list_tools(run_context, agent) with pytest.raises(UserError): await server.call_tool("foo", {}) diff --git a/tests/mcp/test_tool_filtering.py b/tests/mcp/test_tool_filtering.py new file mode 100644 index 000000000..c1ffff4b8 --- /dev/null +++ b/tests/mcp/test_tool_filtering.py @@ -0,0 +1,243 @@ +""" +Tool filtering tests use FakeMCPServer instead of real MCPServer implementations to avoid +external dependencies (processes, network connections) and ensure fast, reliable unit tests. +FakeMCPServer delegates filtering logic to the real _MCPServerWithClientSession implementation. +""" +import asyncio + +import pytest +from mcp import Tool as MCPTool + +from agents import Agent +from agents.mcp import ToolFilterContext, create_static_tool_filter +from agents.run_context import RunContextWrapper + +from .helpers import FakeMCPServer + + +def create_test_agent(name: str = "test_agent") -> Agent: + """Create a test agent for filtering tests.""" + return Agent(name=name, instructions="Test agent") + + +def create_test_context() -> RunContextWrapper: + """Create a test run context for filtering tests.""" + return RunContextWrapper(context=None) + + +# === Static Tool Filtering Tests === + +@pytest.mark.asyncio +async def test_static_tool_filtering(): + """Test all static tool filtering scenarios: allowed, blocked, both, none, etc.""" + server = FakeMCPServer(server_name="test_server") + server.add_tool("tool1", {}) + server.add_tool("tool2", {}) + server.add_tool("tool3", {}) + server.add_tool("tool4", {}) + + # Create test context and agent for all calls + run_context = create_test_context() + agent = create_test_agent() + + # Test allowed_tool_names only + server.tool_filter = {"allowed_tool_names": ["tool1", "tool2"]} + tools = await server.list_tools(run_context, agent) + assert len(tools) == 2 + assert {t.name for t in tools} == {"tool1", "tool2"} + + # Test blocked_tool_names only + server.tool_filter = {"blocked_tool_names": ["tool3", "tool4"]} + tools = await server.list_tools(run_context, agent) + assert len(tools) == 2 + assert {t.name for t in tools} == {"tool1", "tool2"} + + # Test both filters together (allowed first, then blocked) + server.tool_filter = { + "allowed_tool_names": ["tool1", "tool2", "tool3"], + "blocked_tool_names": ["tool3"] + } + tools = await server.list_tools(run_context, agent) + assert len(tools) == 2 + assert {t.name for t in tools} == {"tool1", "tool2"} + + # Test no filter + server.tool_filter = None + tools = await server.list_tools(run_context, agent) + assert len(tools) == 4 + + # Test helper function + server.tool_filter = create_static_tool_filter( + allowed_tool_names=["tool1", "tool2"], + blocked_tool_names=["tool2"] + ) + tools = await server.list_tools(run_context, agent) + assert len(tools) == 1 + assert tools[0].name == "tool1" + + +# === Dynamic Tool Filtering Core Tests === + +@pytest.mark.asyncio +async def test_dynamic_filter_sync_and_async(): + """Test both synchronous and asynchronous dynamic filters""" + server = FakeMCPServer(server_name="test_server") + server.add_tool("allowed_tool", {}) + server.add_tool("blocked_tool", {}) + server.add_tool("restricted_tool", {}) + + # Create test context and agent + run_context = create_test_context() + agent = create_test_agent() + + # Test sync filter + def sync_filter(context: ToolFilterContext, tool: MCPTool) -> bool: + return tool.name.startswith("allowed") + + server.tool_filter = sync_filter + tools = await server.list_tools(run_context, agent) + assert len(tools) == 1 + assert tools[0].name == "allowed_tool" + + # Test async filter + async def async_filter(context: ToolFilterContext, tool: MCPTool) -> bool: + await asyncio.sleep(0.001) # Simulate async operation + return "restricted" not in tool.name + + server.tool_filter = async_filter + tools = await server.list_tools(run_context, agent) + assert len(tools) == 2 + assert {t.name for t in tools} == {"allowed_tool", "blocked_tool"} + + +@pytest.mark.asyncio +async def test_dynamic_filter_context_handling(): + """Test dynamic filters with context access""" + server = FakeMCPServer(server_name="test_server") + server.add_tool("admin_tool", {}) + server.add_tool("user_tool", {}) + server.add_tool("guest_tool", {}) + + # Test context-independent filter + def context_independent_filter(context: ToolFilterContext, tool: MCPTool) -> bool: + return not tool.name.startswith("admin") + + server.tool_filter = context_independent_filter + run_context = create_test_context() + agent = create_test_agent() + tools = await server.list_tools(run_context, agent) + assert len(tools) == 2 + assert {t.name for t in tools} == {"user_tool", "guest_tool"} + + # Test context-dependent filter (needs context) + def context_dependent_filter(context: ToolFilterContext, tool: MCPTool) -> bool: + assert context is not None + assert context.run_context is not None + assert context.agent is not None + assert context.server_name == "test_server" + + # Only admin tools for agents with "admin" in name + if "admin" in context.agent.name.lower(): + return True + else: + return not tool.name.startswith("admin") + + server.tool_filter = context_dependent_filter + + # Should work with context + run_context = RunContextWrapper(context=None) + regular_agent = create_test_agent("regular_user") + tools = await server.list_tools(run_context, regular_agent) + assert len(tools) == 2 + assert {t.name for t in tools} == {"user_tool", "guest_tool"} + + admin_agent = create_test_agent("admin_user") + tools = await server.list_tools(run_context, admin_agent) + assert len(tools) == 3 + + +@pytest.mark.asyncio +async def test_dynamic_filter_error_handling(): + """Test error handling in dynamic filters""" + server = FakeMCPServer(server_name="test_server") + server.add_tool("good_tool", {}) + server.add_tool("error_tool", {}) + server.add_tool("another_good_tool", {}) + + def error_prone_filter(context: ToolFilterContext, tool: MCPTool) -> bool: + if tool.name == "error_tool": + raise ValueError("Simulated filter error") + return True + + server.tool_filter = error_prone_filter + + # Test with server call + run_context = create_test_context() + agent = create_test_agent() + tools = await server.list_tools(run_context, agent) + assert len(tools) == 2 + assert {t.name for t in tools} == {"good_tool", "another_good_tool"} + + +# === Integration Tests === + +@pytest.mark.asyncio +async def test_agent_dynamic_filtering_integration(): + """Test dynamic filtering integration with Agent methods""" + server = FakeMCPServer() + server.add_tool("file_read", {"type": "object", "properties": {"path": {"type": "string"}}}) + server.add_tool( + "file_write", + { + "type": "object", + "properties": {"path": {"type": "string"}, "content": {"type": "string"}}, + }, + ) + server.add_tool( + "database_query", {"type": "object", "properties": {"query": {"type": "string"}}} + ) + server.add_tool( + "network_request", {"type": "object", "properties": {"url": {"type": "string"}}} + ) + + # Role-based filter for comprehensive testing + async def role_based_filter(context: ToolFilterContext, tool: MCPTool) -> bool: + # Simulate async permission check + await asyncio.sleep(0.001) + + agent_name = context.agent.name.lower() + if "admin" in agent_name: + return True + elif "readonly" in agent_name: + return "read" in tool.name or "query" in tool.name + else: + return tool.name.startswith("file_") + + server.tool_filter = role_based_filter + + # Test admin agent + admin_agent = Agent(name="admin_user", instructions="Admin", mcp_servers=[server]) + run_context = RunContextWrapper(context=None) + admin_tools = await admin_agent.get_mcp_tools(run_context) + assert len(admin_tools) == 4 + + # Test readonly agent + readonly_agent = Agent(name="readonly_viewer", instructions="Read-only", mcp_servers=[server]) + readonly_tools = await readonly_agent.get_mcp_tools(run_context) + assert len(readonly_tools) == 2 + assert {t.name for t in readonly_tools} == {"file_read", "database_query"} + + # Test regular agent + regular_agent = Agent(name="regular_user", instructions="Regular", mcp_servers=[server]) + regular_tools = await regular_agent.get_mcp_tools(run_context) + assert len(regular_tools) == 2 + assert {t.name for t in regular_tools} == {"file_read", "file_write"} + + # Test get_all_tools method + all_tools = await regular_agent.get_all_tools(run_context) + mcp_tool_names = { + t.name + for t in all_tools + if t.name in {"file_read", "file_write", "database_query", "network_request"} + } + assert mcp_tool_names == {"file_read", "file_write"} From 25db658f75999f309522819c92f653be5f050134 Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Sat, 28 Jun 2025 00:22:35 +0900 Subject: [PATCH 16/23] Tweak in pyproject.toml (#952) - Add python 3.13 - Update the homepage URL --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 6c0bbe9f1..0dc0ca567 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,13 +23,14 @@ classifiers = [ "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", "Operating System :: OS Independent", "Topic :: Software Development :: Libraries :: Python Modules", "License :: OSI Approved :: MIT License", ] [project.urls] -Homepage = "https://github.com/openai/openai-agents-python" +Homepage = "https://openai.github.io/openai-agents-python/" Repository = "https://github.com/openai/openai-agents-python" [project.optional-dependencies] From c772205a7ecdd5e795d537fb902a0ff7724023c2 Mon Sep 17 00:00:00 2001 From: Shiraz Ali <147478490+shirazkk@users.noreply.github.com> Date: Fri, 27 Jun 2025 20:23:37 +0500 Subject: [PATCH 17/23] Fix incorrect argument description in `on_trace_end` docstring (#958) The docstring for the `on_trace_end` method in the `TracingProcessor` class incorrectly describes the argument `trace`. The current description states that the argument is the trace that started, but it should refer to the trace that has finished. This commit updates the docstring to clarify that the argument represents the trace that has completed, ensuring accuracy and clarity. Changes: - Updated the argument description for `trace` in `on_trace_end` to clarify it refers to the trace that finished. --- src/agents/tracing/processor_interface.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/agents/tracing/processor_interface.py b/src/agents/tracing/processor_interface.py index 4dcd897c7..0a05bcae2 100644 --- a/src/agents/tracing/processor_interface.py +++ b/src/agents/tracing/processor_interface.py @@ -23,7 +23,7 @@ def on_trace_end(self, trace: "Trace") -> None: """Called when a trace is finished. Args: - trace: The trace that started. + trace: The trace that finished. """ pass From 381e90d9b4fda8e7a52c4c8c85c893843ef997a2 Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Sat, 28 Jun 2025 00:24:10 +0900 Subject: [PATCH 18/23] Fix #944 by updating the models document page to cover extra_args (#950) This pull request resolves #944 --- docs/ja/models.md | 106 --------------------------------------- docs/ja/models/index.md | 107 +++++++++++++++++++++++++++------------- docs/models/index.md | 16 ++++++ 3 files changed, 89 insertions(+), 140 deletions(-) delete mode 100644 docs/ja/models.md diff --git a/docs/ja/models.md b/docs/ja/models.md deleted file mode 100644 index 5a76d60ec..000000000 --- a/docs/ja/models.md +++ /dev/null @@ -1,106 +0,0 @@ -# モデル - -Agents SDK には、OpenAI モデルの 2 種類のサポートが標準で用意されています。 - -- **推奨**: [`OpenAIResponsesModel`][agents.models.openai_responses.OpenAIResponsesModel] は、新しい [Responses API](https://platform.openai.com/docs/api-reference/responses) を使って OpenAI API を呼び出します。 -- [`OpenAIChatCompletionsModel`][agents.models.openai_chatcompletions.OpenAIChatCompletionsModel] は、[Chat Completions API](https://platform.openai.com/docs/api-reference/chat) を使って OpenAI API を呼び出します。 - -## モデルの組み合わせ - -1 つのワークフロー内で、各エージェントごとに異なるモデルを使いたい場合があります。たとえば、トリアージには小型で高速なモデルを使い、複雑なタスクにはより大きく高性能なモデルを使うことができます。[`Agent`][agents.Agent] を設定する際、以下のいずれかの方法で特定のモデルを選択できます。 - -1. OpenAI モデル名を直接渡す。 -2. 任意のモデル名と、その名前を Model インスタンスにマッピングできる [`ModelProvider`][agents.models.interface.ModelProvider] を渡す。 -3. [`Model`][agents.models.interface.Model] 実装を直接指定する。 - -!!!note - - SDK は [`OpenAIResponsesModel`][agents.models.openai_responses.OpenAIResponsesModel] と [`OpenAIChatCompletionsModel`][agents.models.openai_chatcompletions.OpenAIChatCompletionsModel] の両方の形状をサポートしていますが、各ワークフローで 1 つのモデル形状のみを使うことを推奨します。なぜなら、2 つの形状はサポートする機能やツールが異なるためです。ワークフローでモデル形状を組み合わせて使う場合は、利用するすべての機能が両方で利用可能かご確認ください。 - -```python -from agents import Agent, Runner, AsyncOpenAI, OpenAIChatCompletionsModel -import asyncio - -spanish_agent = Agent( - name="Spanish agent", - instructions="You only speak Spanish.", - model="o3-mini", # (1)! -) - -english_agent = Agent( - name="English agent", - instructions="You only speak English", - model=OpenAIChatCompletionsModel( # (2)! - model="gpt-4o", - openai_client=AsyncOpenAI() - ), -) - -triage_agent = Agent( - name="Triage agent", - instructions="Handoff to the appropriate agent based on the language of the request.", - handoffs=[spanish_agent, english_agent], - model="gpt-3.5-turbo", -) - -async def main(): - result = await Runner.run(triage_agent, input="Hola, ¿cómo estás?") - print(result.final_output) -``` - -1. OpenAI モデル名を直接設定します。 -2. [`Model`][agents.models.interface.Model] 実装を指定します。 - -エージェントで使用するモデルをさらに細かく設定したい場合は、[`ModelSettings`][agents.models.interface.ModelSettings] を渡すことができます。これにより、temperature などのオプションのモデル設定パラメーターを指定できます。 - -```python -from agents import Agent, ModelSettings - -english_agent = Agent( - name="English agent", - instructions="You only speak English", - model="gpt-4o", - model_settings=ModelSettings(temperature=0.1), -) -``` - -## 他の LLM プロバイダーの利用 - -他の LLM プロバイダーは、3 つの方法で利用できます([こちら](https://github.com/openai/openai-agents-python/tree/main/examples/model_providers/) に code examples があります)。 - -1. [`set_default_openai_client`][agents.set_default_openai_client] は、`AsyncOpenAI` のインスタンスを LLM クライアントとしてグローバルに利用したい場合に便利です。これは、LLM プロバイダーが OpenAI 互換の API エンドポイントを持ち、`base_url` と `api_key` を設定できる場合に使います。[examples/model_providers/custom_example_global.py](https://github.com/openai/openai-agents-python/tree/main/examples/model_providers/custom_example_global.py) に設定例があります。 -2. [`ModelProvider`][agents.models.interface.ModelProvider] は `Runner.run` レベルで利用します。これにより、「この実行のすべてのエージェントでカスタムモデルプロバイダーを使う」と指定できます。[examples/model_providers/custom_example_provider.py](https://github.com/openai/openai-agents-python/tree/main/examples/model_providers/custom_example_provider.py) に設定例があります。 -3. [`Agent.model`][agents.agent.Agent.model] で、特定のエージェントインスタンスにモデルを指定できます。これにより、エージェントごとに異なるプロバイダーを組み合わせて使うことができます。[examples/model_providers/custom_example_agent.py](https://github.com/openai/openai-agents-python/tree/main/examples/model_providers/custom_example_agent.py) に設定例があります。 - -`platform.openai.com` の API キーがない場合は、`set_tracing_disabled()` でトレーシングを無効にするか、[別のトレーシングプロセッサー](tracing.md) を設定することを推奨します。 - -!!! note - - これらの code examples では Chat Completions API/モデルを使っています。なぜなら、ほとんどの LLM プロバイダーはまだ Responses API をサポートしていないためです。もし LLM プロバイダーが Responses API をサポートしている場合は、Responses の利用を推奨します。 - -## 他の LLM プロバイダー利用時のよくある問題 - -### Tracing クライアントの 401 エラー - -トレーシングに関連するエラーが発生した場合、これはトレースが OpenAI サーバーにアップロードされるため、OpenAI API キーがないことが原因です。解決方法は 3 つあります。 - -1. トレーシングを完全に無効化する: [`set_tracing_disabled(True)`][agents.set_tracing_disabled]。 -2. トレーシング用の OpenAI キーを設定する: [`set_tracing_export_api_key(...)`][agents.set_tracing_export_api_key]。この API キーはトレースのアップロードのみに使われ、[platform.openai.com](https://platform.openai.com/) のものが必要です。 -3. OpenAI 以外のトレースプロセッサーを使う。[トレーシングのドキュメント](tracing.md#custom-tracing-processors) をご覧ください。 - -### Responses API サポート - -SDK はデフォルトで Responses API を使いますが、ほとんどの他の LLM プロバイダーはまだ対応していません。そのため、404 エラーなどが発生する場合があります。解決方法は 2 つあります。 - -1. [`set_default_openai_api("chat_completions")`][agents.set_default_openai_api] を呼び出します。これは、環境変数で `OPENAI_API_KEY` と `OPENAI_BASE_URL` を設定している場合に有効です。 -2. [`OpenAIChatCompletionsModel`][agents.models.openai_chatcompletions.OpenAIChatCompletionsModel] を使います。[こちら](https://github.com/openai/openai-agents-python/tree/main/examples/model_providers/) に code examples があります。 - -### structured outputs サポート - -一部のモデルプロバイダーは [structured outputs](https://platform.openai.com/docs/guides/structured-outputs) をサポートしていません。その場合、次のようなエラーが発生することがあります。 - -``` -BadRequestError: Error code: 400 - {'error': {'message': "'response_format.type' : value is not one of the allowed values ['text','json_object']", 'type': 'invalid_request_error'}} -``` - -これは一部のモデルプロバイダーの制限で、JSON 出力には対応していても、出力に使う `json_schema` を指定できない場合があります。現在この問題の修正に取り組んでいますが、JSON schema 出力をサポートしているプロバイダーの利用を推奨します。そうでない場合、不正な JSON によりアプリが頻繁に動作しなくなる可能性があります。 \ No newline at end of file diff --git a/docs/ja/models/index.md b/docs/ja/models/index.md index a40ae38f6..410c01676 100644 --- a/docs/ja/models/index.md +++ b/docs/ja/models/index.md @@ -4,21 +4,52 @@ search: --- # モデル -Agents SDK には、標準で 2 種類の OpenAI モデルサポートが含まれています。 +Agents SDK は OpenAI モデルを 2 つの形態で即利用できます。 -- **推奨**: [`OpenAIResponsesModel`][agents.models.openai_responses.OpenAIResponsesModel] — 新しい [Responses API](https://platform.openai.com/docs/api-reference/responses) を利用して OpenAI API を呼び出します。 -- [`OpenAIChatCompletionsModel`][agents.models.openai_chatcompletions.OpenAIChatCompletionsModel] — [Chat Completions API](https://platform.openai.com/docs/api-reference/chat) を利用して OpenAI API を呼び出します。 +- **推奨**: [`OpenAIResponsesModel`][agents.models.openai_responses.OpenAIResponsesModel] は、新しい [Responses API](https://platform.openai.com/docs/api-reference/responses) を使用して OpenAI API を呼び出します。 +- [`OpenAIChatCompletionsModel`][agents.models.openai_chatcompletions.OpenAIChatCompletionsModel] は、[Chat Completions API](https://platform.openai.com/docs/api-reference/chat) を使用して OpenAI API を呼び出します。 + +## 非 OpenAI モデル + +ほとんどの非 OpenAI モデルは [LiteLLM インテグレーション](./litellm.md) 経由で利用できます。まず、litellm 依存グループをインストールします: + +```bash +pip install "openai-agents[litellm]" +``` + +次に、`litellm/` 接頭辞を付けて任意の [サポート対象モデル](https://docs.litellm.ai/docs/providers) を使用します: + +```python +claude_agent = Agent(model="litellm/anthropic/claude-3-5-sonnet-20240620", ...) +gemini_agent = Agent(model="litellm/gemini/gemini-2.5-flash-preview-04-17", ...) +``` + +### 非 OpenAI モデルを利用するその他の方法 + +他の LLM プロバイダーを統合する方法は、あと 3 つあります([こちら](https://github.com/openai/openai-agents-python/tree/main/examples/model_providers/) に例があります)。 + +1. [`set_default_openai_client`][agents.set_default_openai_client] + `AsyncOpenAI` インスタンスを LLM クライアントとしてグローバルに使用したい場合に便利です。LLM プロバイダーが OpenAI 互換の API エンドポイントを持ち、`base_url` と `api_key` を設定できる場合に使用します。設定例は [examples/model_providers/custom_example_global.py](https://github.com/openai/openai-agents-python/tree/main/examples/model_providers/custom_example_global.py) にあります。 +2. [`ModelProvider`][agents.models.interface.ModelProvider] + `Runner.run` レベルでカスタムモデルプロバイダーを指定できます。これにより「この run のすべてのエージェントでカスタムプロバイダーを使う」と宣言できます。設定例は [examples/model_providers/custom_example_provider.py](https://github.com/openai/openai-agents-python/tree/main/examples/model_providers/custom_example_provider.py) にあります。 +3. [`Agent.model`][agents.agent.Agent.model] + 特定のエージェントインスタンスにモデルを指定できます。エージェントごとに異なるプロバイダーを組み合わせることが可能です。設定例は [examples/model_providers/custom_example_agent.py](https://github.com/openai/openai-agents-python/tree/main/examples/model_providers/custom_example_agent.py) にあります。ほとんどのモデルを簡単に利用する方法として [LiteLLM インテグレーション](./litellm.md) を利用できます。 + +`platform.openai.com` の API キーを持っていない場合は、`set_tracing_disabled()` でトレーシングを無効化するか、[別のトレーシングプロセッサー](../tracing.md) を設定することをお勧めします。 + +!!! note + これらの例では、Responses API をまだサポートしていない LLM プロバイダーが多いため、Chat Completions API/モデルを使用しています。LLM プロバイダーが Responses API をサポートしている場合は、Responses を使用することを推奨します。 ## モデルの組み合わせ -1 つのワークフロー内で、エージェントごとに異なるモデルを使用したい場合があります。たとえば、振り分けには小さく高速なモデルを、複雑なタスクには大きく高性能なモデルを使う、といった使い分けです。[`Agent`][agents.Agent] を設定する際は、以下のいずれかで特定のモデルを指定できます。 +1 つのワークフロー内でエージェントごとに異なるモデルを使用したい場合があります。たとえば、振り分けには小さく高速なモデルを、複雑なタスクには大きく高性能なモデルを使用するといったケースです。[`Agent`][agents.Agent] を設定する際、次のいずれかの方法でモデルを選択できます。 -1. OpenAI モデル名を直接渡す -2. 任意のモデル名と、それを `Model` インスタンスへマッピングできる [`ModelProvider`][agents.models.interface.ModelProvider] を渡す +1. モデル名を直接指定する +2. 任意のモデル名と、その名前を Model インスタンスへマッピングできる [`ModelProvider`][agents.models.interface.ModelProvider] を指定する 3. [`Model`][agents.models.interface.Model] 実装を直接渡す !!!note - SDK は [`OpenAIResponsesModel`][agents.models.openai_responses.OpenAIResponsesModel] と [`OpenAIChatCompletionsModel`][agents.models.openai_chatcompletions.OpenAIChatCompletionsModel] の両方の形に対応していますが、ワークフローごとに 1 つのモデル形を使用することを推奨します。2 つの形ではサポートする機能・ツールが異なるためです。どうしても混在させる場合は、利用するすべての機能が両方で利用可能であることを確認してください。 + SDK は [`OpenAIResponsesModel`][agents.models.openai_responses.OpenAIResponsesModel] と [`OpenAIChatCompletionsModel`][agents.models.openai_chatcompletions.OpenAIChatCompletionsModel] の両形態をサポートしていますが、各ワークフローで 1 つのモデル形態に統一することを推奨します。2 つの形態はサポートする機能とツールが異なるためです。混在させる場合は、使用する機能が双方で利用可能かを必ず確認してください。 ```python from agents import Agent, Runner, AsyncOpenAI, OpenAIChatCompletionsModel @@ -51,10 +82,10 @@ async def main(): print(result.final_output) ``` -1. OpenAI モデル名を直接指定 +1. OpenAI のモデル名を直接設定 2. [`Model`][agents.models.interface.Model] 実装を提供 -エージェントで使用するモデルをさらに細かく設定したい場合は、`temperature` などのオプションを指定できる [`ModelSettings`][agents.models.interface.ModelSettings] を渡します。 +エージェントで使用するモデルをさらに構成したい場合は、`temperature` などのオプションパラメーターを指定できる [`ModelSettings`][agents.models.interface.ModelSettings] を渡せます。 ```python from agents import Agent, ModelSettings @@ -67,50 +98,58 @@ english_agent = Agent( ) ``` -## 他の LLM プロバイダーの利用 - -他の LLM プロバイダーは 3 通りの方法で利用できます(コード例は [こちら](https://github.com/openai/openai-agents-python/tree/main/examples/model_providers/))。 - -1. [`set_default_openai_client`][agents.set_default_openai_client] - OpenAI 互換の API エンドポイントを持つ場合に、`AsyncOpenAI` インスタンスをグローバルに LLM クライアントとして設定できます。`base_url` と `api_key` を設定するケースです。設定例は [examples/model_providers/custom_example_global.py](https://github.com/openai/openai-agents-python/tree/main/examples/model_providers/custom_example_global.py)。 +OpenAI の Responses API を使用する場合、`user` や `service_tier` など[その他のオプションパラメーター](https://platform.openai.com/docs/api-reference/responses/create) があります。トップレベルで指定できない場合は、`extra_args` で渡してください。 -2. [`ModelProvider`][agents.models.interface.ModelProvider] - `Runner.run` レベルで「この実行中のすべてのエージェントにカスタムモデルプロバイダーを使う」と宣言できます。設定例は [examples/model_providers/custom_example_provider.py](https://github.com/openai/openai-agents-python/tree/main/examples/model_providers/custom_example_provider.py)。 - -3. [`Agent.model`][agents.agent.Agent.model] - 特定の Agent インスタンスにモデルを指定できます。エージェントごとに異なるプロバイダーを組み合わせられます。設定例は [examples/model_providers/custom_example_agent.py](https://github.com/openai/openai-agents-python/tree/main/examples/model_providers/custom_example_agent.py)。多くのモデルを簡単に使う方法として [LiteLLM 連携](./litellm.md) があります。 - -`platform.openai.com` の API キーを持たない場合は、`set_tracing_disabled()` でトレーシングを無効化するか、[別のトレーシングプロセッサー](../tracing.md) を設定することを推奨します。 +```python +from agents import Agent, ModelSettings -!!! note - これらの例では Chat Completions API/モデルを使用しています。多くの LLM プロバイダーがまだ Responses API をサポートしていないためです。もしプロバイダーが Responses API をサポートしている場合は、Responses の使用を推奨します。 +english_agent = Agent( + name="English agent", + instructions="You only speak English", + model="gpt-4o", + model_settings=ModelSettings( + temperature=0.1, + extra_args={"service_tier": "flex", "user": "user_12345"}, + ), +) +``` -## 他の LLM プロバイダーでよくある問題 +## 他の LLM プロバイダー使用時の一般的な問題 ### Tracing クライアントの 401 エラー -トレースは OpenAI サーバーへアップロードされるため、OpenAI API キーがない場合にエラーになります。解決策は次の 3 つです。 +Tracing 関連のエラーが発生する場合、トレースは OpenAI サーバーへアップロードされるため、OpenAI API キーが必要です。対応方法は次の 3 つです。 1. トレーシングを完全に無効化する: [`set_tracing_disabled(True)`][agents.set_tracing_disabled] -2. トレーシング用の OpenAI キーを設定する: [`set_tracing_export_api_key(...)`][agents.set_tracing_export_api_key] - このキーはトレースのアップロードにのみ使用され、[platform.openai.com](https://platform.openai.com/) のものが必要です。 -3. OpenAI 以外のトレースプロセッサーを使う。詳しくは [tracing ドキュメント](../tracing.md#custom-tracing-processors) を参照してください。 +2. トレース用に OpenAI キーを設定する: [`set_tracing_export_api_key(...)`][agents.set_tracing_export_api_key] + この API キーはトレースのアップロードのみに使用され、[platform.openai.com](https://platform.openai.com/) で取得したものが必要です。 +3. OpenAI 以外のトレースプロセッサーを使用する。詳細は [tracing のドキュメント](../tracing.md#custom-tracing-processors) を参照してください。 -### Responses API サポート +### Responses API のサポート -SDK は既定で Responses API を使用しますが、多くの LLM プロバイダーはまだ対応していません。そのため 404 などのエラーが発生する場合があります。対処方法は 2 つです。 +SDK はデフォルトで Responses API を使用しますが、ほとんどの LLM プロバイダーはまだ非対応です。その結果、404 などのエラーが発生することがあります。対処方法は次の 2 つです。 1. [`set_default_openai_api("chat_completions")`][agents.set_default_openai_api] を呼び出す - 環境変数 `OPENAI_API_KEY` と `OPENAI_BASE_URL` を設定している場合に機能します。 + `OPENAI_API_KEY` と `OPENAI_BASE_URL` を環境変数で設定している場合に有効です。 2. [`OpenAIChatCompletionsModel`][agents.models.openai_chatcompletions.OpenAIChatCompletionsModel] を使用する - コード例は [こちら](https://github.com/openai/openai-agents-python/tree/main/examples/model_providers/) にあります。 + 例は [こちら](https://github.com/openai/openai-agents-python/tree/main/examples/model_providers/) にあります。 ### structured outputs のサポート 一部のモデルプロバイダーは [structured outputs](https://platform.openai.com/docs/guides/structured-outputs) をサポートしていません。その場合、次のようなエラーが発生することがあります。 ``` + BadRequestError: Error code: 400 - {'error': {'message': "'response_format.type' : value is not one of the allowed values ['text','json_object']", 'type': 'invalid_request_error'}} + ``` -これは一部プロバイダーの制限で、JSON 出力はサポートしていても `json_schema` を指定できません。現在修正に取り組んでいますが、JSON スキーマ出力をサポートしているプロバイダーを利用することを推奨します。そうでない場合、不正な JSON によりアプリが頻繁に壊れる可能性があります。 \ No newline at end of file +これは一部プロバイダーの制限で、JSON 出力自体はサポートしていても `json_schema` を指定できないことが原因です。修正に向けて取り組んでいますが、JSON スキーマ出力をサポートしているプロバイダーを使用することをお勧めします。そうでないと、不正な JSON が返されてアプリが頻繁に壊れる可能性があります。 + +## プロバイダーを跨いだモデルの組み合わせ + +モデルプロバイダーごとの機能差に注意しないと、エラーが発生します。たとえば OpenAI は structured outputs、マルチモーダル入力、ホスト型の file search や web search をサポートしていますが、多くの他プロバイダーは非対応です。以下の制限に留意してください。 + +- 対応していないプロバイダーには未サポートの `tools` を送らない +- テキストのみのモデルを呼び出す前にマルチモーダル入力を除外する +- structured JSON 出力をサポートしていないプロバイダーでは、不正な JSON が返ることがある点に注意する \ No newline at end of file diff --git a/docs/models/index.md b/docs/models/index.md index 1c89d778a..b3b2b7f0b 100644 --- a/docs/models/index.md +++ b/docs/models/index.md @@ -93,6 +93,22 @@ english_agent = Agent( ) ``` +Also, when you use OpenAI's Responses API, [there are a few other optional parameters](https://platform.openai.com/docs/api-reference/responses/create) (e.g., `user`, `service_tier`, and so on). If they are not available at the top level, you can use `extra_args` to pass them as well. + +```python +from agents import Agent, ModelSettings + +english_agent = Agent( + name="English agent", + instructions="You only speak English", + model="gpt-4o", + model_settings=ModelSettings( + temperature=0.1, + extra_args={"service_tier": "flex", "user": "user_12345"}, + ), +) +``` + ## Common issues with using other LLM providers ### Tracing client error 401 From 017ad69de11c945b90cdbfe68a6c56c61ac93ff5 Mon Sep 17 00:00:00 2001 From: tconley1428 Date: Fri, 27 Jun 2025 08:25:06 -0700 Subject: [PATCH 19/23] Annotating the openai.Omit type so that ModelSettings can be serialized by pydantic (#938) Because `openai.Omit` is not a type which can be serialized by `pydantic`, it produces difficulty in consistently serializing/deserializing types across the agents SDK, many of which are `pydantic` types. This adds an annotation to enable `pydantic` to serialize `Omit`, in particular in the case of `ModelSettings` which contains `Omit` in its `extra_headers` --- src/agents/model_settings.py | 44 ++++++++++++++++++++-- tests/model_settings/test_serialization.py | 31 +++++++++++++++ 2 files changed, 71 insertions(+), 4 deletions(-) diff --git a/src/agents/model_settings.py b/src/agents/model_settings.py index 709046b7a..26af94ba3 100644 --- a/src/agents/model_settings.py +++ b/src/agents/model_settings.py @@ -1,14 +1,50 @@ from __future__ import annotations import dataclasses +from collections.abc import Mapping from dataclasses import dataclass, fields, replace -from typing import Any, Literal +from typing import Annotated, Any, Literal, Union -from openai._types import Body, Headers, Query +from openai import Omit as _Omit +from openai._types import Body, Query from openai.types.responses import ResponseIncludable from openai.types.shared import Reasoning -from pydantic import BaseModel - +from pydantic import BaseModel, GetCoreSchemaHandler +from pydantic_core import core_schema +from typing_extensions import TypeAlias + + +class _OmitTypeAnnotation: + @classmethod + def __get_pydantic_core_schema__( + cls, + _source_type: Any, + _handler: GetCoreSchemaHandler, + ) -> core_schema.CoreSchema: + def validate_from_none(value: None) -> _Omit: + return _Omit() + + from_none_schema = core_schema.chain_schema( + [ + core_schema.none_schema(), + core_schema.no_info_plain_validator_function(validate_from_none), + ] + ) + return core_schema.json_or_python_schema( + json_schema=from_none_schema, + python_schema=core_schema.union_schema( + [ + # check if it's an instance first before doing any further work + core_schema.is_instance_schema(_Omit), + from_none_schema, + ] + ), + serialization=core_schema.plain_serializer_function_ser_schema( + lambda instance: None + ), + ) +Omit = Annotated[_Omit, _OmitTypeAnnotation] +Headers: TypeAlias = Mapping[str, Union[str, Omit]] @dataclass class ModelSettings: diff --git a/tests/model_settings/test_serialization.py b/tests/model_settings/test_serialization.py index 2bbc7ce2c..94d11def3 100644 --- a/tests/model_settings/test_serialization.py +++ b/tests/model_settings/test_serialization.py @@ -2,6 +2,8 @@ from dataclasses import fields from openai.types.shared import Reasoning +from pydantic import TypeAdapter +from pydantic_core import to_json from agents.model_settings import ModelSettings @@ -132,3 +134,32 @@ def test_extra_args_resolve_both_none() -> None: assert resolved.extra_args is None assert resolved.temperature == 0.5 assert resolved.top_p == 0.9 + +def test_pydantic_serialization() -> None: + + """Tests whether ModelSettings can be serialized with Pydantic.""" + + # First, lets create a ModelSettings instance + model_settings = ModelSettings( + temperature=0.5, + top_p=0.9, + frequency_penalty=0.0, + presence_penalty=0.0, + tool_choice="auto", + parallel_tool_calls=True, + truncation="auto", + max_tokens=100, + reasoning=Reasoning(), + metadata={"foo": "bar"}, + store=False, + include_usage=False, + extra_query={"foo": "bar"}, + extra_body={"foo": "bar"}, + extra_headers={"foo": "bar"}, + extra_args={"custom_param": "value", "another_param": 42}, + ) + + json = to_json(model_settings) + deserialized = TypeAdapter(ModelSettings).validate_json(json) + + assert model_settings == deserialized From 27912d4af85baf3acdefc950e4f611db406e1138 Mon Sep 17 00:00:00 2001 From: Daniel Hashmi <151331476+DanielHashmi@users.noreply.github.com> Date: Fri, 27 Jun 2025 20:59:19 +0500 Subject: [PATCH 20/23] Import Path Inconsistency (#960) There's an inconsistent import style in the tracing module where one import uses an absolute path in line number 3 while all others use relative imports Additionally: Also fixed the ordering of the imports --- src/agents/tracing/__init__.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/agents/tracing/__init__.py b/src/agents/tracing/__init__.py index 64b0bd71f..b45c06d75 100644 --- a/src/agents/tracing/__init__.py +++ b/src/agents/tracing/__init__.py @@ -1,7 +1,5 @@ import atexit -from agents.tracing.provider import DefaultTraceProvider, TraceProvider - from .create import ( agent_span, custom_span, @@ -20,6 +18,7 @@ ) from .processor_interface import TracingProcessor from .processors import default_exporter, default_processor +from .provider import DefaultTraceProvider, TraceProvider from .setup import get_trace_provider, set_trace_provider from .span_data import ( AgentSpanData, From fcafae1d50a3345f795518308e3e87eab15ce1b0 Mon Sep 17 00:00:00 2001 From: Jinseop Song <152702069+axion66@users.noreply.github.com> Date: Sat, 28 Jun 2025 01:06:13 +0900 Subject: [PATCH 21/23] Add reasoning content - fix on #494 (#871) my personal update on https://github.com/openai/openai-agents-python/pull/494 Key changes: - Updated `ResponseReasoningItem` to use the correct structure with `id`, `summary`, and `type="reasoning"` fields - Fixed `ResponseReasoningSummaryPartAddedEvent` to include `summary_index` parameter and use a dictionary for the `part` parameter - Fixed `ResponseReasoningSummaryTextDeltaEvent` to include `summary_index` parameter - Updated `ResponseReasoningSummaryPartDoneEvent` to include `summary_index` and use a dictionary for `part` - Changed how the `Summary` object is accessed in tests (using `.text` property instead of subscript notation) - Updated `chatcmpl_converter.py` to use the `Summary` class instead of a dictionary for better handlign Fixes #494 --- examples/reasoning_content/__init__.py | 3 + examples/reasoning_content/main.py | 124 ++++++++ examples/reasoning_content/runner_example.py | 88 ++++++ src/agents/models/chatcmpl_converter.py | 12 + src/agents/models/chatcmpl_stream_handler.py | 142 ++++++++- tests/test_reasoning_content.py | 289 +++++++++++++++++++ 6 files changed, 643 insertions(+), 15 deletions(-) create mode 100644 examples/reasoning_content/__init__.py create mode 100644 examples/reasoning_content/main.py create mode 100644 examples/reasoning_content/runner_example.py create mode 100644 tests/test_reasoning_content.py diff --git a/examples/reasoning_content/__init__.py b/examples/reasoning_content/__init__.py new file mode 100644 index 000000000..f24b2606d --- /dev/null +++ b/examples/reasoning_content/__init__.py @@ -0,0 +1,3 @@ +""" +Examples demonstrating how to use models that provide reasoning content. +""" diff --git a/examples/reasoning_content/main.py b/examples/reasoning_content/main.py new file mode 100644 index 000000000..5f67e1779 --- /dev/null +++ b/examples/reasoning_content/main.py @@ -0,0 +1,124 @@ +""" +Example demonstrating how to use the reasoning content feature with models that support it. + +Some models, like deepseek-reasoner, provide a reasoning_content field in addition to the regular content. +This example shows how to access and use this reasoning content from both streaming and non-streaming responses. + +To run this example, you need to: +1. Set your OPENAI_API_KEY environment variable +2. Use a model that supports reasoning content (e.g., deepseek-reasoner) +""" + +import asyncio +import os +from typing import Any, cast + +from agents import ModelSettings +from agents.models.interface import ModelTracing +from agents.models.openai_provider import OpenAIProvider +from agents.types import ResponseOutputRefusal, ResponseOutputText # type: ignore + +MODEL_NAME = os.getenv("EXAMPLE_MODEL_NAME") or "deepseek-reasoner" + + +async def stream_with_reasoning_content(): + """ + Example of streaming a response from a model that provides reasoning content. + The reasoning content will be emitted as separate events. + """ + provider = OpenAIProvider() + model = provider.get_model(MODEL_NAME) + + print("\n=== Streaming Example ===") + print("Prompt: Write a haiku about recursion in programming") + + reasoning_content = "" + regular_content = "" + + async for event in model.stream_response( + system_instructions="You are a helpful assistant that writes creative content.", + input="Write a haiku about recursion in programming", + model_settings=ModelSettings(), + tools=[], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + previous_response_id=None, + prompt=None + ): + if event.type == "response.reasoning_summary_text.delta": + print( + f"\033[33m{event.delta}\033[0m", end="", flush=True + ) # Yellow for reasoning content + reasoning_content += event.delta + elif event.type == "response.output_text.delta": + print(f"\033[32m{event.delta}\033[0m", end="", flush=True) # Green for regular content + regular_content += event.delta + + print("\n\nReasoning Content:") + print(reasoning_content) + print("\nRegular Content:") + print(regular_content) + print("\n") + + +async def get_response_with_reasoning_content(): + """ + Example of getting a complete response from a model that provides reasoning content. + The reasoning content will be available as a separate item in the response. + """ + provider = OpenAIProvider() + model = provider.get_model(MODEL_NAME) + + print("\n=== Non-streaming Example ===") + print("Prompt: Explain the concept of recursion in programming") + + response = await model.get_response( + system_instructions="You are a helpful assistant that explains technical concepts clearly.", + input="Explain the concept of recursion in programming", + model_settings=ModelSettings(), + tools=[], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + previous_response_id=None, + prompt=None + ) + + # Extract reasoning content and regular content from the response + reasoning_content = None + regular_content = None + + for item in response.output: + if hasattr(item, "type") and item.type == "reasoning": + reasoning_content = item.summary[0].text + elif hasattr(item, "type") and item.type == "message": + if item.content and len(item.content) > 0: + content_item = item.content[0] + if isinstance(content_item, ResponseOutputText): + regular_content = content_item.text + elif isinstance(content_item, ResponseOutputRefusal): + refusal_item = cast(Any, content_item) + regular_content = refusal_item.refusal + + print("\nReasoning Content:") + print(reasoning_content or "No reasoning content provided") + + print("\nRegular Content:") + print(regular_content or "No regular content provided") + + print("\n") + + +async def main(): + try: + await stream_with_reasoning_content() + await get_response_with_reasoning_content() + except Exception as e: + print(f"Error: {e}") + print("\nNote: This example requires a model that supports reasoning content.") + print("You may need to use a specific model like deepseek-reasoner or similar.") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/reasoning_content/runner_example.py b/examples/reasoning_content/runner_example.py new file mode 100644 index 000000000..e51f85799 --- /dev/null +++ b/examples/reasoning_content/runner_example.py @@ -0,0 +1,88 @@ +""" +Example demonstrating how to use the reasoning content feature with the Runner API. + +This example shows how to extract and use reasoning content from responses when using +the Runner API, which is the most common way users interact with the Agents library. + +To run this example, you need to: +1. Set your OPENAI_API_KEY environment variable +2. Use a model that supports reasoning content (e.g., deepseek-reasoner) +""" + +import asyncio +import os +from typing import Any + +from agents import Agent, Runner, trace +from agents.items import ReasoningItem + +MODEL_NAME = os.getenv("EXAMPLE_MODEL_NAME") or "deepseek-reasoner" + + +async def main(): + print(f"Using model: {MODEL_NAME}") + + # Create an agent with a model that supports reasoning content + agent = Agent( + name="Reasoning Agent", + instructions="You are a helpful assistant that explains your reasoning step by step.", + model=MODEL_NAME, + ) + + # Example 1: Non-streaming response + with trace("Reasoning Content - Non-streaming"): + print("\n=== Example 1: Non-streaming response ===") + result = await Runner.run( + agent, "What is the square root of 841? Please explain your reasoning." + ) + + # Extract reasoning content from the result items + reasoning_content = None + # RunResult has 'response' attribute which has 'output' attribute + for item in result.response.output: # type: ignore + if isinstance(item, ReasoningItem): + reasoning_content = item.summary[0].text # type: ignore + break + + print("\nReasoning Content:") + print(reasoning_content or "No reasoning content provided") + + print("\nFinal Output:") + print(result.final_output) + + # Example 2: Streaming response + with trace("Reasoning Content - Streaming"): + print("\n=== Example 2: Streaming response ===") + print("\nStreaming response:") + + # Buffers to collect reasoning and regular content + reasoning_buffer = "" + content_buffer = "" + + # RunResultStreaming is async iterable + stream = Runner.run_streamed(agent, "What is 15 x 27? Please explain your reasoning.") + + async for event in stream: # type: ignore + if isinstance(event, ReasoningItem): + # This is reasoning content + reasoning_item: Any = event + reasoning_buffer += reasoning_item.summary[0].text + print( + f"\033[33m{reasoning_item.summary[0].text}\033[0m", end="", flush=True + ) # Yellow for reasoning + elif hasattr(event, "text"): + # This is regular content + content_buffer += event.text + print( + f"\033[32m{event.text}\033[0m", end="", flush=True + ) # Green for regular content + + print("\n\nCollected Reasoning Content:") + print(reasoning_buffer) + + print("\nCollected Final Answer:") + print(content_buffer) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/agents/models/chatcmpl_converter.py b/src/agents/models/chatcmpl_converter.py index 1d599e8c0..25d9f083d 100644 --- a/src/agents/models/chatcmpl_converter.py +++ b/src/agents/models/chatcmpl_converter.py @@ -33,8 +33,10 @@ ResponseOutputMessageParam, ResponseOutputRefusal, ResponseOutputText, + ResponseReasoningItem, ) from openai.types.responses.response_input_param import FunctionCallOutput, ItemReference, Message +from openai.types.responses.response_reasoning_item import Summary from ..agent_output import AgentOutputSchemaBase from ..exceptions import AgentsException, UserError @@ -85,6 +87,16 @@ def convert_response_format( def message_to_output_items(cls, message: ChatCompletionMessage) -> list[TResponseOutputItem]: items: list[TResponseOutputItem] = [] + # Handle reasoning content if available + if hasattr(message, "reasoning_content") and message.reasoning_content: + items.append( + ResponseReasoningItem( + id=FAKE_RESPONSES_ID, + summary=[Summary(text=message.reasoning_content, type="summary_text")], + type="reasoning", + ) + ) + message_item = ResponseOutputMessage( id=FAKE_RESPONSES_ID, content=[], diff --git a/src/agents/models/chatcmpl_stream_handler.py b/src/agents/models/chatcmpl_stream_handler.py index d18f5912a..0cf1e6a3e 100644 --- a/src/agents/models/chatcmpl_stream_handler.py +++ b/src/agents/models/chatcmpl_stream_handler.py @@ -20,21 +20,38 @@ ResponseOutputMessage, ResponseOutputRefusal, ResponseOutputText, + ResponseReasoningItem, + ResponseReasoningSummaryPartAddedEvent, + ResponseReasoningSummaryPartDoneEvent, + ResponseReasoningSummaryTextDeltaEvent, ResponseRefusalDeltaEvent, ResponseTextDeltaEvent, ResponseUsage, ) +from openai.types.responses.response_reasoning_item import Summary +from openai.types.responses.response_reasoning_summary_part_added_event import ( + Part as AddedEventPart, +) +from openai.types.responses.response_reasoning_summary_part_done_event import Part as DoneEventPart from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails from ..items import TResponseStreamEvent from .fake_id import FAKE_RESPONSES_ID +# Define a Part class for internal use +class Part: + def __init__(self, text: str, type: str): + self.text = text + self.type = type + + @dataclass class StreamingState: started: bool = False text_content_index_and_output: tuple[int, ResponseOutputText] | None = None refusal_content_index_and_output: tuple[int, ResponseOutputRefusal] | None = None + reasoning_content_index_and_output: tuple[int, ResponseReasoningItem] | None = None function_calls: dict[int, ResponseFunctionToolCall] = field(default_factory=dict) @@ -75,12 +92,65 @@ async def handle_stream( delta = chunk.choices[0].delta - # Handle text - if delta.content: + # Handle reasoning content + if hasattr(delta, "reasoning_content"): + reasoning_content = delta.reasoning_content + if reasoning_content and not state.reasoning_content_index_and_output: + state.reasoning_content_index_and_output = ( + 0, + ResponseReasoningItem( + id=FAKE_RESPONSES_ID, + summary=[Summary(text="", type="summary_text")], + type="reasoning", + ), + ) + yield ResponseOutputItemAddedEvent( + item=ResponseReasoningItem( + id=FAKE_RESPONSES_ID, + summary=[Summary(text="", type="summary_text")], + type="reasoning", + ), + output_index=0, + type="response.output_item.added", + sequence_number=sequence_number.get_and_increment(), + ) + + yield ResponseReasoningSummaryPartAddedEvent( + item_id=FAKE_RESPONSES_ID, + output_index=0, + summary_index=0, + part=AddedEventPart(text="", type="summary_text"), + type="response.reasoning_summary_part.added", + sequence_number=sequence_number.get_and_increment(), + ) + + if reasoning_content and state.reasoning_content_index_and_output: + yield ResponseReasoningSummaryTextDeltaEvent( + delta=reasoning_content, + item_id=FAKE_RESPONSES_ID, + output_index=0, + summary_index=0, + type="response.reasoning_summary_text.delta", + sequence_number=sequence_number.get_and_increment(), + ) + + # Create a new summary with updated text + current_summary = state.reasoning_content_index_and_output[1].summary[0] + updated_text = current_summary.text + reasoning_content + new_summary = Summary(text=updated_text, type="summary_text") + state.reasoning_content_index_and_output[1].summary[0] = new_summary + + # Handle regular content + if delta.content is not None: if not state.text_content_index_and_output: - # Initialize a content tracker for streaming text + content_index = 0 + if state.reasoning_content_index_and_output: + content_index += 1 + if state.refusal_content_index_and_output: + content_index += 1 + state.text_content_index_and_output = ( - 0 if not state.refusal_content_index_and_output else 1, + content_index, ResponseOutputText( text="", type="output_text", @@ -98,14 +168,16 @@ async def handle_stream( # Notify consumers of the start of a new output message + first content part yield ResponseOutputItemAddedEvent( item=assistant_item, - output_index=0, + output_index=state.reasoning_content_index_and_output + is not None, # fixed 0 -> 0 or 1 type="response.output_item.added", sequence_number=sequence_number.get_and_increment(), ) yield ResponseContentPartAddedEvent( content_index=state.text_content_index_and_output[0], item_id=FAKE_RESPONSES_ID, - output_index=0, + output_index=state.reasoning_content_index_and_output + is not None, # fixed 0 -> 0 or 1 part=ResponseOutputText( text="", type="output_text", @@ -119,7 +191,8 @@ async def handle_stream( content_index=state.text_content_index_and_output[0], delta=delta.content, item_id=FAKE_RESPONSES_ID, - output_index=0, + output_index=state.reasoning_content_index_and_output + is not None, # fixed 0 -> 0 or 1 type="response.output_text.delta", sequence_number=sequence_number.get_and_increment(), ) @@ -130,9 +203,14 @@ async def handle_stream( # This is always set by the OpenAI API, but not by others e.g. LiteLLM if hasattr(delta, "refusal") and delta.refusal: if not state.refusal_content_index_and_output: - # Initialize a content tracker for streaming refusal text + refusal_index = 0 + if state.reasoning_content_index_and_output: + refusal_index += 1 + if state.text_content_index_and_output: + refusal_index += 1 + state.refusal_content_index_and_output = ( - 0 if not state.text_content_index_and_output else 1, + refusal_index, ResponseOutputRefusal(refusal="", type="refusal"), ) # Start a new assistant message if one doesn't exist yet (in-progress) @@ -146,14 +224,16 @@ async def handle_stream( # Notify downstream that assistant message + first content part are starting yield ResponseOutputItemAddedEvent( item=assistant_item, - output_index=0, + output_index=state.reasoning_content_index_and_output + is not None, # fixed 0 -> 0 or 1 type="response.output_item.added", sequence_number=sequence_number.get_and_increment(), ) yield ResponseContentPartAddedEvent( content_index=state.refusal_content_index_and_output[0], item_id=FAKE_RESPONSES_ID, - output_index=0, + output_index=state.reasoning_content_index_and_output + is not None, # fixed 0 -> 0 or 1 part=ResponseOutputText( text="", type="output_text", @@ -167,7 +247,8 @@ async def handle_stream( content_index=state.refusal_content_index_and_output[0], delta=delta.refusal, item_id=FAKE_RESPONSES_ID, - output_index=0, + output_index=state.reasoning_content_index_and_output + is not None, # fixed 0 -> 0 or 1 type="response.refusal.delta", sequence_number=sequence_number.get_and_increment(), ) @@ -197,14 +278,37 @@ async def handle_stream( ) or "" state.function_calls[tc_delta.index].call_id += tc_delta.id or "" + if state.reasoning_content_index_and_output: + yield ResponseReasoningSummaryPartDoneEvent( + item_id=FAKE_RESPONSES_ID, + output_index=0, + summary_index=0, + part=DoneEventPart( + text=state.reasoning_content_index_and_output[1].summary[0].text, + type="summary_text", + ), + type="response.reasoning_summary_part.done", + sequence_number=sequence_number.get_and_increment(), + ) + yield ResponseOutputItemDoneEvent( + item=state.reasoning_content_index_and_output[1], + output_index=0, + type="response.output_item.done", + sequence_number=sequence_number.get_and_increment(), + ) + function_call_starting_index = 0 + if state.reasoning_content_index_and_output: + function_call_starting_index += 1 + if state.text_content_index_and_output: function_call_starting_index += 1 # Send end event for this content part yield ResponseContentPartDoneEvent( content_index=state.text_content_index_and_output[0], item_id=FAKE_RESPONSES_ID, - output_index=0, + output_index=state.reasoning_content_index_and_output + is not None, # fixed 0 -> 0 or 1 part=state.text_content_index_and_output[1], type="response.content_part.done", sequence_number=sequence_number.get_and_increment(), @@ -216,7 +320,8 @@ async def handle_stream( yield ResponseContentPartDoneEvent( content_index=state.refusal_content_index_and_output[0], item_id=FAKE_RESPONSES_ID, - output_index=0, + output_index=state.reasoning_content_index_and_output + is not None, # fixed 0 -> 0 or 1 part=state.refusal_content_index_and_output[1], type="response.content_part.done", sequence_number=sequence_number.get_and_increment(), @@ -261,6 +366,12 @@ async def handle_stream( # Finally, send the Response completed event outputs: list[ResponseOutputItem] = [] + + # include Reasoning item if it exists + if state.reasoning_content_index_and_output: + outputs.append(state.reasoning_content_index_and_output[1]) + + # include text or refusal content if they exist if state.text_content_index_and_output or state.refusal_content_index_and_output: assistant_msg = ResponseOutputMessage( id=FAKE_RESPONSES_ID, @@ -278,7 +389,8 @@ async def handle_stream( # send a ResponseOutputItemDone for the assistant message yield ResponseOutputItemDoneEvent( item=assistant_msg, - output_index=0, + output_index=state.reasoning_content_index_and_output + is not None, # fixed 0 -> 0 or 1 type="response.output_item.done", sequence_number=sequence_number.get_and_increment(), ) diff --git a/tests/test_reasoning_content.py b/tests/test_reasoning_content.py new file mode 100644 index 000000000..5160e09c2 --- /dev/null +++ b/tests/test_reasoning_content.py @@ -0,0 +1,289 @@ +from __future__ import annotations + +from collections.abc import AsyncIterator +from typing import Any, cast + +import pytest +from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage +from openai.types.chat.chat_completion_chunk import Choice, ChoiceDelta +from openai.types.completion_usage import ( + CompletionTokensDetails, + CompletionUsage, + PromptTokensDetails, +) +from openai.types.responses import ( + Response, + ResponseOutputMessage, + ResponseOutputText, + ResponseReasoningItem, +) + +from agents.model_settings import ModelSettings +from agents.models.interface import ModelTracing +from agents.models.openai_chatcompletions import OpenAIChatCompletionsModel +from agents.models.openai_provider import OpenAIProvider + + +# Helper functions to create test objects consistently +def create_content_delta(content: str) -> dict[str, Any]: + """Create a delta dictionary with regular content""" + return { + "content": content, + "role": None, + "function_call": None, + "tool_calls": None + } + +def create_reasoning_delta(content: str) -> dict[str, Any]: + """Create a delta dictionary with reasoning content. The Only difference is reasoning_content""" + return { + "content": None, + "role": None, + "function_call": None, + "tool_calls": None, + "reasoning_content": content + } + + +def create_chunk(delta: dict[str, Any], include_usage: bool = False) -> ChatCompletionChunk: + """Create a ChatCompletionChunk with the given delta""" + # Create a ChoiceDelta object from the dictionary + delta_obj = ChoiceDelta( + content=delta.get("content"), + role=delta.get("role"), + function_call=delta.get("function_call"), + tool_calls=delta.get("tool_calls"), + ) + + # Add reasoning_content attribute dynamically if present in the delta + if "reasoning_content" in delta: + # Use direct assignment for the reasoning_content attribute + delta_obj_any = cast(Any, delta_obj) + delta_obj_any.reasoning_content = delta["reasoning_content"] + + # Create the chunk + chunk = ChatCompletionChunk( + id="chunk-id", + created=1, + model="deepseek is usually expected", + object="chat.completion.chunk", + choices=[Choice(index=0, delta=delta_obj)], + ) + + if include_usage: + chunk.usage = CompletionUsage( + completion_tokens=4, + prompt_tokens=2, + total_tokens=6, + completion_tokens_details=CompletionTokensDetails(reasoning_tokens=2), + prompt_tokens_details=PromptTokensDetails(cached_tokens=0), + ) + + return chunk + + +async def create_fake_stream( + chunks: list[ChatCompletionChunk], +) -> AsyncIterator[ChatCompletionChunk]: + for chunk in chunks: + yield chunk + + +@pytest.mark.allow_call_model_methods +@pytest.mark.asyncio +async def test_stream_response_yields_events_for_reasoning_content(monkeypatch) -> None: + """ + Validate that when a model streams reasoning content, + `stream_response` emits the appropriate sequence of events including + `response.reasoning_summary_text.delta` events for each chunk of the reasoning content and + constructs a completed response with a `ResponseReasoningItem` part. + """ + # Create test chunks + chunks = [ + # Reasoning content chunks + create_chunk(create_reasoning_delta("Let me think")), + create_chunk(create_reasoning_delta(" about this")), + # Regular content chunks + create_chunk(create_content_delta("The answer")), + create_chunk(create_content_delta(" is 42"), include_usage=True), + ] + + async def patched_fetch_response(self, *args, **kwargs): + resp = Response( + id="resp-id", + created_at=0, + model="fake-model", + object="response", + output=[], + tool_choice="none", + tools=[], + parallel_tool_calls=False, + ) + return resp, create_fake_stream(chunks) + + monkeypatch.setattr(OpenAIChatCompletionsModel, "_fetch_response", patched_fetch_response) + model = OpenAIProvider(use_responses=False).get_model("gpt-4") + output_events = [] + async for event in model.stream_response( + system_instructions=None, + input="", + model_settings=ModelSettings(), + tools=[], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + previous_response_id=None, + prompt=None, + ): + output_events.append(event) + + # verify reasoning content events were emitted + reasoning_delta_events = [ + e for e in output_events if e.type == "response.reasoning_summary_text.delta" + ] + assert len(reasoning_delta_events) == 2 + assert reasoning_delta_events[0].delta == "Let me think" + assert reasoning_delta_events[1].delta == " about this" + + # verify regular content events were emitted + content_delta_events = [e for e in output_events if e.type == "response.output_text.delta"] + assert len(content_delta_events) == 2 + assert content_delta_events[0].delta == "The answer" + assert content_delta_events[1].delta == " is 42" + + # verify the final response contains both types of content + response_event = output_events[-1] + assert response_event.type == "response.completed" + assert len(response_event.response.output) == 2 + + # first item should be reasoning + assert isinstance(response_event.response.output[0], ResponseReasoningItem) + assert response_event.response.output[0].summary[0].text == "Let me think about this" + + # second item should be message with text + assert isinstance(response_event.response.output[1], ResponseOutputMessage) + assert isinstance(response_event.response.output[1].content[0], ResponseOutputText) + assert response_event.response.output[1].content[0].text == "The answer is 42" + + +@pytest.mark.allow_call_model_methods +@pytest.mark.asyncio +async def test_get_response_with_reasoning_content(monkeypatch) -> None: + """ + Test that when a model returns reasoning content in addition to regular content, + `get_response` properly includes both in the response output. + """ + # create a message with reasoning content + msg = ChatCompletionMessage( + role="assistant", + content="The answer is 42", + ) + # Use dynamic attribute for reasoning_content + # We need to cast to Any to avoid mypy errors since reasoning_content is not a defined attribute + msg_with_reasoning = cast(Any, msg) + msg_with_reasoning.reasoning_content = "Let me think about this question carefully" + + # create a choice with the message + mock_choice = { + "index": 0, + "finish_reason": "stop", + "message": msg_with_reasoning, + "delta": None + } + + chat = ChatCompletion( + id="resp-id", + created=0, + model="deepseek is expected", + object="chat.completion", + choices=[mock_choice], # type: ignore[list-item] + usage=CompletionUsage( + completion_tokens=10, + prompt_tokens=5, + total_tokens=15, + completion_tokens_details=CompletionTokensDetails(reasoning_tokens=6), + prompt_tokens_details=PromptTokensDetails(cached_tokens=0), + ), + ) + + async def patched_fetch_response(self, *args, **kwargs): + return chat + + monkeypatch.setattr(OpenAIChatCompletionsModel, "_fetch_response", patched_fetch_response) + model = OpenAIProvider(use_responses=False).get_model("gpt-4") + resp = await model.get_response( + system_instructions=None, + input="", + model_settings=ModelSettings(), + tools=[], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + previous_response_id=None, + prompt=None, + ) + + # should have produced a reasoning item and a message with text content + assert len(resp.output) == 2 + + # first output should be the reasoning item + assert isinstance(resp.output[0], ResponseReasoningItem) + assert resp.output[0].summary[0].text == "Let me think about this question carefully" + + # second output should be the message with text content + assert isinstance(resp.output[1], ResponseOutputMessage) + assert isinstance(resp.output[1].content[0], ResponseOutputText) + assert resp.output[1].content[0].text == "The answer is 42" + + +@pytest.mark.allow_call_model_methods +@pytest.mark.asyncio +async def test_stream_response_with_empty_reasoning_content(monkeypatch) -> None: + """ + Test that when a model streams empty reasoning content, + the response still processes correctly without errors. + """ + # create test chunks with empty reasoning content + chunks = [ + create_chunk(create_reasoning_delta("")), + create_chunk(create_content_delta("The answer is 42"), include_usage=True), + ] + + async def patched_fetch_response(self, *args, **kwargs): + resp = Response( + id="resp-id", + created_at=0, + model="fake-model", + object="response", + output=[], + tool_choice="none", + tools=[], + parallel_tool_calls=False, + ) + return resp, create_fake_stream(chunks) + + monkeypatch.setattr(OpenAIChatCompletionsModel, "_fetch_response", patched_fetch_response) + model = OpenAIProvider(use_responses=False).get_model("gpt-4") + output_events = [] + async for event in model.stream_response( + system_instructions=None, + input="", + model_settings=ModelSettings(), + tools=[], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + previous_response_id=None, + prompt=None + ): + output_events.append(event) + + # verify the final response contains the content + response_event = output_events[-1] + assert response_event.type == "response.completed" + + # should only have the message, not an empty reasoning item + assert len(response_event.response.output) == 1 + assert isinstance(response_event.response.output[0], ResponseOutputMessage) + assert isinstance(response_event.response.output[0].content[0], ResponseOutputText) + assert response_event.response.output[0].content[0].text == "The answer is 42" From 86642716e764ef7cfceb206c576d5e40358819a9 Mon Sep 17 00:00:00 2001 From: Rohan Mehta Date: Fri, 27 Jun 2025 14:32:34 -0400 Subject: [PATCH 22/23] Add safety check handling for ComputerTool (#923) ### Summary - handle safety checks when using computer tools - expose new `on_safety_check` callback and data structure - return acknowledged checks in computer output - test acknowledging safety checks ### Test plan - `make format` - `make lint` - `make mypy` - `make tests` Resolves #843 ------ https://chatgpt.com/codex/tasks/task_i_684f207b8b588321a33642a2a9c96a1a --- src/agents/_run_impl.py | 30 +++++++++++++++++++++++ src/agents/tool.py | 24 +++++++++++++++++++ tests/test_computer_action.py | 45 ++++++++++++++++++++++++++++++++++- 3 files changed, 98 insertions(+), 1 deletion(-) diff --git a/src/agents/_run_impl.py b/src/agents/_run_impl.py index a75c5e825..4ac8b316b 100644 --- a/src/agents/_run_impl.py +++ b/src/agents/_run_impl.py @@ -28,6 +28,9 @@ ActionType, ActionWait, ) +from openai.types.responses.response_input_item_param import ( + ComputerCallOutputAcknowledgedSafetyCheck, +) from openai.types.responses.response_input_param import ComputerCallOutput, McpApprovalResponse from openai.types.responses.response_output_item import ( ImageGenerationCall, @@ -67,6 +70,7 @@ from .stream_events import RunItemStreamEvent, StreamEvent from .tool import ( ComputerTool, + ComputerToolSafetyCheckData, FunctionTool, FunctionToolResult, HostedMCPTool, @@ -638,6 +642,29 @@ async def execute_computer_actions( results: list[RunItem] = [] # Need to run these serially, because each action can affect the computer state for action in actions: + acknowledged: list[ComputerCallOutputAcknowledgedSafetyCheck] | None = None + if action.tool_call.pending_safety_checks and action.computer_tool.on_safety_check: + acknowledged = [] + for check in action.tool_call.pending_safety_checks: + data = ComputerToolSafetyCheckData( + ctx_wrapper=context_wrapper, + agent=agent, + tool_call=action.tool_call, + safety_check=check, + ) + maybe = action.computer_tool.on_safety_check(data) + ack = await maybe if inspect.isawaitable(maybe) else maybe + if ack: + acknowledged.append( + ComputerCallOutputAcknowledgedSafetyCheck( + id=check.id, + code=check.code, + message=check.message, + ) + ) + else: + raise UserError("Computer tool safety check was not acknowledged") + results.append( await ComputerAction.execute( agent=agent, @@ -645,6 +672,7 @@ async def execute_computer_actions( hooks=hooks, context_wrapper=context_wrapper, config=config, + acknowledged_safety_checks=acknowledged, ) ) @@ -998,6 +1026,7 @@ async def execute( hooks: RunHooks[TContext], context_wrapper: RunContextWrapper[TContext], config: RunConfig, + acknowledged_safety_checks: list[ComputerCallOutputAcknowledgedSafetyCheck] | None = None, ) -> RunItem: output_func = ( cls._get_screenshot_async(action.computer_tool.computer, action.tool_call) @@ -1036,6 +1065,7 @@ async def execute( "image_url": image_url, }, type="computer_call_output", + acknowledged_safety_checks=acknowledged_safety_checks, ), ) diff --git a/src/agents/tool.py b/src/agents/tool.py index 283a9e24f..3aab47752 100644 --- a/src/agents/tool.py +++ b/src/agents/tool.py @@ -7,6 +7,10 @@ from typing import TYPE_CHECKING, Any, Callable, Literal, Union, overload from openai.types.responses.file_search_tool_param import Filters, RankingOptions +from openai.types.responses.response_computer_tool_call import ( + PendingSafetyCheck, + ResponseComputerToolCall, +) from openai.types.responses.response_output_item import LocalShellCall, McpApprovalRequest from openai.types.responses.tool_param import CodeInterpreter, ImageGeneration, Mcp from openai.types.responses.web_search_tool_param import UserLocation @@ -142,11 +146,31 @@ class ComputerTool: as well as implements the computer actions like click, screenshot, etc. """ + on_safety_check: Callable[[ComputerToolSafetyCheckData], MaybeAwaitable[bool]] | None = None + """Optional callback to acknowledge computer tool safety checks.""" + @property def name(self): return "computer_use_preview" +@dataclass +class ComputerToolSafetyCheckData: + """Information about a computer tool safety check.""" + + ctx_wrapper: RunContextWrapper[Any] + """The run context.""" + + agent: Agent[Any] + """The agent performing the computer action.""" + + tool_call: ResponseComputerToolCall + """The computer tool call.""" + + safety_check: PendingSafetyCheck + """The pending safety check to acknowledge.""" + + @dataclass class MCPToolApprovalRequest: """A request to approve a tool call.""" diff --git a/tests/test_computer_action.py b/tests/test_computer_action.py index 70dcabd59..a306b1841 100644 --- a/tests/test_computer_action.py +++ b/tests/test_computer_action.py @@ -18,6 +18,7 @@ ActionScroll, ActionType, ActionWait, + PendingSafetyCheck, ResponseComputerToolCall, ) @@ -31,8 +32,9 @@ RunContextWrapper, RunHooks, ) -from agents._run_impl import ComputerAction, ToolRunComputerAction +from agents._run_impl import ComputerAction, RunImpl, ToolRunComputerAction from agents.items import ToolCallOutputItem +from agents.tool import ComputerToolSafetyCheckData class LoggingComputer(Computer): @@ -309,3 +311,44 @@ async def test_execute_invokes_hooks_and_returns_tool_call_output() -> None: assert raw["output"]["type"] == "computer_screenshot" assert "image_url" in raw["output"] assert raw["output"]["image_url"].endswith("xyz") + + +@pytest.mark.asyncio +async def test_pending_safety_check_acknowledged() -> None: + """Safety checks should be acknowledged via the callback.""" + + computer = LoggingComputer(screenshot_return="img") + called: list[ComputerToolSafetyCheckData] = [] + + def on_sc(data: ComputerToolSafetyCheckData) -> bool: + called.append(data) + return True + + tool = ComputerTool(computer=computer, on_safety_check=on_sc) + safety = PendingSafetyCheck(id="sc", code="c", message="m") + tool_call = ResponseComputerToolCall( + id="t1", + type="computer_call", + action=ActionClick(type="click", x=1, y=1, button="left"), + call_id="t1", + pending_safety_checks=[safety], + status="completed", + ) + run_action = ToolRunComputerAction(tool_call=tool_call, computer_tool=tool) + agent = Agent(name="a", tools=[tool]) + ctx = RunContextWrapper(context=None) + + results = await RunImpl.execute_computer_actions( + agent=agent, + actions=[run_action], + hooks=RunHooks[Any](), + context_wrapper=ctx, + config=RunConfig(), + ) + + assert len(results) == 1 + raw = results[0].raw_item + assert isinstance(raw, dict) + assert raw.get("acknowledged_safety_checks") == [{"id": "sc", "code": "c", "message": "m"}] + assert len(called) == 1 + assert called[0].safety_check.id == "sc" From 18cb55e4c45f096d00263c3561d6f22a2299368e Mon Sep 17 00:00:00 2001 From: Rohan Mehta Date: Fri, 27 Jun 2025 16:56:19 -0400 Subject: [PATCH 23/23] v0.1.0 (#963) There's a small breaking change so incrementing the minor version number. --- docs/release.md | 16 +++++++++++----- pyproject.toml | 2 +- uv.lock | 2 +- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/docs/release.md b/docs/release.md index 3bfc6e9b8..a86103f96 100644 --- a/docs/release.md +++ b/docs/release.md @@ -1,4 +1,4 @@ -# Release process +# Release process/changelog The project follows a slightly modified version of semantic versioning using the form `0.Y.Z`. The leading `0` indicates the SDK is still evolving rapidly. Increment the components as follows: @@ -12,7 +12,13 @@ If you don't want breaking changes, we recommend pinning to `0.0.x` versions in We will increment `Z` for non-breaking changes: -- Bug fixes -- New features -- Changes to private interfaces -- Updates to beta features +- Bug fixes +- New features +- Changes to private interfaces +- Updates to beta features + +## Breaking change changelog + +### 0.1.0 + +In this version, [`MCPServer.list_tools()`][agents.mcp.server.MCPServer] has two new params: `run_context` and `agent`. You'll need to add these params to any classes that subclass `MCPServer`. diff --git a/pyproject.toml b/pyproject.toml index 0dc0ca567..e659348ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "openai-agents" -version = "0.0.19" +version = "0.1.0" description = "OpenAI Agents SDK" readme = "README.md" requires-python = ">=3.9" diff --git a/uv.lock b/uv.lock index 679d8d2c1..d882c9bc5 100644 --- a/uv.lock +++ b/uv.lock @@ -1480,7 +1480,7 @@ wheels = [ [[package]] name = "openai-agents" -version = "0.0.19" +version = "0.1.0" source = { editable = "." } dependencies = [ { name = "griffe" },