Skip to content

Commit 7196862

Browse files
Added RunErrorDetails object for MaxTurnsExceeded exception (#743)
### Summary Introduced the `RunErrorDetails` object to get partial results from a run interrupted by `MaxTurnsExceeded` exception. In this proposal the `RunErrorDetails` object contains all the fields from `RunResult` with `final_output` set to `None` and `output_guardrail_results` set to an empty list. We can decide to return less information. @rm-openai At the moment the exception doesn't return the `RunErrorDetails` object for the streaming mode. Do you have any suggestions on how to deal with it? In the `_check_errors` function of `agents/result.py` file. ### Test plan I have not implemented any tests currently, but if needed I can implement a basic test to retrieve partial data. ### Issue number This PR is an attempt to solve issue #719 ### Checks - [✅ ] I've added new tests (if relevant) - [ ] I've added/updated the relevant documentation - [ ✅] I've run `make lint` and `make format` - [ ✅] I've made sure tests pass
1 parent d46e2ec commit 7196862

File tree

7 files changed

+167
-23
lines changed

7 files changed

+167
-23
lines changed

src/agents/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
MaxTurnsExceeded,
1515
ModelBehaviorError,
1616
OutputGuardrailTripwireTriggered,
17+
RunErrorDetails,
1718
UserError,
1819
)
1920
from .guardrail import (
@@ -204,6 +205,7 @@ def enable_verbose_stdout_logging():
204205
"AgentHooks",
205206
"RunContextWrapper",
206207
"TContext",
208+
"RunErrorDetails",
207209
"RunResult",
208210
"RunResultStreaming",
209211
"RunConfig",

src/agents/exceptions.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,39 @@
1-
from typing import TYPE_CHECKING
1+
from __future__ import annotations
2+
3+
from dataclasses import dataclass
4+
from typing import TYPE_CHECKING, Any
25

36
if TYPE_CHECKING:
7+
from .agent import Agent
48
from .guardrail import InputGuardrailResult, OutputGuardrailResult
9+
from .items import ModelResponse, RunItem, TResponseInputItem
10+
from .run_context import RunContextWrapper
11+
12+
from .util._pretty_print import pretty_print_run_error_details
13+
14+
15+
@dataclass
16+
class RunErrorDetails:
17+
"""Data collected from an agent run when an exception occurs."""
18+
input: str | list[TResponseInputItem]
19+
new_items: list[RunItem]
20+
raw_responses: list[ModelResponse]
21+
last_agent: Agent[Any]
22+
context_wrapper: RunContextWrapper[Any]
23+
input_guardrail_results: list[InputGuardrailResult]
24+
output_guardrail_results: list[OutputGuardrailResult]
25+
26+
def __str__(self) -> str:
27+
return pretty_print_run_error_details(self)
528

629

730
class AgentsException(Exception):
831
"""Base class for all exceptions in the Agents SDK."""
32+
run_data: RunErrorDetails | None
33+
34+
def __init__(self, *args: object) -> None:
35+
super().__init__(*args)
36+
self.run_data = None
937

1038

1139
class MaxTurnsExceeded(AgentsException):
@@ -15,6 +43,7 @@ class MaxTurnsExceeded(AgentsException):
1543

1644
def __init__(self, message: str):
1745
self.message = message
46+
super().__init__(message)
1847

1948

2049
class ModelBehaviorError(AgentsException):
@@ -26,6 +55,7 @@ class ModelBehaviorError(AgentsException):
2655

2756
def __init__(self, message: str):
2857
self.message = message
58+
super().__init__(message)
2959

3060

3161
class UserError(AgentsException):
@@ -35,15 +65,16 @@ class UserError(AgentsException):
3565

3666
def __init__(self, message: str):
3767
self.message = message
68+
super().__init__(message)
3869

3970

4071
class InputGuardrailTripwireTriggered(AgentsException):
4172
"""Exception raised when a guardrail tripwire is triggered."""
4273

43-
guardrail_result: "InputGuardrailResult"
74+
guardrail_result: InputGuardrailResult
4475
"""The result data of the guardrail that was triggered."""
4576

46-
def __init__(self, guardrail_result: "InputGuardrailResult"):
77+
def __init__(self, guardrail_result: InputGuardrailResult):
4778
self.guardrail_result = guardrail_result
4879
super().__init__(
4980
f"Guardrail {guardrail_result.guardrail.__class__.__name__} triggered tripwire"
@@ -53,10 +84,10 @@ def __init__(self, guardrail_result: "InputGuardrailResult"):
5384
class OutputGuardrailTripwireTriggered(AgentsException):
5485
"""Exception raised when a guardrail tripwire is triggered."""
5586

56-
guardrail_result: "OutputGuardrailResult"
87+
guardrail_result: OutputGuardrailResult
5788
"""The result data of the guardrail that was triggered."""
5889

59-
def __init__(self, guardrail_result: "OutputGuardrailResult"):
90+
def __init__(self, guardrail_result: OutputGuardrailResult):
6091
self.guardrail_result = guardrail_result
6192
super().__init__(
6293
f"Guardrail {guardrail_result.guardrail.__class__.__name__} triggered tripwire"

src/agents/result.py

Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,22 @@
1111
from ._run_impl import QueueCompleteSentinel
1212
from .agent import Agent
1313
from .agent_output import AgentOutputSchemaBase
14-
from .exceptions import InputGuardrailTripwireTriggered, MaxTurnsExceeded
14+
from .exceptions import (
15+
AgentsException,
16+
InputGuardrailTripwireTriggered,
17+
MaxTurnsExceeded,
18+
RunErrorDetails,
19+
)
1520
from .guardrail import InputGuardrailResult, OutputGuardrailResult
1621
from .items import ItemHelpers, ModelResponse, RunItem, TResponseInputItem
1722
from .logger import logger
1823
from .run_context import RunContextWrapper
1924
from .stream_events import StreamEvent
2025
from .tracing import Trace
21-
from .util._pretty_print import pretty_print_result, pretty_print_run_result_streaming
26+
from .util._pretty_print import (
27+
pretty_print_result,
28+
pretty_print_run_result_streaming,
29+
)
2230

2331
if TYPE_CHECKING:
2432
from ._run_impl import QueueCompleteSentinel
@@ -206,31 +214,53 @@ async def stream_events(self) -> AsyncIterator[StreamEvent]:
206214
if self._stored_exception:
207215
raise self._stored_exception
208216

217+
def _create_error_details(self) -> RunErrorDetails:
218+
"""Return a `RunErrorDetails` object considering the current attributes of the class."""
219+
return RunErrorDetails(
220+
input=self.input,
221+
new_items=self.new_items,
222+
raw_responses=self.raw_responses,
223+
last_agent=self.current_agent,
224+
context_wrapper=self.context_wrapper,
225+
input_guardrail_results=self.input_guardrail_results,
226+
output_guardrail_results=self.output_guardrail_results,
227+
)
228+
209229
def _check_errors(self):
210230
if self.current_turn > self.max_turns:
211-
self._stored_exception = MaxTurnsExceeded(f"Max turns ({self.max_turns}) exceeded")
231+
max_turns_exc = MaxTurnsExceeded(f"Max turns ({self.max_turns}) exceeded")
232+
max_turns_exc.run_data = self._create_error_details()
233+
self._stored_exception = max_turns_exc
212234

213235
# Fetch all the completed guardrail results from the queue and raise if needed
214236
while not self._input_guardrail_queue.empty():
215237
guardrail_result = self._input_guardrail_queue.get_nowait()
216238
if guardrail_result.output.tripwire_triggered:
217-
self._stored_exception = InputGuardrailTripwireTriggered(guardrail_result)
239+
tripwire_exc = InputGuardrailTripwireTriggered(guardrail_result)
240+
tripwire_exc.run_data = self._create_error_details()
241+
self._stored_exception = tripwire_exc
218242

219243
# Check the tasks for any exceptions
220244
if self._run_impl_task and self._run_impl_task.done():
221-
exc = self._run_impl_task.exception()
222-
if exc and isinstance(exc, Exception):
223-
self._stored_exception = exc
245+
run_impl_exc = self._run_impl_task.exception()
246+
if run_impl_exc and isinstance(run_impl_exc, Exception):
247+
if isinstance(run_impl_exc, AgentsException) and run_impl_exc.run_data is None:
248+
run_impl_exc.run_data = self._create_error_details()
249+
self._stored_exception = run_impl_exc
224250

225251
if self._input_guardrails_task and self._input_guardrails_task.done():
226-
exc = self._input_guardrails_task.exception()
227-
if exc and isinstance(exc, Exception):
228-
self._stored_exception = exc
252+
in_guard_exc = self._input_guardrails_task.exception()
253+
if in_guard_exc and isinstance(in_guard_exc, Exception):
254+
if isinstance(in_guard_exc, AgentsException) and in_guard_exc.run_data is None:
255+
in_guard_exc.run_data = self._create_error_details()
256+
self._stored_exception = in_guard_exc
229257

230258
if self._output_guardrails_task and self._output_guardrails_task.done():
231-
exc = self._output_guardrails_task.exception()
232-
if exc and isinstance(exc, Exception):
233-
self._stored_exception = exc
259+
out_guard_exc = self._output_guardrails_task.exception()
260+
if out_guard_exc and isinstance(out_guard_exc, Exception):
261+
if isinstance(out_guard_exc, AgentsException) and out_guard_exc.run_data is None:
262+
out_guard_exc.run_data = self._create_error_details()
263+
self._stored_exception = out_guard_exc
234264

235265
def _cleanup_tasks(self):
236266
if self._run_impl_task and not self._run_impl_task.done():
@@ -244,3 +274,4 @@ def _cleanup_tasks(self):
244274

245275
def __str__(self) -> str:
246276
return pretty_print_run_result_streaming(self)
277+

src/agents/run.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
12
from __future__ import annotations
23

34
import asyncio
@@ -26,6 +27,7 @@
2627
MaxTurnsExceeded,
2728
ModelBehaviorError,
2829
OutputGuardrailTripwireTriggered,
30+
RunErrorDetails,
2931
)
3032
from .guardrail import InputGuardrail, InputGuardrailResult, OutputGuardrail, OutputGuardrailResult
3133
from .handoffs import Handoff, HandoffInputFilter, handoff
@@ -208,7 +210,9 @@ async def run(
208210
data={"max_turns": max_turns},
209211
),
210212
)
211-
raise MaxTurnsExceeded(f"Max turns ({max_turns}) exceeded")
213+
raise MaxTurnsExceeded(
214+
f"Max turns ({max_turns}) exceeded"
215+
)
212216

213217
logger.debug(
214218
f"Running agent {current_agent.name} (turn {current_turn})",
@@ -283,6 +287,17 @@ async def run(
283287
raise AgentsException(
284288
f"Unknown next step type: {type(turn_result.next_step)}"
285289
)
290+
except AgentsException as exc:
291+
exc.run_data = RunErrorDetails(
292+
input=original_input,
293+
new_items=generated_items,
294+
raw_responses=model_responses,
295+
last_agent=current_agent,
296+
context_wrapper=context_wrapper,
297+
input_guardrail_results=input_guardrail_results,
298+
output_guardrail_results=[]
299+
)
300+
raise
286301
finally:
287302
if current_span:
288303
current_span.finish(reset_current=True)
@@ -609,6 +624,19 @@ async def _run_streamed_impl(
609624
streamed_result._event_queue.put_nowait(QueueCompleteSentinel())
610625
elif isinstance(turn_result.next_step, NextStepRunAgain):
611626
pass
627+
except AgentsException as exc:
628+
streamed_result.is_complete = True
629+
streamed_result._event_queue.put_nowait(QueueCompleteSentinel())
630+
exc.run_data = RunErrorDetails(
631+
input=streamed_result.input,
632+
new_items=streamed_result.new_items,
633+
raw_responses=streamed_result.raw_responses,
634+
last_agent=current_agent,
635+
context_wrapper=context_wrapper,
636+
input_guardrail_results=streamed_result.input_guardrail_results,
637+
output_guardrail_results=streamed_result.output_guardrail_results,
638+
)
639+
raise
612640
except Exception as e:
613641
if current_span:
614642
_error_tracing.attach_error_to_span(

src/agents/util/_pretty_print.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from pydantic import BaseModel
44

55
if TYPE_CHECKING:
6+
from ..exceptions import RunErrorDetails
67
from ..result import RunResult, RunResultBase, RunResultStreaming
78

89

@@ -38,6 +39,17 @@ def pretty_print_result(result: "RunResult") -> str:
3839
return output
3940

4041

42+
def pretty_print_run_error_details(result: "RunErrorDetails") -> str:
43+
output = "RunErrorDetails:"
44+
output += f'\n- Last agent: Agent(name="{result.last_agent.name}", ...)'
45+
output += f"\n- {len(result.new_items)} new item(s)"
46+
output += f"\n- {len(result.raw_responses)} raw response(s)"
47+
output += f"\n- {len(result.input_guardrail_results)} input guardrail result(s)"
48+
output += "\n(See `RunErrorDetails` for more details)"
49+
50+
return output
51+
52+
4153
def pretty_print_run_result_streaming(result: "RunResultStreaming") -> str:
4254
output = "RunResultStreaming:"
4355
output += f'\n- Current agent: Agent(name="{result.current_agent.name}", ...)'

tests/test_run_error_details.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import json
2+
3+
import pytest
4+
5+
from agents import Agent, MaxTurnsExceeded, RunErrorDetails, Runner
6+
7+
from .fake_model import FakeModel
8+
from .test_responses import get_function_tool, get_function_tool_call, get_text_message
9+
10+
11+
@pytest.mark.asyncio
12+
async def test_run_error_includes_data():
13+
model = FakeModel()
14+
agent = Agent(name="test", model=model, tools=[get_function_tool("foo", "res")])
15+
model.add_multiple_turn_outputs([
16+
[get_text_message("1"), get_function_tool_call("foo", json.dumps({"a": "b"}))],
17+
[get_text_message("done")],
18+
])
19+
with pytest.raises(MaxTurnsExceeded) as exc:
20+
await Runner.run(agent, input="hello", max_turns=1)
21+
data = exc.value.run_data
22+
assert isinstance(data, RunErrorDetails)
23+
assert data.last_agent == agent
24+
assert len(data.raw_responses) == 1
25+
assert len(data.new_items) > 0
26+
27+
28+
@pytest.mark.asyncio
29+
async def test_streamed_run_error_includes_data():
30+
model = FakeModel()
31+
agent = Agent(name="test", model=model, tools=[get_function_tool("foo", "res")])
32+
model.add_multiple_turn_outputs([
33+
[get_text_message("1"), get_function_tool_call("foo", json.dumps({"a": "b"}))],
34+
[get_text_message("done")],
35+
])
36+
result = Runner.run_streamed(agent, input="hello", max_turns=1)
37+
with pytest.raises(MaxTurnsExceeded) as exc:
38+
async for _ in result.stream_events():
39+
pass
40+
data = exc.value.run_data
41+
assert isinstance(data, RunErrorDetails)
42+
assert data.last_agent == agent
43+
assert len(data.raw_responses) == 1
44+
assert len(data.new_items) > 0

tests/test_tracing_errors_streamed.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -168,10 +168,6 @@ async def test_tool_call_error():
168168
"children": [
169169
{
170170
"type": "agent",
171-
"error": {
172-
"message": "Error in agent run",
173-
"data": {"error": "Invalid JSON input for tool foo: bad_json"},
174-
},
175171
"data": {
176172
"name": "test_agent",
177173
"handoffs": [],

0 commit comments

Comments
 (0)