Skip to content

Remove redundant weaker tracing assertions #261

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ sync:
.PHONY: format
format:
uv run ruff format
uv run ruff check --fix

.PHONY: lint
lint:
Expand Down
41 changes: 1 addition & 40 deletions tests/test_agent_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from .fake_model import FakeModel
from .test_responses import get_text_message
from .testing_processor import fetch_normalized_spans, fetch_ordered_spans, fetch_traces
from .testing_processor import fetch_normalized_spans, fetch_traces


@pytest.mark.asyncio
Expand All @@ -23,9 +23,6 @@ async def test_single_run_is_single_trace():

await Runner.run(agent, input="first_test")

traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"

assert fetch_normalized_spans() == snapshot(
[
{
Expand All @@ -45,12 +42,6 @@ async def test_single_run_is_single_trace():
]
)

spans = fetch_ordered_spans()
assert len(spans) == 1, (
f"Got {len(spans)}, but expected 1: the agent span. data:"
f"{[span.span_data for span in spans]}"
)


@pytest.mark.asyncio
async def test_multiple_runs_are_multiple_traces():
Expand All @@ -69,9 +60,6 @@ async def test_multiple_runs_are_multiple_traces():
await Runner.run(agent, input="first_test")
await Runner.run(agent, input="second_test")

traces = fetch_traces()
assert len(traces) == 2, f"Expected 2 traces, got {len(traces)}"

assert fetch_normalized_spans() == snapshot(
[
{
Expand Down Expand Up @@ -105,9 +93,6 @@ async def test_multiple_runs_are_multiple_traces():
]
)

spans = fetch_ordered_spans()
assert len(spans) == 2, f"Got {len(spans)}, but expected 2: agent span per run"


@pytest.mark.asyncio
async def test_wrapped_trace_is_single_trace():
Expand All @@ -129,9 +114,6 @@ async def test_wrapped_trace_is_single_trace():
await Runner.run(agent, input="second_test")
await Runner.run(agent, input="third_test")

traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"

assert fetch_normalized_spans() == snapshot(
[
{
Expand Down Expand Up @@ -169,9 +151,6 @@ async def test_wrapped_trace_is_single_trace():
]
)

spans = fetch_ordered_spans()
assert len(spans) == 3, f"Got {len(spans)}, but expected 3: the agent span per run"


@pytest.mark.asyncio
async def test_parent_disabled_trace_disabled_agent_trace():
Expand All @@ -185,15 +164,8 @@ async def test_parent_disabled_trace_disabled_agent_trace():

await Runner.run(agent, input="first_test")

traces = fetch_traces()
assert len(traces) == 0, f"Expected 0 traces, got {len(traces)}"
assert fetch_normalized_spans() == snapshot([])

spans = fetch_ordered_spans()
assert len(spans) == 0, (
f"Expected no spans, got {len(spans)}, with {[x.span_data for x in spans]}"
)


@pytest.mark.asyncio
async def test_manual_disabling_works():
Expand All @@ -206,13 +178,8 @@ async def test_manual_disabling_works():

await Runner.run(agent, input="first_test", run_config=RunConfig(tracing_disabled=True))

traces = fetch_traces()
assert len(traces) == 0, f"Expected 0 traces, got {len(traces)}"
assert fetch_normalized_spans() == snapshot([])

spans = fetch_ordered_spans()
assert len(spans) == 0, f"Got {len(spans)}, but expected no spans"


@pytest.mark.asyncio
async def test_trace_config_works():
Expand Down Expand Up @@ -255,9 +222,6 @@ async def test_not_starting_streaming_creates_trace():
break
await asyncio.sleep(0.1)

traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"

assert fetch_normalized_spans() == snapshot(
[
{
Expand All @@ -277,9 +241,6 @@ async def test_not_starting_streaming_creates_trace():
]
)

spans = fetch_ordered_spans()
assert len(spans) == 1, f"Got {len(spans)}, but expected 1: the agent span"

# Await the stream to avoid warnings about it not being awaited
async for _ in result.stream_events():
pass
Expand Down
13 changes: 0 additions & 13 deletions tests/test_responses_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,6 @@ async def dummy_fetch_response(
]
)

spans = fetch_ordered_spans()
assert len(spans) == 1

assert isinstance(spans[0].span_data, ResponseSpanData)
assert spans[0].span_data.response is not None
assert spans[0].span_data.response.id == "dummy-id"


@pytest.mark.allow_call_model_methods
@pytest.mark.asyncio
Expand Down Expand Up @@ -164,12 +157,6 @@ async def __aiter__(self):
]
)

spans = fetch_ordered_spans()
assert len(spans) == 1
assert isinstance(spans[0].span_data, ResponseSpanData)
assert spans[0].span_data.response is not None
assert spans[0].span_data.response.id == "dummy-id-123"


@pytest.mark.allow_call_model_methods
@pytest.mark.asyncio
Expand Down
88 changes: 1 addition & 87 deletions tests/test_tracing_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
Runner,
TResponseInputItem,
)
from agents.tracing import AgentSpanData, FunctionSpanData, GenerationSpanData

from .fake_model import FakeModel
from .test_responses import (
Expand All @@ -28,7 +27,7 @@
get_handoff_tool_call,
get_text_message,
)
from .testing_processor import fetch_normalized_spans, fetch_ordered_spans, fetch_traces
from .testing_processor import fetch_normalized_spans


@pytest.mark.asyncio
Expand All @@ -43,9 +42,6 @@ async def test_single_turn_model_error():
with pytest.raises(ValueError):
await Runner.run(agent, input="first_test")

traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"

assert fetch_normalized_spans() == snapshot(
[
{
Expand Down Expand Up @@ -74,13 +70,6 @@ async def test_single_turn_model_error():
]
)

spans = fetch_ordered_spans()
assert len(spans) == 2, f"should have agent and generation spans, got {len(spans)}"

generation_span = spans[1]
assert isinstance(generation_span.span_data, GenerationSpanData)
assert generation_span.error, "should have error"


@pytest.mark.asyncio
async def test_multi_turn_no_handoffs():
Expand All @@ -106,9 +95,6 @@ async def test_multi_turn_no_handoffs():
with pytest.raises(ValueError):
await Runner.run(agent, input="first_test")

traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"

assert fetch_normalized_spans() == snapshot(
[
{
Expand Down Expand Up @@ -146,15 +132,6 @@ async def test_multi_turn_no_handoffs():
]
)

spans = fetch_ordered_spans()
assert len(spans) == 4, (
f"should have agent, generation, tool, generation, got {len(spans)} with data: "
f"{[x.span_data for x in spans]}"
)

last_generation_span = [x for x in spans if isinstance(x.span_data, GenerationSpanData)][-1]
assert last_generation_span.error, "should have error"


@pytest.mark.asyncio
async def test_tool_call_error():
Expand All @@ -173,9 +150,6 @@ async def test_tool_call_error():
with pytest.raises(ModelBehaviorError):
await Runner.run(agent, input="first_test")

traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"

assert fetch_normalized_spans() == snapshot(
[
{
Expand Down Expand Up @@ -209,15 +183,6 @@ async def test_tool_call_error():
]
)

spans = fetch_ordered_spans()
assert len(spans) == 3, (
f"should have agent, generation, tool spans, got {len(spans)} with data: "
f"{[x.span_data for x in spans]}"
)

function_span = [x for x in spans if isinstance(x.span_data, FunctionSpanData)][0]
assert function_span.error, "should have error"


@pytest.mark.asyncio
async def test_multiple_handoff_doesnt_error():
Expand Down Expand Up @@ -255,9 +220,6 @@ async def test_multiple_handoff_doesnt_error():
result = await Runner.run(agent_3, input="user_message")
assert result.last_agent == agent_1, "should have picked first handoff"

traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"

assert fetch_normalized_spans() == snapshot(
[
{
Expand Down Expand Up @@ -295,12 +257,6 @@ async def test_multiple_handoff_doesnt_error():
]
)

spans = fetch_ordered_spans()
assert len(spans) == 7, (
f"should have 2 agent, 1 function, 3 generation, 1 handoff, got {len(spans)} with data: "
f"{[x.span_data for x in spans]}"
)


class Foo(TypedDict):
bar: str
Expand All @@ -326,9 +282,6 @@ async def test_multiple_final_output_doesnt_error():
result = await Runner.run(agent_1, input="user_message")
assert result.final_output == Foo(bar="abc")

traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"

assert fetch_normalized_spans() == snapshot(
[
{
Expand All @@ -344,12 +297,6 @@ async def test_multiple_final_output_doesnt_error():
]
)

spans = fetch_ordered_spans()
assert len(spans) == 2, (
f"should have 1 agent, 1 generation, got {len(spans)} with data: "
f"{[x.span_data for x in spans]}"
)


@pytest.mark.asyncio
async def test_handoffs_lead_to_correct_agent_spans():
Expand Down Expand Up @@ -399,9 +346,6 @@ async def test_handoffs_lead_to_correct_agent_spans():
f"should have ended on the third agent, got {result.last_agent.name}"
)

traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"

assert fetch_normalized_spans() == snapshot(
[
{
Expand Down Expand Up @@ -472,12 +416,6 @@ async def test_handoffs_lead_to_correct_agent_spans():
]
)

spans = fetch_ordered_spans()
assert len(spans) == 12, (
f"should have 3 agents, 2 function, 5 generation, 2 handoff, got {len(spans)} with data: "
f"{[x.span_data for x in spans]}"
)


@pytest.mark.asyncio
async def test_max_turns_exceeded():
Expand All @@ -503,9 +441,6 @@ async def test_max_turns_exceeded():
with pytest.raises(MaxTurnsExceeded):
await Runner.run(agent, input="user_message", max_turns=2)

traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"

assert fetch_normalized_spans() == snapshot(
[
{
Expand Down Expand Up @@ -538,15 +473,6 @@ async def test_max_turns_exceeded():
]
)

spans = fetch_ordered_spans()
assert len(spans) == 5, (
f"should have 1 agent span, 2 generations, 2 function calls, got "
f"{len(spans)} with data: {[x.span_data for x in spans]}"
)

agent_span = [x for x in spans if isinstance(x.span_data, AgentSpanData)][-1]
assert agent_span.error, "last agent should have error"


def guardrail_function(
context: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem]
Expand All @@ -568,9 +494,6 @@ async def test_guardrail_error():
with pytest.raises(InputGuardrailTripwireTriggered):
await Runner.run(agent, input="user_message")

traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"

assert fetch_normalized_spans() == snapshot(
[
{
Expand All @@ -594,12 +517,3 @@ async def test_guardrail_error():
}
]
)

spans = fetch_ordered_spans()
assert len(spans) == 2, (
f"should have 1 agent, 1 guardrail, got {len(spans)} with data: "
f"{[x.span_data for x in spans]}"
)

agent_span = [x for x in spans if isinstance(x.span_data, AgentSpanData)][-1]
assert agent_span.error, "last agent should have error"
Loading