diff --git a/.gitignore b/.gitignore index 7dd22b88..2e9b9237 100644 --- a/.gitignore +++ b/.gitignore @@ -141,4 +141,5 @@ cython_debug/ .ruff_cache/ # PyPI configuration file -.pypirc \ No newline at end of file +.pypirc +.aider* diff --git a/pyproject.toml b/pyproject.toml index f94eb859..c06a8416 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,6 +61,7 @@ dev = [ "graphviz", "mkdocs-static-i18n>=1.3.0", "eval-type-backport>=0.2.2", + "fastapi >= 0.110.0, <1", ] [tool.uv.workspace] diff --git a/src/agents/result.py b/src/agents/result.py index a2a6cc4a..de60298b 100644 --- a/src/agents/result.py +++ b/src/agents/result.py @@ -126,7 +126,7 @@ class RunResultStreaming(RunResultBase): _current_agent_output_schema: AgentOutputSchema | None = field(repr=False) - _trace: Trace | None = field(repr=False) + trace: Trace | None = field(repr=False) is_complete: bool = False """Whether the agent has finished running.""" @@ -185,9 +185,6 @@ async def stream_events(self) -> AsyncIterator[StreamEvent]: yield item self._event_queue.task_done() - if self._trace: - self._trace.finish(reset_current=True) - self._cleanup_tasks() if self._stored_exception: diff --git a/src/agents/run.py b/src/agents/run.py index e2b0dbce..6396e5c9 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -404,10 +404,6 @@ def run_streamed( disabled=run_config.tracing_disabled, ) ) - # Need to start the trace here, because the current trace contextvar is captured at - # asyncio.create_task time - if new_trace: - new_trace.start(mark_as_current=True) output_schema = cls._get_output_schema(starting_agent) context_wrapper: RunContextWrapper[TContext] = RunContextWrapper( @@ -426,7 +422,7 @@ def run_streamed( input_guardrail_results=[], output_guardrail_results=[], _current_agent_output_schema=output_schema, - _trace=new_trace, + trace=new_trace, ) # Kick off the actual agent loop in the background and return the streamed result object. @@ -499,6 +495,9 @@ async def _run_streamed_impl( run_config: RunConfig, previous_response_id: str | None, ): + if streamed_result.trace: + streamed_result.trace.start(mark_as_current=True) + current_span: Span[AgentSpanData] | None = None current_agent = starting_agent current_turn = 0 @@ -625,6 +624,8 @@ async def _run_streamed_impl( finally: if current_span: current_span.finish(reset_current=True) + if streamed_result.trace: + streamed_result.trace.finish(reset_current=True) @classmethod async def _run_single_turn_streamed( diff --git a/tests/fastapi/__init__.py b/tests/fastapi/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/fastapi/streaming_app.py b/tests/fastapi/streaming_app.py new file mode 100644 index 00000000..b93ccf3f --- /dev/null +++ b/tests/fastapi/streaming_app.py @@ -0,0 +1,30 @@ +from collections.abc import AsyncIterator + +from fastapi import FastAPI +from starlette.responses import StreamingResponse + +from agents import Agent, Runner, RunResultStreaming + +agent = Agent( + name="Assistant", + instructions="You are a helpful assistant.", +) + + +app = FastAPI() + + +@app.post("/stream") +async def stream(): + result = Runner.run_streamed(agent, input="Tell me a joke") + stream_handler = StreamHandler(result) + return StreamingResponse(stream_handler.stream_events(), media_type="application/x-ndjson") + + +class StreamHandler: + def __init__(self, result: RunResultStreaming): + self.result = result + + async def stream_events(self) -> AsyncIterator[str]: + async for event in self.result.stream_events(): + yield f"{event.type}\n\n" diff --git a/tests/fastapi/test_streaming_context.py b/tests/fastapi/test_streaming_context.py new file mode 100644 index 00000000..ee13045e --- /dev/null +++ b/tests/fastapi/test_streaming_context.py @@ -0,0 +1,29 @@ +import pytest +from httpx import ASGITransport, AsyncClient +from inline_snapshot import snapshot + +from ..fake_model import FakeModel +from ..test_responses import get_text_message +from .streaming_app import agent, app + + +@pytest.mark.asyncio +async def test_streaming_context(): + """This ensures that FastAPI streaming works. The context for this test is that the Runner + method was called in one async context, and the streaming was ended in another context, + leading to a tracing error because the context was closed in the wrong context. This test + ensures that this actually works. + """ + model = FakeModel() + agent.model = model + model.set_next_output([get_text_message("done")]) + + transport = ASGITransport(app) + async with AsyncClient(transport=transport, base_url="http://test") as ac: + async with ac.stream("POST", "/stream") as r: + assert r.status_code == 200 + body = (await r.aread()).decode("utf-8") + lines = [line for line in body.splitlines() if line] + assert lines == snapshot( + ["agent_updated_stream_event", "raw_response_event", "run_item_stream_event"] + ) diff --git a/uv.lock b/uv.lock index 24d089a5..3ee65d21 100644 --- a/uv.lock +++ b/uv.lock @@ -483,6 +483,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7b/8f/c4d9bafc34ad7ad5d8dc16dd1347ee0e507a52c3adb6bfa8887e1c6a26ba/executing-2.2.0-py2.py3-none-any.whl", hash = "sha256:11387150cad388d62750327a53d3339fad4888b39a6fe233c3afbb54ecffd3aa", size = 26702 }, ] +[[package]] +name = "fastapi" +version = "0.115.12" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pydantic" }, + { name = "starlette" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f4/55/ae499352d82338331ca1e28c7f4a63bfd09479b16395dce38cf50a39e2c2/fastapi-0.115.12.tar.gz", hash = "sha256:1e2c2a2646905f9e83d32f04a3f86aff4a286669c6c950ca95b5fd68c2602681", size = 295236 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/50/b3/b51f09c2ba432a576fe63758bddc81f78f0c6309d9e5c10d194313bf021e/fastapi-0.115.12-py3-none-any.whl", hash = "sha256:e94613d6c05e27be7ffebdd6ea5f388112e5e430c8f7d6494a9d1d88d43e814d", size = 95164 }, +] + [[package]] name = "filelock" version = "3.18.0" @@ -1496,6 +1510,7 @@ voice = [ dev = [ { name = "coverage" }, { name = "eval-type-backport" }, + { name = "fastapi" }, { name = "graphviz" }, { name = "inline-snapshot" }, { name = "mkdocs" }, @@ -1536,6 +1551,7 @@ provides-extras = ["voice", "viz", "litellm"] dev = [ { name = "coverage", specifier = ">=7.6.12" }, { name = "eval-type-backport", specifier = ">=0.2.2" }, + { name = "fastapi", specifier = ">=0.110.0,<1" }, { name = "graphviz" }, { name = "inline-snapshot", specifier = ">=0.20.7" }, { name = "mkdocs", specifier = ">=1.6.0" }, @@ -2474,7 +2490,8 @@ name = "starlette" version = "0.46.2" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "anyio", marker = "python_full_version >= '3.10'" }, + { name = "anyio" }, + { name = "typing-extensions", marker = "python_full_version < '3.10'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/ce/20/08dfcd9c983f6a6f4a1000d934b9e6d626cff8d2eeb77a89a68eef20a2b7/starlette-0.46.2.tar.gz", hash = "sha256:7f7361f34eed179294600af672f565727419830b54b7b084efe44bb82d2fccd5", size = 2580846 } wheels = [