diff --git a/docs/ref/tool_context.md b/docs/ref/tool_context.md new file mode 100644 index 000000000..6f32fac83 --- /dev/null +++ b/docs/ref/tool_context.md @@ -0,0 +1,3 @@ +# `Tool context` + +::: agents.tool_context diff --git a/docs/tools.md b/docs/tools.md index 6dba1a853..17f7da0a1 100644 --- a/docs/tools.md +++ b/docs/tools.md @@ -180,7 +180,7 @@ Sometimes, you don't want to use a Python function as a tool. You can directly c - `name` - `description` - `params_json_schema`, which is the JSON schema for the arguments -- `on_invoke_tool`, which is an async function that receives the context and the arguments as a JSON string, and must return the tool output as a string. +- `on_invoke_tool`, which is an async function that receives a [`ToolContext`][agents.tool_context.ToolContext] and the arguments as a JSON string, and must return the tool output as a string. ```python from typing import Any diff --git a/mkdocs.yml b/mkdocs.yml index b79e6454f..0170f71a5 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -90,6 +90,7 @@ plugins: - ref/lifecycle.md - ref/items.md - ref/run_context.md + - ref/tool_context.md - ref/usage.md - ref/exceptions.md - ref/guardrail.md diff --git a/src/agents/_run_impl.py b/src/agents/_run_impl.py index 4ac8b316b..a83af62a1 100644 --- a/src/agents/_run_impl.py +++ b/src/agents/_run_impl.py @@ -548,7 +548,11 @@ async def run_single_tool( func_tool: FunctionTool, tool_call: ResponseFunctionToolCall ) -> Any: with function_span(func_tool.name) as span_fn: - tool_context = ToolContext.from_agent_context(context_wrapper, tool_call.call_id) + tool_context = ToolContext.from_agent_context( + context_wrapper, + tool_call.call_id, + tool_call=tool_call, + ) if config.trace_include_sensitive_data: span_fn.span_data.input = tool_call.arguments try: diff --git a/src/agents/tool_context.py b/src/agents/tool_context.py index c4329b8af..16845badd 100644 --- a/src/agents/tool_context.py +++ b/src/agents/tool_context.py @@ -1,5 +1,7 @@ from dataclasses import dataclass, field, fields -from typing import Any +from typing import Any, Optional + +from openai.types.responses import ResponseFunctionToolCall from .run_context import RunContextWrapper, TContext @@ -8,16 +10,26 @@ def _assert_must_pass_tool_call_id() -> str: raise ValueError("tool_call_id must be passed to ToolContext") +def _assert_must_pass_tool_name() -> str: + raise ValueError("tool_name must be passed to ToolContext") + + @dataclass class ToolContext(RunContextWrapper[TContext]): """The context of a tool call.""" + tool_name: str = field(default_factory=_assert_must_pass_tool_name) + """The name of the tool being invoked.""" + tool_call_id: str = field(default_factory=_assert_must_pass_tool_call_id) """The ID of the tool call.""" @classmethod def from_agent_context( - cls, context: RunContextWrapper[TContext], tool_call_id: str + cls, + context: RunContextWrapper[TContext], + tool_call_id: str, + tool_call: Optional[ResponseFunctionToolCall] = None, ) -> "ToolContext": """ Create a ToolContext from a RunContextWrapper. @@ -26,4 +38,5 @@ def from_agent_context( base_values: dict[str, Any] = { f.name: getattr(context, f.name) for f in fields(RunContextWrapper) if f.init } - return cls(tool_call_id=tool_call_id, **base_values) + tool_name = tool_call.name if tool_call is not None else _assert_must_pass_tool_name() + return cls(tool_name=tool_name, tool_call_id=tool_call_id, **base_values) diff --git a/tests/test_function_tool.py b/tests/test_function_tool.py index b232bf75e..b339a9794 100644 --- a/tests/test_function_tool.py +++ b/tests/test_function_tool.py @@ -19,7 +19,9 @@ async def test_argless_function(): tool = function_tool(argless_function) assert tool.name == "argless_function" - result = await tool.on_invoke_tool(ToolContext(context=None, tool_call_id="1"), "") + result = await tool.on_invoke_tool( + ToolContext(context=None, tool_name=tool.name, tool_call_id="1"), "" + ) assert result == "ok" @@ -32,11 +34,13 @@ async def test_argless_with_context(): tool = function_tool(argless_with_context) assert tool.name == "argless_with_context" - result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), "") + result = await tool.on_invoke_tool(ToolContext(None, tool_name=tool.name, tool_call_id="1"), "") assert result == "ok" # Extra JSON should not raise an error - result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"a": 1}') + result = await tool.on_invoke_tool( + ToolContext(None, tool_name=tool.name, tool_call_id="1"), '{"a": 1}' + ) assert result == "ok" @@ -49,15 +53,19 @@ async def test_simple_function(): tool = function_tool(simple_function, failure_error_function=None) assert tool.name == "simple_function" - result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"a": 1}') + result = await tool.on_invoke_tool( + ToolContext(None, tool_name=tool.name, tool_call_id="1"), '{"a": 1}' + ) assert result == 6 - result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"a": 1, "b": 2}') + result = await tool.on_invoke_tool( + ToolContext(None, tool_name=tool.name, tool_call_id="1"), '{"a": 1, "b": 2}' + ) assert result == 3 # Missing required argument should raise an error with pytest.raises(ModelBehaviorError): - await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), "") + await tool.on_invoke_tool(ToolContext(None, tool_name=tool.name, tool_call_id="1"), "") class Foo(BaseModel): @@ -85,7 +93,9 @@ async def test_complex_args_function(): "bar": Bar(x="hello", y=10), } ) - result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), valid_json) + result = await tool.on_invoke_tool( + ToolContext(None, tool_name=tool.name, tool_call_id="1"), valid_json + ) assert result == "6 hello10 hello" valid_json = json.dumps( @@ -94,7 +104,9 @@ async def test_complex_args_function(): "bar": Bar(x="hello", y=10), } ) - result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), valid_json) + result = await tool.on_invoke_tool( + ToolContext(None, tool_name=tool.name, tool_call_id="1"), valid_json + ) assert result == "3 hello10 hello" valid_json = json.dumps( @@ -104,12 +116,16 @@ async def test_complex_args_function(): "baz": "world", } ) - result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), valid_json) + result = await tool.on_invoke_tool( + ToolContext(None, tool_name=tool.name, tool_call_id="1"), valid_json + ) assert result == "3 hello10 world" # Missing required argument should raise an error with pytest.raises(ModelBehaviorError): - await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"foo": {"a": 1}}') + await tool.on_invoke_tool( + ToolContext(None, tool_name=tool.name, tool_call_id="1"), '{"foo": {"a": 1}}' + ) def test_function_config_overrides(): @@ -169,7 +185,9 @@ async def run_function(ctx: RunContextWrapper[Any], args: str) -> str: assert tool.params_json_schema[key] == value assert tool.strict_json_schema - result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"data": "hello"}') + result = await tool.on_invoke_tool( + ToolContext(None, tool_name=tool.name, tool_call_id="1"), '{"data": "hello"}' + ) assert result == "hello_done" tool_not_strict = FunctionTool( @@ -184,7 +202,8 @@ async def run_function(ctx: RunContextWrapper[Any], args: str) -> str: assert "additionalProperties" not in tool_not_strict.params_json_schema result = await tool_not_strict.on_invoke_tool( - ToolContext(None, tool_call_id="1"), '{"data": "hello", "bar": "baz"}' + ToolContext(None, tool_name=tool_not_strict.name, tool_call_id="1"), + '{"data": "hello", "bar": "baz"}', ) assert result == "hello_done" @@ -195,7 +214,7 @@ def my_func(a: int, b: int = 5): raise ValueError("test") tool = function_tool(my_func) - ctx = ToolContext(None, tool_call_id="1") + ctx = ToolContext(None, tool_name=tool.name, tool_call_id="1") result = await tool.on_invoke_tool(ctx, "") assert "Invalid JSON" in str(result) @@ -219,7 +238,7 @@ def custom_sync_error_function(ctx: RunContextWrapper[Any], error: Exception) -> return f"error_{error.__class__.__name__}" tool = function_tool(my_func, failure_error_function=custom_sync_error_function) - ctx = ToolContext(None, tool_call_id="1") + ctx = ToolContext(None, tool_name=tool.name, tool_call_id="1") result = await tool.on_invoke_tool(ctx, "") assert result == "error_ModelBehaviorError" @@ -243,7 +262,7 @@ def custom_sync_error_function(ctx: RunContextWrapper[Any], error: Exception) -> return f"error_{error.__class__.__name__}" tool = function_tool(my_func, failure_error_function=custom_sync_error_function) - ctx = ToolContext(None, tool_call_id="1") + ctx = ToolContext(None, tool_name=tool.name, tool_call_id="1") result = await tool.on_invoke_tool(ctx, "") assert result == "error_ModelBehaviorError" diff --git a/tests/test_function_tool_decorator.py b/tests/test_function_tool_decorator.py index d334d8f84..b81d5dbe2 100644 --- a/tests/test_function_tool_decorator.py +++ b/tests/test_function_tool_decorator.py @@ -16,7 +16,7 @@ def __init__(self): def ctx_wrapper() -> ToolContext[DummyContext]: - return ToolContext(context=DummyContext(), tool_call_id="1") + return ToolContext(context=DummyContext(), tool_name="dummy", tool_call_id="1") @function_tool