Skip to content

Commit 6904dcb

Browse files
authored
fix(run): fire on_llm_start / on_llm_end in Runner.run() for streaming & non-streaming (aligns with docs) (#1619)
1 parent 5de3b58 commit 6904dcb

File tree

3 files changed

+302
-25
lines changed

3 files changed

+302
-25
lines changed

examples/basic/lifecycle_example.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import asyncio
22
import random
3-
from typing import Any
3+
from typing import Any, Optional
44

55
from pydantic import BaseModel
66

77
from agents import Agent, RunContextWrapper, RunHooks, Runner, Tool, Usage, function_tool
8+
from agents.items import ModelResponse, TResponseInputItem
89

910

1011
class ExampleHooks(RunHooks):
@@ -20,6 +21,22 @@ async def on_agent_start(self, context: RunContextWrapper, agent: Agent) -> None
2021
f"### {self.event_counter}: Agent {agent.name} started. Usage: {self._usage_to_str(context.usage)}"
2122
)
2223

24+
async def on_llm_start(
25+
self,
26+
context: RunContextWrapper,
27+
agent: Agent,
28+
system_prompt: Optional[str],
29+
input_items: list[TResponseInputItem],
30+
) -> None:
31+
self.event_counter += 1
32+
print(f"### {self.event_counter}: LLM started. Usage: {self._usage_to_str(context.usage)}")
33+
34+
async def on_llm_end(
35+
self, context: RunContextWrapper, agent: Agent, response: ModelResponse
36+
) -> None:
37+
self.event_counter += 1
38+
print(f"### {self.event_counter}: LLM ended. Usage: {self._usage_to_str(context.usage)}")
39+
2340
async def on_agent_end(self, context: RunContextWrapper, agent: Agent, output: Any) -> None:
2441
self.event_counter += 1
2542
print(
@@ -109,13 +126,21 @@ async def main() -> None:
109126
110127
Enter a max number: 250
111128
### 1: Agent Start Agent started. Usage: 0 requests, 0 input tokens, 0 output tokens, 0 total tokens
112-
### 2: Tool random_number started. Usage: 1 requests, 148 input tokens, 15 output tokens, 163 total tokens
113-
### 3: Tool random_number ended with result 101. Usage: 1 requests, 148 input tokens, 15 output tokens, 163 total token
114-
### 4: Handoff from Start Agent to Multiply Agent. Usage: 2 requests, 323 input tokens, 30 output tokens, 353 total tokens
115-
### 5: Agent Multiply Agent started. Usage: 2 requests, 323 input tokens, 30 output tokens, 353 total tokens
116-
### 6: Tool multiply_by_two started. Usage: 3 requests, 504 input tokens, 46 output tokens, 550 total tokens
117-
### 7: Tool multiply_by_two ended with result 202. Usage: 3 requests, 504 input tokens, 46 output tokens, 550 total tokens
118-
### 8: Agent Multiply Agent ended with output number=202. Usage: 4 requests, 714 input tokens, 63 output tokens, 777 total tokens
129+
### 2: LLM started. Usage: 0 requests, 0 input tokens, 0 output tokens, 0 total tokens
130+
### 3: LLM ended. Usage: 1 requests, 143 input tokens, 15 output tokens, 158 total tokens
131+
### 4: Tool random_number started. Usage: 1 requests, 143 input tokens, 15 output tokens, 158 total tokens
132+
### 5: Tool random_number ended with result 69. Usage: 1 requests, 143 input tokens, 15 output tokens, 158 total tokens
133+
### 6: LLM started. Usage: 1 requests, 143 input tokens, 15 output tokens, 158 total tokens
134+
### 7: LLM ended. Usage: 2 requests, 310 input tokens, 29 output tokens, 339 total tokens
135+
### 8: Handoff from Start Agent to Multiply Agent. Usage: 2 requests, 310 input tokens, 29 output tokens, 339 total tokens
136+
### 9: Agent Multiply Agent started. Usage: 2 requests, 310 input tokens, 29 output tokens, 339 total tokens
137+
### 10: LLM started. Usage: 2 requests, 310 input tokens, 29 output tokens, 339 total tokens
138+
### 11: LLM ended. Usage: 3 requests, 472 input tokens, 45 output tokens, 517 total tokens
139+
### 12: Tool multiply_by_two started. Usage: 3 requests, 472 input tokens, 45 output tokens, 517 total tokens
140+
### 13: Tool multiply_by_two ended with result 138. Usage: 3 requests, 472 input tokens, 45 output tokens, 517 total tokens
141+
### 14: LLM started. Usage: 3 requests, 472 input tokens, 45 output tokens, 517 total tokens
142+
### 15: LLM ended. Usage: 4 requests, 660 input tokens, 56 output tokens, 716 total tokens
143+
### 16: Agent Multiply Agent ended with output number=138. Usage: 4 requests, 660 input tokens, 56 output tokens, 716 total tokens
119144
Done!
120145
121146
"""

src/agents/run.py

Lines changed: 46 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -994,10 +994,16 @@ async def _run_single_turn_streamed(
994994
)
995995

996996
# Call hook just before the model is invoked, with the correct system_prompt.
997-
if agent.hooks:
998-
await agent.hooks.on_llm_start(
999-
context_wrapper, agent, filtered.instructions, filtered.input
1000-
)
997+
await asyncio.gather(
998+
hooks.on_llm_start(context_wrapper, agent, filtered.instructions, filtered.input),
999+
(
1000+
agent.hooks.on_llm_start(
1001+
context_wrapper, agent, filtered.instructions, filtered.input
1002+
)
1003+
if agent.hooks
1004+
else _coro.noop_coroutine()
1005+
),
1006+
)
10011007

10021008
# 1. Stream the output events
10031009
async for event in model.stream_response(
@@ -1056,8 +1062,15 @@ async def _run_single_turn_streamed(
10561062
streamed_result._event_queue.put_nowait(RawResponsesStreamEvent(data=event))
10571063

10581064
# Call hook just after the model response is finalized.
1059-
if agent.hooks and final_response is not None:
1060-
await agent.hooks.on_llm_end(context_wrapper, agent, final_response)
1065+
if final_response is not None:
1066+
await asyncio.gather(
1067+
(
1068+
agent.hooks.on_llm_end(context_wrapper, agent, final_response)
1069+
if agent.hooks
1070+
else _coro.noop_coroutine()
1071+
),
1072+
hooks.on_llm_end(context_wrapper, agent, final_response),
1073+
)
10611074

10621075
# 2. At this point, the streaming is complete for this turn of the agent loop.
10631076
if not final_response:
@@ -1150,6 +1163,7 @@ async def _run_single_turn(
11501163
output_schema,
11511164
all_tools,
11521165
handoffs,
1166+
hooks,
11531167
context_wrapper,
11541168
run_config,
11551169
tool_use_tracker,
@@ -1345,6 +1359,7 @@ async def _get_new_response(
13451359
output_schema: AgentOutputSchemaBase | None,
13461360
all_tools: list[Tool],
13471361
handoffs: list[Handoff],
1362+
hooks: RunHooks[TContext],
13481363
context_wrapper: RunContextWrapper[TContext],
13491364
run_config: RunConfig,
13501365
tool_use_tracker: AgentToolUseTracker,
@@ -1364,14 +1379,21 @@ async def _get_new_response(
13641379
model = cls._get_model(agent, run_config)
13651380
model_settings = agent.model_settings.resolve(run_config.model_settings)
13661381
model_settings = RunImpl.maybe_reset_tool_choice(agent, tool_use_tracker, model_settings)
1367-
# If the agent has hooks, we need to call them before and after the LLM call
1368-
if agent.hooks:
1369-
await agent.hooks.on_llm_start(
1370-
context_wrapper,
1371-
agent,
1372-
filtered.instructions, # Use filtered instructions
1373-
filtered.input, # Use filtered input
1374-
)
1382+
1383+
# If we have run hooks, or if the agent has hooks, we need to call them before the LLM call
1384+
await asyncio.gather(
1385+
hooks.on_llm_start(context_wrapper, agent, filtered.instructions, filtered.input),
1386+
(
1387+
agent.hooks.on_llm_start(
1388+
context_wrapper,
1389+
agent,
1390+
filtered.instructions, # Use filtered instructions
1391+
filtered.input, # Use filtered input
1392+
)
1393+
if agent.hooks
1394+
else _coro.noop_coroutine()
1395+
),
1396+
)
13751397

13761398
new_response = await model.get_response(
13771399
system_instructions=filtered.instructions,
@@ -1387,12 +1409,19 @@ async def _get_new_response(
13871409
conversation_id=conversation_id,
13881410
prompt=prompt_config,
13891411
)
1390-
# If the agent has hooks, we need to call them after the LLM call
1391-
if agent.hooks:
1392-
await agent.hooks.on_llm_end(context_wrapper, agent, new_response)
13931412

13941413
context_wrapper.usage.add(new_response.usage)
13951414

1415+
# If we have run hooks, or if the agent has hooks, we need to call them after the LLM call
1416+
await asyncio.gather(
1417+
(
1418+
agent.hooks.on_llm_end(context_wrapper, agent, new_response)
1419+
if agent.hooks
1420+
else _coro.noop_coroutine()
1421+
),
1422+
hooks.on_llm_end(context_wrapper, agent, new_response),
1423+
)
1424+
13961425
return new_response
13971426

13981427
@classmethod

tests/test_run_hooks.py

Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
from collections import defaultdict
2+
from typing import Any, Optional
3+
4+
import pytest
5+
6+
from agents.agent import Agent
7+
from agents.items import ItemHelpers, ModelResponse, TResponseInputItem
8+
from agents.lifecycle import RunHooks
9+
from agents.models.interface import Model
10+
from agents.run import Runner
11+
from agents.run_context import RunContextWrapper, TContext
12+
from agents.tool import Tool
13+
from tests.test_agent_llm_hooks import AgentHooksForTests
14+
15+
from .fake_model import FakeModel
16+
from .test_responses import (
17+
get_function_tool,
18+
get_text_message,
19+
)
20+
21+
22+
class RunHooksForTests(RunHooks):
23+
def __init__(self):
24+
self.events: dict[str, int] = defaultdict(int)
25+
26+
def reset(self):
27+
self.events.clear()
28+
29+
async def on_agent_start(
30+
self, context: RunContextWrapper[TContext], agent: Agent[TContext]
31+
) -> None:
32+
self.events["on_agent_start"] += 1
33+
34+
async def on_agent_end(
35+
self, context: RunContextWrapper[TContext], agent: Agent[TContext], output: Any
36+
) -> None:
37+
self.events["on_agent_end"] += 1
38+
39+
async def on_handoff(
40+
self,
41+
context: RunContextWrapper[TContext],
42+
from_agent: Agent[TContext],
43+
to_agent: Agent[TContext],
44+
) -> None:
45+
self.events["on_handoff"] += 1
46+
47+
async def on_tool_start(
48+
self, context: RunContextWrapper[TContext], agent: Agent[TContext], tool: Tool
49+
) -> None:
50+
self.events["on_tool_start"] += 1
51+
52+
async def on_tool_end(
53+
self,
54+
context: RunContextWrapper[TContext],
55+
agent: Agent[TContext],
56+
tool: Tool,
57+
result: str,
58+
) -> None:
59+
self.events["on_tool_end"] += 1
60+
61+
async def on_llm_start(
62+
self,
63+
context: RunContextWrapper[TContext],
64+
agent: Agent[TContext],
65+
system_prompt: Optional[str],
66+
input_items: list[TResponseInputItem],
67+
) -> None:
68+
self.events["on_llm_start"] += 1
69+
70+
async def on_llm_end(
71+
self,
72+
context: RunContextWrapper[TContext],
73+
agent: Agent[TContext],
74+
response: ModelResponse,
75+
) -> None:
76+
self.events["on_llm_end"] += 1
77+
78+
79+
# Example test using the above hooks
80+
@pytest.mark.asyncio
81+
async def test_async_run_hooks_with_llm():
82+
hooks = RunHooksForTests()
83+
model = FakeModel()
84+
85+
agent = Agent(name="A", model=model, tools=[get_function_tool("f", "res")], handoffs=[])
86+
# Simulate a single LLM call producing an output:
87+
model.set_next_output([get_text_message("hello")])
88+
await Runner.run(agent, input="hello", hooks=hooks)
89+
# Expect one on_agent_start, one on_llm_start, one on_llm_end, and one on_agent_end
90+
assert hooks.events == {
91+
"on_agent_start": 1,
92+
"on_llm_start": 1,
93+
"on_llm_end": 1,
94+
"on_agent_end": 1,
95+
}
96+
97+
98+
# test_sync_run_hook_with_llm()
99+
def test_sync_run_hook_with_llm():
100+
hooks = RunHooksForTests()
101+
model = FakeModel()
102+
agent = Agent(name="A", model=model, tools=[get_function_tool("f", "res")], handoffs=[])
103+
# Simulate a single LLM call producing an output:
104+
model.set_next_output([get_text_message("hello")])
105+
Runner.run_sync(agent, input="hello", hooks=hooks)
106+
# Expect one on_agent_start, one on_llm_start, one on_llm_end, and one on_agent_end
107+
assert hooks.events == {
108+
"on_agent_start": 1,
109+
"on_llm_start": 1,
110+
"on_llm_end": 1,
111+
"on_agent_end": 1,
112+
}
113+
114+
115+
# test_streamed_run_hooks_with_llm():
116+
@pytest.mark.asyncio
117+
async def test_streamed_run_hooks_with_llm():
118+
hooks = RunHooksForTests()
119+
model = FakeModel()
120+
agent = Agent(name="A", model=model, tools=[get_function_tool("f", "res")], handoffs=[])
121+
# Simulate a single LLM call producing an output:
122+
model.set_next_output([get_text_message("hello")])
123+
stream = Runner.run_streamed(agent, input="hello", hooks=hooks)
124+
125+
async for event in stream.stream_events():
126+
if event.type == "raw_response_event":
127+
continue
128+
if event.type == "agent_updated_stream_event":
129+
print(f"[EVENT] agent_updated → {event.new_agent.name}")
130+
elif event.type == "run_item_stream_event":
131+
item = event.item
132+
if item.type == "tool_call_item":
133+
print("[EVENT] tool_call_item")
134+
elif item.type == "tool_call_output_item":
135+
print(f"[EVENT] tool_call_output_item → {item.output}")
136+
elif item.type == "message_output_item":
137+
text = ItemHelpers.text_message_output(item)
138+
print(f"[EVENT] message_output_item → {text}")
139+
140+
# Expect one on_agent_start, one on_llm_start, one on_llm_end, and one on_agent_end
141+
assert hooks.events == {
142+
"on_agent_start": 1,
143+
"on_llm_start": 1,
144+
"on_llm_end": 1,
145+
"on_agent_end": 1,
146+
}
147+
148+
149+
# test_async_run_hooks_with_agent_hooks_with_llm
150+
@pytest.mark.asyncio
151+
async def test_async_run_hooks_with_agent_hooks_with_llm():
152+
hooks = RunHooksForTests()
153+
agent_hooks = AgentHooksForTests()
154+
model = FakeModel()
155+
156+
agent = Agent(
157+
name="A", model=model, tools=[get_function_tool("f", "res")], handoffs=[], hooks=agent_hooks
158+
)
159+
# Simulate a single LLM call producing an output:
160+
model.set_next_output([get_text_message("hello")])
161+
await Runner.run(agent, input="hello", hooks=hooks)
162+
# Expect one on_agent_start, one on_llm_start, one on_llm_end, and one on_agent_end
163+
assert hooks.events == {
164+
"on_agent_start": 1,
165+
"on_llm_start": 1,
166+
"on_llm_end": 1,
167+
"on_agent_end": 1,
168+
}
169+
# Expect one on_start, one on_llm_start, one on_llm_end, and one on_end
170+
assert agent_hooks.events == {"on_start": 1, "on_llm_start": 1, "on_llm_end": 1, "on_end": 1}
171+
172+
173+
@pytest.mark.asyncio
174+
async def test_run_hooks_llm_error_non_streaming(monkeypatch):
175+
hooks = RunHooksForTests()
176+
model = FakeModel()
177+
agent = Agent(name="A", model=model, tools=[get_function_tool("f", "res")], handoffs=[])
178+
179+
async def boom(*args, **kwargs):
180+
raise RuntimeError("boom")
181+
182+
monkeypatch.setattr(FakeModel, "get_response", boom, raising=True)
183+
184+
with pytest.raises(RuntimeError, match="boom"):
185+
await Runner.run(agent, input="hello", hooks=hooks)
186+
187+
# Current behavior is that hooks will not fire on LLM failure
188+
assert hooks.events["on_agent_start"] == 1
189+
assert hooks.events["on_llm_start"] == 1
190+
assert hooks.events["on_llm_end"] == 0
191+
assert hooks.events["on_agent_end"] == 0
192+
193+
194+
class BoomModel(Model):
195+
async def get_response(self, *a, **k):
196+
raise AssertionError("get_response should not be called in streaming test")
197+
198+
async def stream_response(self, *a, **k):
199+
yield {"foo": "bar"}
200+
raise RuntimeError("stream blew up")
201+
202+
203+
@pytest.mark.asyncio
204+
async def test_streamed_run_hooks_llm_error(monkeypatch):
205+
"""
206+
Verify that when the streaming path raises, we still emit on_llm_start
207+
but do NOT emit on_llm_end (current behavior), and the exception propagates.
208+
"""
209+
hooks = RunHooksForTests()
210+
agent = Agent(name="A", model=BoomModel(), tools=[get_function_tool("f", "res")], handoffs=[])
211+
212+
stream = Runner.run_streamed(agent, input="hello", hooks=hooks)
213+
214+
# Consuming the stream should surface the exception
215+
with pytest.raises(RuntimeError, match="stream blew up"):
216+
async for _ in stream.stream_events():
217+
pass
218+
219+
# Current behavior: success-only on_llm_end; ensure starts fired but ends did not.
220+
assert hooks.events["on_agent_start"] == 1
221+
assert hooks.events["on_llm_start"] == 1
222+
assert hooks.events["on_llm_end"] == 0
223+
assert hooks.events["on_agent_end"] == 0

0 commit comments

Comments
 (0)