Skip to content

Commit 3b874f3

Browse files
committed
Start and finish streaming trace in impl metod
1 parent 5639606 commit 3b874f3

File tree

8 files changed

+85
-9
lines changed

8 files changed

+85
-9
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,4 +141,5 @@ cython_debug/
141141
.ruff_cache/
142142

143143
# PyPI configuration file
144-
.pypirc
144+
.pypirc
145+
.aider*

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ dev = [
6161
"graphviz",
6262
"mkdocs-static-i18n>=1.3.0",
6363
"eval-type-backport>=0.2.2",
64+
"fastapi >= 0.110.0, <1",
6465
]
6566

6667
[tool.uv.workspace]

src/agents/result.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -185,9 +185,6 @@ async def stream_events(self) -> AsyncIterator[StreamEvent]:
185185
yield item
186186
self._event_queue.task_done()
187187

188-
if self._trace:
189-
self._trace.finish(reset_current=True)
190-
191188
self._cleanup_tasks()
192189

193190
if self._stored_exception:

src/agents/run.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -404,10 +404,6 @@ def run_streamed(
404404
disabled=run_config.tracing_disabled,
405405
)
406406
)
407-
# Need to start the trace here, because the current trace contextvar is captured at
408-
# asyncio.create_task time
409-
if new_trace:
410-
new_trace.start(mark_as_current=True)
411407

412408
output_schema = cls._get_output_schema(starting_agent)
413409
context_wrapper: RunContextWrapper[TContext] = RunContextWrapper(
@@ -499,6 +495,9 @@ async def _run_streamed_impl(
499495
run_config: RunConfig,
500496
previous_response_id: str | None,
501497
):
498+
if streamed_result._trace:
499+
streamed_result._trace.start(mark_as_current=True)
500+
502501
current_span: Span[AgentSpanData] | None = None
503502
current_agent = starting_agent
504503
current_turn = 0
@@ -625,6 +624,8 @@ async def _run_streamed_impl(
625624
finally:
626625
if current_span:
627626
current_span.finish(reset_current=True)
627+
if streamed_result._trace:
628+
streamed_result._trace.finish(reset_current=True)
628629

629630
@classmethod
630631
async def _run_single_turn_streamed(

tests/fastapi/__init__.py

Whitespace-only changes.

tests/fastapi/streaming_app.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from collections.abc import AsyncIterator
2+
3+
from fastapi import FastAPI
4+
from starlette.responses import StreamingResponse
5+
6+
from agents import Agent, Runner, RunResultStreaming
7+
8+
agent = Agent(
9+
name="Assistant",
10+
instructions="You are a helpful assistant.",
11+
)
12+
13+
14+
app = FastAPI()
15+
16+
17+
@app.post("/stream")
18+
async def stream():
19+
result = Runner.run_streamed(agent, input="Tell me a joke")
20+
stream_handler = StreamHandler(result)
21+
return StreamingResponse(stream_handler.stream_events(), media_type="application/x-ndjson")
22+
23+
24+
class StreamHandler:
25+
def __init__(self, result: RunResultStreaming):
26+
self.result = result
27+
28+
async def stream_events(self) -> AsyncIterator[str]:
29+
async for event in self.result.stream_events():
30+
yield f"{event.type}\n\n"
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import pytest
2+
from httpx import ASGITransport, AsyncClient
3+
from inline_snapshot import snapshot
4+
5+
from ..fake_model import FakeModel
6+
from ..test_responses import get_text_message
7+
from .streaming_app import agent, app
8+
9+
10+
@pytest.mark.asyncio
11+
async def test_streaming_context():
12+
"""This ensures that FastAPI streaming works. The context for this test is that the Runner
13+
method was called in one async context, and the streaming was ended in another context,
14+
leading to a tracing error because the context was closed in the wrong context. This test
15+
ensures that this actually works.
16+
"""
17+
model = FakeModel()
18+
agent.model = model
19+
model.set_next_output([get_text_message("done")])
20+
21+
transport = ASGITransport(app)
22+
async with AsyncClient(transport=transport, base_url="http://test") as ac:
23+
async with ac.stream("POST", "/stream") as r:
24+
assert r.status_code == 200
25+
body = (await r.aread()).decode("utf-8")
26+
lines = [line for line in body.splitlines() if line]
27+
assert lines == snapshot(
28+
["agent_updated_stream_event", "raw_response_event", "run_item_stream_event"]
29+
)

uv.lock

Lines changed: 18 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)