Skip to content

Commit e474d0f

Browse files
committed
add tool hooks
Signed-off-by: myan <myan@redhat.com>
1 parent 4187fba commit e474d0f

File tree

6 files changed

+37
-13
lines changed

6 files changed

+37
-13
lines changed

examples/basic/agent_lifecycle_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ async def on_handoff(self, context: RunContextWrapper, agent: Agent, source: Age
2828
f"### ({self.display_name}) {self.event_counter}: Agent {source.name} handed off to {agent.name}"
2929
)
3030

31-
async def on_tool_start(self, context: RunContextWrapper, agent: Agent, tool: Tool) -> None:
31+
async def on_tool_start(self, context: RunContextWrapper, agent: Agent, tool: Tool, args: str) -> None:
3232
self.event_counter += 1
3333
print(
3434
f"### ({self.display_name}) {self.event_counter}: Agent {agent.name} started tool {tool.name}"

examples/basic/lifecycle_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ async def on_agent_end(self, context: RunContextWrapper, agent: Agent, output: A
2626
f"### {self.event_counter}: Agent {agent.name} ended with output {output}. Usage: {self._usage_to_str(context.usage)}"
2727
)
2828

29-
async def on_tool_start(self, context: RunContextWrapper, agent: Agent, tool: Tool) -> None:
29+
async def on_tool_start(self, context: RunContextWrapper, agent: Agent, tool: Tool, args: str) -> None:
3030
self.event_counter += 1
3131
print(
3232
f"### {self.event_counter}: Tool {tool.name} started. Usage: {self._usage_to_str(context.usage)}"

src/agents/_run_impl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -437,9 +437,9 @@ async def run_single_tool(
437437
span_fn.span_data.input = tool_call.arguments
438438
try:
439439
_, _, result = await asyncio.gather(
440-
hooks.on_tool_start(context_wrapper, agent, func_tool),
440+
hooks.on_tool_start(context_wrapper, agent, func_tool, tool_call.arguments),
441441
(
442-
agent.hooks.on_tool_start(context_wrapper, agent, func_tool)
442+
agent.hooks.on_tool_start(context_wrapper, agent, func_tool, tool_call.arguments)
443443
if agent.hooks
444444
else _coro.noop_coroutine()
445445
),

src/agents/lifecycle.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ async def on_tool_start(
3939
context: RunContextWrapper[TContext],
4040
agent: Agent[TContext],
4141
tool: Tool,
42+
args: str,
4243
) -> None:
4344
"""Called before a tool is invoked."""
4445
pass
@@ -61,7 +62,9 @@ class AgentHooks(Generic[TContext]):
6162
Subclass and override the methods you need.
6263
"""
6364

64-
async def on_start(self, context: RunContextWrapper[TContext], agent: Agent[TContext]) -> None:
65+
async def on_start(
66+
self, context: RunContextWrapper[TContext], agent: Agent[TContext]
67+
) -> None:
6568
"""Called before the agent is invoked. Called each time the running agent is changed to this
6669
agent."""
6770
pass
@@ -90,6 +93,7 @@ async def on_tool_start(
9093
context: RunContextWrapper[TContext],
9194
agent: Agent[TContext],
9295
tool: Tool,
96+
args: str,
9397
) -> None:
9498
"""Called before a tool is invoked."""
9599
pass

tests/test_agent_hooks.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ async def on_tool_start(
5454
context: RunContextWrapper[TContext],
5555
agent: Agent[TContext],
5656
tool: Tool,
57+
args: str,
5758
) -> None:
5859
self.events["on_tool_start"] += 1
5960

tests/test_computer_action.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -127,10 +127,18 @@ async def drag(self, path: list[tuple[int, int]]) -> None:
127127
@pytest.mark.parametrize(
128128
"action,expected_call",
129129
[
130-
(ActionClick(type="click", x=10, y=21, button="left"), ("click", (10, 21, "left"))),
131-
(ActionDoubleClick(type="double_click", x=42, y=47), ("double_click", (42, 47))),
132130
(
133-
ActionDrag(type="drag", path=[ActionDragPath(x=1, y=2), ActionDragPath(x=3, y=4)]),
131+
ActionClick(type="click", x=10, y=21, button="left"),
132+
("click", (10, 21, "left")),
133+
),
134+
(
135+
ActionDoubleClick(type="double_click", x=42, y=47),
136+
("double_click", (42, 47)),
137+
),
138+
(
139+
ActionDrag(
140+
type="drag", path=[ActionDragPath(x=1, y=2), ActionDragPath(x=3, y=4)]
141+
),
134142
("drag", (((1, 2), (3, 4)),)),
135143
),
136144
(ActionKeypress(type="keypress", keys=["a", "b"]), ("keypress", (["a", "b"],))),
@@ -172,13 +180,24 @@ async def test_get_screenshot_sync_executes_action_and_takes_screenshot(
172180
@pytest.mark.parametrize(
173181
"action,expected_call",
174182
[
175-
(ActionClick(type="click", x=2, y=3, button="right"), ("click", (2, 3, "right"))),
176-
(ActionDoubleClick(type="double_click", x=12, y=13), ("double_click", (12, 13))),
177183
(
178-
ActionDrag(type="drag", path=[ActionDragPath(x=5, y=6), ActionDragPath(x=6, y=7)]),
184+
ActionClick(type="click", x=2, y=3, button="right"),
185+
("click", (2, 3, "right")),
186+
),
187+
(
188+
ActionDoubleClick(type="double_click", x=12, y=13),
189+
("double_click", (12, 13)),
190+
),
191+
(
192+
ActionDrag(
193+
type="drag", path=[ActionDragPath(x=5, y=6), ActionDragPath(x=6, y=7)]
194+
),
179195
("drag", (((5, 6), (6, 7)),)),
180196
),
181-
(ActionKeypress(type="keypress", keys=["ctrl", "c"]), ("keypress", (["ctrl", "c"],))),
197+
(
198+
ActionKeypress(type="keypress", keys=["ctrl", "c"]),
199+
("keypress", (["ctrl", "c"],)),
200+
),
182201
(ActionMove(type="move", x=8, y=9), ("move", (8, 9))),
183202
(ActionScreenshot(type="screenshot"), ("screenshot", ())),
184203
(
@@ -241,7 +260,7 @@ def __init__(self) -> None:
241260
self.ended: list[tuple[Agent[Any], Any, str]] = []
242261

243262
async def on_tool_start(
244-
self, context: RunContextWrapper[Any], agent: Agent[Any], tool: Any
263+
self, context: RunContextWrapper[Any], agent: Agent[Any], tool: Any, args: str,
245264
) -> None:
246265
self.started.append((agent, tool))
247266

0 commit comments

Comments
 (0)