Skip to content

enhancement: Add tool_name to ToolContext to support shared tool handlers #1043

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/ref/tool_context.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# `Tool context`

::: agents.tool_context
2 changes: 1 addition & 1 deletion docs/tools.md
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ Sometimes, you don't want to use a Python function as a tool. You can directly c
- `name`
- `description`
- `params_json_schema`, which is the JSON schema for the arguments
- `on_invoke_tool`, which is an async function that receives the context and the arguments as a JSON string, and must return the tool output as a string.
- `on_invoke_tool`, which is an async function that receives a [`ToolContext`][agents.tool_context.ToolContext] and the arguments as a JSON string, and must return the tool output as a string.

```python
from typing import Any
Expand Down
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ plugins:
- ref/lifecycle.md
- ref/items.md
- ref/run_context.md
- ref/tool_context.md
- ref/usage.md
- ref/exceptions.md
- ref/guardrail.md
Expand Down
6 changes: 5 additions & 1 deletion src/agents/_run_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,11 @@ async def run_single_tool(
func_tool: FunctionTool, tool_call: ResponseFunctionToolCall
) -> Any:
with function_span(func_tool.name) as span_fn:
tool_context = ToolContext.from_agent_context(context_wrapper, tool_call.call_id)
tool_context = ToolContext.from_agent_context(
context_wrapper,
tool_call.call_id,
tool_call=tool_call,
)
if config.trace_include_sensitive_data:
span_fn.span_data.input = tool_call.arguments
try:
Expand Down
19 changes: 16 additions & 3 deletions src/agents/tool_context.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from dataclasses import dataclass, field, fields
from typing import Any
from typing import Any, Optional

from openai.types.responses import ResponseFunctionToolCall

from .run_context import RunContextWrapper, TContext

Expand All @@ -8,16 +10,26 @@ def _assert_must_pass_tool_call_id() -> str:
raise ValueError("tool_call_id must be passed to ToolContext")


def _assert_must_pass_tool_name() -> str:
raise ValueError("tool_name must be passed to ToolContext")


@dataclass
class ToolContext(RunContextWrapper[TContext]):
"""The context of a tool call."""

tool_name: str = field(default_factory=_assert_must_pass_tool_name)
"""The name of the tool being invoked."""

tool_call_id: str = field(default_factory=_assert_must_pass_tool_call_id)
"""The ID of the tool call."""

@classmethod
def from_agent_context(
cls, context: RunContextWrapper[TContext], tool_call_id: str
cls,
context: RunContextWrapper[TContext],
tool_call_id: str,
tool_call: Optional[ResponseFunctionToolCall] = None,
) -> "ToolContext":
"""
Create a ToolContext from a RunContextWrapper.
Expand All @@ -26,4 +38,5 @@ def from_agent_context(
base_values: dict[str, Any] = {
f.name: getattr(context, f.name) for f in fields(RunContextWrapper) if f.init
}
return cls(tool_call_id=tool_call_id, **base_values)
tool_name = tool_call.name if tool_call is not None else _assert_must_pass_tool_name()
return cls(tool_name=tool_name, tool_call_id=tool_call_id, **base_values)
49 changes: 34 additions & 15 deletions tests/test_function_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ async def test_argless_function():
tool = function_tool(argless_function)
assert tool.name == "argless_function"

result = await tool.on_invoke_tool(ToolContext(context=None, tool_call_id="1"), "")
result = await tool.on_invoke_tool(
ToolContext(context=None, tool_name=tool.name, tool_call_id="1"), ""
)
assert result == "ok"


Expand All @@ -32,11 +34,13 @@ async def test_argless_with_context():
tool = function_tool(argless_with_context)
assert tool.name == "argless_with_context"

result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), "")
result = await tool.on_invoke_tool(ToolContext(None, tool_name=tool.name, tool_call_id="1"), "")
assert result == "ok"

# Extra JSON should not raise an error
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"a": 1}')
result = await tool.on_invoke_tool(
ToolContext(None, tool_name=tool.name, tool_call_id="1"), '{"a": 1}'
)
assert result == "ok"


Expand All @@ -49,15 +53,19 @@ async def test_simple_function():
tool = function_tool(simple_function, failure_error_function=None)
assert tool.name == "simple_function"

result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"a": 1}')
result = await tool.on_invoke_tool(
ToolContext(None, tool_name=tool.name, tool_call_id="1"), '{"a": 1}'
)
assert result == 6

result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"a": 1, "b": 2}')
result = await tool.on_invoke_tool(
ToolContext(None, tool_name=tool.name, tool_call_id="1"), '{"a": 1, "b": 2}'
)
assert result == 3

# Missing required argument should raise an error
with pytest.raises(ModelBehaviorError):
await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), "")
await tool.on_invoke_tool(ToolContext(None, tool_name=tool.name, tool_call_id="1"), "")


class Foo(BaseModel):
Expand Down Expand Up @@ -85,7 +93,9 @@ async def test_complex_args_function():
"bar": Bar(x="hello", y=10),
}
)
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), valid_json)
result = await tool.on_invoke_tool(
ToolContext(None, tool_name=tool.name, tool_call_id="1"), valid_json
)
assert result == "6 hello10 hello"

valid_json = json.dumps(
Expand All @@ -94,7 +104,9 @@ async def test_complex_args_function():
"bar": Bar(x="hello", y=10),
}
)
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), valid_json)
result = await tool.on_invoke_tool(
ToolContext(None, tool_name=tool.name, tool_call_id="1"), valid_json
)
assert result == "3 hello10 hello"

valid_json = json.dumps(
Expand All @@ -104,12 +116,16 @@ async def test_complex_args_function():
"baz": "world",
}
)
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), valid_json)
result = await tool.on_invoke_tool(
ToolContext(None, tool_name=tool.name, tool_call_id="1"), valid_json
)
assert result == "3 hello10 world"

# Missing required argument should raise an error
with pytest.raises(ModelBehaviorError):
await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"foo": {"a": 1}}')
await tool.on_invoke_tool(
ToolContext(None, tool_name=tool.name, tool_call_id="1"), '{"foo": {"a": 1}}'
)


def test_function_config_overrides():
Expand Down Expand Up @@ -169,7 +185,9 @@ async def run_function(ctx: RunContextWrapper[Any], args: str) -> str:
assert tool.params_json_schema[key] == value
assert tool.strict_json_schema

result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"data": "hello"}')
result = await tool.on_invoke_tool(
ToolContext(None, tool_name=tool.name, tool_call_id="1"), '{"data": "hello"}'
)
assert result == "hello_done"

tool_not_strict = FunctionTool(
Expand All @@ -184,7 +202,8 @@ async def run_function(ctx: RunContextWrapper[Any], args: str) -> str:
assert "additionalProperties" not in tool_not_strict.params_json_schema

result = await tool_not_strict.on_invoke_tool(
ToolContext(None, tool_call_id="1"), '{"data": "hello", "bar": "baz"}'
ToolContext(None, tool_name=tool_not_strict.name, tool_call_id="1"),
'{"data": "hello", "bar": "baz"}',
)
assert result == "hello_done"

Expand All @@ -195,7 +214,7 @@ def my_func(a: int, b: int = 5):
raise ValueError("test")

tool = function_tool(my_func)
ctx = ToolContext(None, tool_call_id="1")
ctx = ToolContext(None, tool_name=tool.name, tool_call_id="1")

result = await tool.on_invoke_tool(ctx, "")
assert "Invalid JSON" in str(result)
Expand All @@ -219,7 +238,7 @@ def custom_sync_error_function(ctx: RunContextWrapper[Any], error: Exception) ->
return f"error_{error.__class__.__name__}"

tool = function_tool(my_func, failure_error_function=custom_sync_error_function)
ctx = ToolContext(None, tool_call_id="1")
ctx = ToolContext(None, tool_name=tool.name, tool_call_id="1")

result = await tool.on_invoke_tool(ctx, "")
assert result == "error_ModelBehaviorError"
Expand All @@ -243,7 +262,7 @@ def custom_sync_error_function(ctx: RunContextWrapper[Any], error: Exception) ->
return f"error_{error.__class__.__name__}"

tool = function_tool(my_func, failure_error_function=custom_sync_error_function)
ctx = ToolContext(None, tool_call_id="1")
ctx = ToolContext(None, tool_name=tool.name, tool_call_id="1")

result = await tool.on_invoke_tool(ctx, "")
assert result == "error_ModelBehaviorError"
Expand Down
2 changes: 1 addition & 1 deletion tests/test_function_tool_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(self):


def ctx_wrapper() -> ToolContext[DummyContext]:
return ToolContext(context=DummyContext(), tool_call_id="1")
return ToolContext(context=DummyContext(), tool_name="dummy", tool_call_id="1")


@function_tool
Expand Down