From c467d00f9ce4509f09732421dd255c368c91185f Mon Sep 17 00:00:00 2001 From: Viraj <123119434+vrtnis@users.noreply.github.com> Date: Tue, 8 Jul 2025 23:12:28 -0700 Subject: [PATCH 1/4] Add tool_name to ToolContext --- docs/tools.md | 17 ++++++++++ src/agents/_run_impl.py | 4 ++- src/agents/tool_context.py | 11 ++++-- tests/test_function_tool.py | 49 +++++++++++++++++++-------- tests/test_function_tool_decorator.py | 2 +- 5 files changed, 64 insertions(+), 19 deletions(-) diff --git a/docs/tools.md b/docs/tools.md index 6dba1a853..bdfb5441d 100644 --- a/docs/tools.md +++ b/docs/tools.md @@ -213,6 +213,23 @@ tool = FunctionTool( ) ``` +### Tool context + +When `on_invoke_tool` is called, it receives a `ToolContext` instance. The object contains: + +- `context` – the context object you passed to `Runner.run()`. +- `usage` – usage information for the run so far. +- `tool_name` – the name of the tool being invoked. +- `tool_call_id` – the ID of the tool call. + +You can access these fields inside your tool function: + +```python +async def run_function(ctx: ToolContext[Any], args: str) -> str: + print("Tool invoked:", ctx.tool_name) + ... +``` + ### Automatic argument and docstring parsing As mentioned before, we automatically parse the function signature to extract the schema for the tool, and we parse the docstring to extract descriptions for the tool and for individual arguments. Some notes on that: diff --git a/src/agents/_run_impl.py b/src/agents/_run_impl.py index 4ac8b316b..d60c9f711 100644 --- a/src/agents/_run_impl.py +++ b/src/agents/_run_impl.py @@ -548,7 +548,9 @@ async def run_single_tool( func_tool: FunctionTool, tool_call: ResponseFunctionToolCall ) -> Any: with function_span(func_tool.name) as span_fn: - tool_context = ToolContext.from_agent_context(context_wrapper, tool_call.call_id) + tool_context = ToolContext.from_agent_context( + context_wrapper, func_tool.name, tool_call.call_id + ) if config.trace_include_sensitive_data: span_fn.span_data.input = tool_call.arguments try: diff --git a/src/agents/tool_context.py b/src/agents/tool_context.py index c4329b8af..8d98395fe 100644 --- a/src/agents/tool_context.py +++ b/src/agents/tool_context.py @@ -8,16 +8,23 @@ def _assert_must_pass_tool_call_id() -> str: raise ValueError("tool_call_id must be passed to ToolContext") +def _assert_must_pass_tool_name() -> str: + raise ValueError("tool_name must be passed to ToolContext") + + @dataclass class ToolContext(RunContextWrapper[TContext]): """The context of a tool call.""" + tool_name: str = field(default_factory=_assert_must_pass_tool_name) + """The name of the tool being invoked.""" + tool_call_id: str = field(default_factory=_assert_must_pass_tool_call_id) """The ID of the tool call.""" @classmethod def from_agent_context( - cls, context: RunContextWrapper[TContext], tool_call_id: str + cls, context: RunContextWrapper[TContext], tool_name: str, tool_call_id: str ) -> "ToolContext": """ Create a ToolContext from a RunContextWrapper. @@ -26,4 +33,4 @@ def from_agent_context( base_values: dict[str, Any] = { f.name: getattr(context, f.name) for f in fields(RunContextWrapper) if f.init } - return cls(tool_call_id=tool_call_id, **base_values) + return cls(tool_name=tool_name, tool_call_id=tool_call_id, **base_values) diff --git a/tests/test_function_tool.py b/tests/test_function_tool.py index b232bf75e..b339a9794 100644 --- a/tests/test_function_tool.py +++ b/tests/test_function_tool.py @@ -19,7 +19,9 @@ async def test_argless_function(): tool = function_tool(argless_function) assert tool.name == "argless_function" - result = await tool.on_invoke_tool(ToolContext(context=None, tool_call_id="1"), "") + result = await tool.on_invoke_tool( + ToolContext(context=None, tool_name=tool.name, tool_call_id="1"), "" + ) assert result == "ok" @@ -32,11 +34,13 @@ async def test_argless_with_context(): tool = function_tool(argless_with_context) assert tool.name == "argless_with_context" - result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), "") + result = await tool.on_invoke_tool(ToolContext(None, tool_name=tool.name, tool_call_id="1"), "") assert result == "ok" # Extra JSON should not raise an error - result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"a": 1}') + result = await tool.on_invoke_tool( + ToolContext(None, tool_name=tool.name, tool_call_id="1"), '{"a": 1}' + ) assert result == "ok" @@ -49,15 +53,19 @@ async def test_simple_function(): tool = function_tool(simple_function, failure_error_function=None) assert tool.name == "simple_function" - result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"a": 1}') + result = await tool.on_invoke_tool( + ToolContext(None, tool_name=tool.name, tool_call_id="1"), '{"a": 1}' + ) assert result == 6 - result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"a": 1, "b": 2}') + result = await tool.on_invoke_tool( + ToolContext(None, tool_name=tool.name, tool_call_id="1"), '{"a": 1, "b": 2}' + ) assert result == 3 # Missing required argument should raise an error with pytest.raises(ModelBehaviorError): - await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), "") + await tool.on_invoke_tool(ToolContext(None, tool_name=tool.name, tool_call_id="1"), "") class Foo(BaseModel): @@ -85,7 +93,9 @@ async def test_complex_args_function(): "bar": Bar(x="hello", y=10), } ) - result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), valid_json) + result = await tool.on_invoke_tool( + ToolContext(None, tool_name=tool.name, tool_call_id="1"), valid_json + ) assert result == "6 hello10 hello" valid_json = json.dumps( @@ -94,7 +104,9 @@ async def test_complex_args_function(): "bar": Bar(x="hello", y=10), } ) - result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), valid_json) + result = await tool.on_invoke_tool( + ToolContext(None, tool_name=tool.name, tool_call_id="1"), valid_json + ) assert result == "3 hello10 hello" valid_json = json.dumps( @@ -104,12 +116,16 @@ async def test_complex_args_function(): "baz": "world", } ) - result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), valid_json) + result = await tool.on_invoke_tool( + ToolContext(None, tool_name=tool.name, tool_call_id="1"), valid_json + ) assert result == "3 hello10 world" # Missing required argument should raise an error with pytest.raises(ModelBehaviorError): - await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"foo": {"a": 1}}') + await tool.on_invoke_tool( + ToolContext(None, tool_name=tool.name, tool_call_id="1"), '{"foo": {"a": 1}}' + ) def test_function_config_overrides(): @@ -169,7 +185,9 @@ async def run_function(ctx: RunContextWrapper[Any], args: str) -> str: assert tool.params_json_schema[key] == value assert tool.strict_json_schema - result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"data": "hello"}') + result = await tool.on_invoke_tool( + ToolContext(None, tool_name=tool.name, tool_call_id="1"), '{"data": "hello"}' + ) assert result == "hello_done" tool_not_strict = FunctionTool( @@ -184,7 +202,8 @@ async def run_function(ctx: RunContextWrapper[Any], args: str) -> str: assert "additionalProperties" not in tool_not_strict.params_json_schema result = await tool_not_strict.on_invoke_tool( - ToolContext(None, tool_call_id="1"), '{"data": "hello", "bar": "baz"}' + ToolContext(None, tool_name=tool_not_strict.name, tool_call_id="1"), + '{"data": "hello", "bar": "baz"}', ) assert result == "hello_done" @@ -195,7 +214,7 @@ def my_func(a: int, b: int = 5): raise ValueError("test") tool = function_tool(my_func) - ctx = ToolContext(None, tool_call_id="1") + ctx = ToolContext(None, tool_name=tool.name, tool_call_id="1") result = await tool.on_invoke_tool(ctx, "") assert "Invalid JSON" in str(result) @@ -219,7 +238,7 @@ def custom_sync_error_function(ctx: RunContextWrapper[Any], error: Exception) -> return f"error_{error.__class__.__name__}" tool = function_tool(my_func, failure_error_function=custom_sync_error_function) - ctx = ToolContext(None, tool_call_id="1") + ctx = ToolContext(None, tool_name=tool.name, tool_call_id="1") result = await tool.on_invoke_tool(ctx, "") assert result == "error_ModelBehaviorError" @@ -243,7 +262,7 @@ def custom_sync_error_function(ctx: RunContextWrapper[Any], error: Exception) -> return f"error_{error.__class__.__name__}" tool = function_tool(my_func, failure_error_function=custom_sync_error_function) - ctx = ToolContext(None, tool_call_id="1") + ctx = ToolContext(None, tool_name=tool.name, tool_call_id="1") result = await tool.on_invoke_tool(ctx, "") assert result == "error_ModelBehaviorError" diff --git a/tests/test_function_tool_decorator.py b/tests/test_function_tool_decorator.py index d334d8f84..b81d5dbe2 100644 --- a/tests/test_function_tool_decorator.py +++ b/tests/test_function_tool_decorator.py @@ -16,7 +16,7 @@ def __init__(self): def ctx_wrapper() -> ToolContext[DummyContext]: - return ToolContext(context=DummyContext(), tool_call_id="1") + return ToolContext(context=DummyContext(), tool_name="dummy", tool_call_id="1") @function_tool From b3ddf23df767e41c407d006995488f6c4f3fba0f Mon Sep 17 00:00:00 2001 From: vrtnis <123119434+vrtnis@users.noreply.github.com> Date: Wed, 9 Jul 2025 23:13:31 -0700 Subject: [PATCH 2/4] Updated docs/tools.md per requested link change, Tweaked signature in tool_context.from_agent_context, Aligned formatting in _run_impl.py --- docs/ref/tool_context.md | 3 +++ docs/tools.md | 2 +- mkdocs.yml | 1 + src/agents/_run_impl.py | 4 +++- src/agents/tool_context.py | 8 +++++++- 5 files changed, 15 insertions(+), 3 deletions(-) create mode 100644 docs/ref/tool_context.md diff --git a/docs/ref/tool_context.md b/docs/ref/tool_context.md new file mode 100644 index 000000000..6f32fac83 --- /dev/null +++ b/docs/ref/tool_context.md @@ -0,0 +1,3 @@ +# `Tool context` + +::: agents.tool_context diff --git a/docs/tools.md b/docs/tools.md index bdfb5441d..e6cca0601 100644 --- a/docs/tools.md +++ b/docs/tools.md @@ -180,7 +180,7 @@ Sometimes, you don't want to use a Python function as a tool. You can directly c - `name` - `description` - `params_json_schema`, which is the JSON schema for the arguments -- `on_invoke_tool`, which is an async function that receives the context and the arguments as a JSON string, and must return the tool output as a string. +- `on_invoke_tool`, which is an async function that receives a [`ToolContext`][agents.tool_context.ToolContext] and the arguments as a JSON string, and must return the tool output as a string. ```python from typing import Any diff --git a/mkdocs.yml b/mkdocs.yml index b79e6454f..0170f71a5 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -90,6 +90,7 @@ plugins: - ref/lifecycle.md - ref/items.md - ref/run_context.md + - ref/tool_context.md - ref/usage.md - ref/exceptions.md - ref/guardrail.md diff --git a/src/agents/_run_impl.py b/src/agents/_run_impl.py index d60c9f711..a83af62a1 100644 --- a/src/agents/_run_impl.py +++ b/src/agents/_run_impl.py @@ -549,7 +549,9 @@ async def run_single_tool( ) -> Any: with function_span(func_tool.name) as span_fn: tool_context = ToolContext.from_agent_context( - context_wrapper, func_tool.name, tool_call.call_id + context_wrapper, + tool_call.call_id, + tool_call=tool_call, ) if config.trace_include_sensitive_data: span_fn.span_data.input = tool_call.arguments diff --git a/src/agents/tool_context.py b/src/agents/tool_context.py index 8d98395fe..0303203df 100644 --- a/src/agents/tool_context.py +++ b/src/agents/tool_context.py @@ -1,6 +1,8 @@ from dataclasses import dataclass, field, fields from typing import Any +from openai.types.responses import ResponseFunctionToolCall + from .run_context import RunContextWrapper, TContext @@ -24,7 +26,10 @@ class ToolContext(RunContextWrapper[TContext]): @classmethod def from_agent_context( - cls, context: RunContextWrapper[TContext], tool_name: str, tool_call_id: str + cls, + context: RunContextWrapper[TContext], + tool_call_id: str, + tool_call: ResponseFunctionToolCall | None = None, ) -> "ToolContext": """ Create a ToolContext from a RunContextWrapper. @@ -33,4 +38,5 @@ def from_agent_context( base_values: dict[str, Any] = { f.name: getattr(context, f.name) for f in fields(RunContextWrapper) if f.init } + tool_name = tool_call.name if tool_call is not None else _assert_must_pass_tool_name() return cls(tool_name=tool_name, tool_call_id=tool_call_id, **base_values) From 5d796487d84764db11aa1a28876ca570be12c9ac Mon Sep 17 00:00:00 2001 From: vrtnis <123119434+vrtnis@users.noreply.github.com> Date: Wed, 9 Jul 2025 23:23:53 -0700 Subject: [PATCH 3/4] fix(tool_context): use Optional for tool_call to support Python 3.9 --- src/agents/tool_context.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/agents/tool_context.py b/src/agents/tool_context.py index 0303203df..16845badd 100644 --- a/src/agents/tool_context.py +++ b/src/agents/tool_context.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field, fields -from typing import Any +from typing import Any, Optional from openai.types.responses import ResponseFunctionToolCall @@ -29,7 +29,7 @@ def from_agent_context( cls, context: RunContextWrapper[TContext], tool_call_id: str, - tool_call: ResponseFunctionToolCall | None = None, + tool_call: Optional[ResponseFunctionToolCall] = None, ) -> "ToolContext": """ Create a ToolContext from a RunContextWrapper. From 8d7217d5b049f3e6827167c05344543b7f728120 Mon Sep 17 00:00:00 2001 From: Viraj <123119434+vrtnis@users.noreply.github.com> Date: Wed, 9 Jul 2025 23:52:42 -0700 Subject: [PATCH 4/4] Update docs/tools.md Co-authored-by: Kazuhiro Sera --- docs/tools.md | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/docs/tools.md b/docs/tools.md index e6cca0601..17f7da0a1 100644 --- a/docs/tools.md +++ b/docs/tools.md @@ -213,23 +213,6 @@ tool = FunctionTool( ) ``` -### Tool context - -When `on_invoke_tool` is called, it receives a `ToolContext` instance. The object contains: - -- `context` – the context object you passed to `Runner.run()`. -- `usage` – usage information for the run so far. -- `tool_name` – the name of the tool being invoked. -- `tool_call_id` – the ID of the tool call. - -You can access these fields inside your tool function: - -```python -async def run_function(ctx: ToolContext[Any], args: str) -> str: - print("Tool invoked:", ctx.tool_name) - ... -``` - ### Automatic argument and docstring parsing As mentioned before, we automatically parse the function signature to extract the schema for the tool, and we parse the docstring to extract descriptions for the tool and for individual arguments. Some notes on that: