From 616192ea8b91d8d47b1ed9d17af41906d93d9de9 Mon Sep 17 00:00:00 2001 From: pakrym-oai Date: Mon, 19 May 2025 17:45:57 -0700 Subject: [PATCH 1/8] Runner --- src/agents/run.py | 294 ++++++++++++++++++++++++---------------------- 1 file changed, 153 insertions(+), 141 deletions(-) diff --git a/src/agents/run.py b/src/agents/run.py index 849da7bf..018200a9 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -1,5 +1,6 @@ from __future__ import annotations +import abc import asyncio import copy from dataclasses import dataclass, field @@ -106,7 +107,48 @@ class RunConfig: """ -class Runner: +class Runner(abc.ABC): + @abc.abstractmethod + async def run_impl( + self, + starting_agent: Agent[TContext], + input: str | list[TResponseInputItem], + *, + context: TContext | None = None, + max_turns: int = DEFAULT_MAX_TURNS, + hooks: RunHooks[TContext] | None = None, + run_config: RunConfig | None = None, + previous_response_id: str | None = None, + ) -> RunResult: + pass + + @abc.abstractmethod + def run_sync_impl( + self, + starting_agent: Agent[TContext], + input: str | list[TResponseInputItem], + *, + context: TContext | None = None, + max_turns: int = DEFAULT_MAX_TURNS, + hooks: RunHooks[TContext] | None = None, + run_config: RunConfig | None = None, + previous_response_id: str | None = None, + ) -> RunResult: + pass + + @abc.abstractmethod + def run_streaming_impl( + self, + starting_agent: Agent[TContext], + input: str | list[TResponseInputItem], + context: TContext | None = None, + max_turns: int = DEFAULT_MAX_TURNS, + hooks: RunHooks[TContext] | None = None, + run_config: RunConfig | None = None, + previous_response_id: str | None = None, + ) -> RunResultStreaming: + pass + @classmethod async def run( cls, @@ -119,36 +161,102 @@ async def run( run_config: RunConfig | None = None, previous_response_id: str | None = None, ) -> RunResult: - """Run a workflow starting at the given agent. The agent will run in a loop until a final - output is generated. The loop runs like so: - 1. The agent is invoked with the given input. - 2. If there is a final output (i.e. the agent produces something of type - `agent.output_type`, the loop terminates. - 3. If there's a handoff, we run the loop again, with the new agent. - 4. Else, we run tool calls (if any), and re-run the loop. - - In two cases, the agent may raise an exception: - 1. If the max_turns is exceeded, a MaxTurnsExceeded exception is raised. - 2. If a guardrail tripwire is triggered, a GuardrailTripwireTriggered exception is raised. - - Note that only the first agent's input guardrails are run. - - Args: - starting_agent: The starting agent to run. - input: The initial input to the agent. You can pass a single string for a user message, - or a list of input items. - context: The context to run the agent with. - max_turns: The maximum number of turns to run the agent for. A turn is defined as one - AI invocation (including any tool calls that might occur). - hooks: An object that receives callbacks on various lifecycle events. - run_config: Global settings for the entire agent run. - previous_response_id: The ID of the previous response, if using OpenAI models via the - Responses API, this allows you to skip passing in input from the previous turn. - - Returns: - A run result containing all the inputs, guardrail results and the output of the last - agent. Agents may perform handoffs, so we don't know the specific type of the output. - """ + return await DefaultRunner().run_impl( + starting_agent, + input, + context=context, + max_turns=max_turns, + hooks=hooks, + run_config=run_config, + previous_response_id=previous_response_id, + ) + + @classmethod + def run_sync( + cls, + starting_agent: Agent[TContext], + input: str | list[TResponseInputItem], + *, + context: TContext | None = None, + max_turns: int = DEFAULT_MAX_TURNS, + hooks: RunHooks[TContext] | None = None, + run_config: RunConfig | None = None, + previous_response_id: str | None = None, + ) -> RunResult: + return DefaultRunner().run_sync_impl( + starting_agent, + input, + context=context, + max_turns=max_turns, + hooks=hooks, + run_config=run_config, + previous_response_id=previous_response_id, + ) + + @classmethod + def run_streamed( + cls, + starting_agent: Agent[TContext], + input: str | list[TResponseInputItem], + context: TContext | None = None, + max_turns: int = DEFAULT_MAX_TURNS, + hooks: RunHooks[TContext] | None = None, + run_config: RunConfig | None = None, + previous_response_id: str | None = None, + ) -> RunResultStreaming: + return DefaultRunner().run_streaming_impl( + starting_agent, + input, + context=context, + max_turns=max_turns, + hooks=hooks, + run_config=run_config, + previous_response_id=previous_response_id, + ) + + @classmethod + def _get_output_schema(cls, agent: Agent[Any]) -> AgentOutputSchemaBase | None: + if agent.output_type is None or agent.output_type is str: + return None + elif isinstance(agent.output_type, AgentOutputSchemaBase): + return agent.output_type + + return AgentOutputSchema(agent.output_type) + + @classmethod + def _get_handoffs(cls, agent: Agent[Any]) -> list[Handoff]: + handoffs = [] + for handoff_item in agent.handoffs: + if isinstance(handoff_item, Handoff): + handoffs.append(handoff_item) + elif isinstance(handoff_item, Agent): + handoffs.append(handoff(handoff_item)) + return handoffs + + @classmethod + def _get_model(cls, agent: Agent[Any], run_config: RunConfig) -> Model: + if isinstance(run_config.model, Model): + return run_config.model + elif isinstance(run_config.model, str): + return run_config.model_provider.get_model(run_config.model) + elif isinstance(agent.model, Model): + return agent.model + + return run_config.model_provider.get_model(agent.model) + + +class DefaultRunner(Runner): + async def run_impl( + self, + starting_agent: Agent[TContext], + input: str | list[TResponseInputItem], + *, + context: TContext | None = None, + max_turns: int = DEFAULT_MAX_TURNS, + hooks: RunHooks[TContext] | None = None, + run_config: RunConfig | None = None, + previous_response_id: str | None = None, + ) -> RunResult: if hooks is None: hooks = RunHooks[Any]() if run_config is None: @@ -183,8 +291,8 @@ async def run( # Start an agent span if we don't have one. This span is ended if the current # agent changes, or if the agent loop ends. if current_span is None: - handoff_names = [h.agent_name for h in cls._get_handoffs(current_agent)] - if output_schema := cls._get_output_schema(current_agent): + handoff_names = [h.agent_name for h in self._get_handoffs(current_agent)] + if output_schema := self._get_output_schema(current_agent): output_type_name = output_schema.name() else: output_type_name = "str" @@ -196,7 +304,7 @@ async def run( ) current_span.start(mark_as_current=True) - all_tools = await cls._get_all_tools(current_agent) + all_tools = await self._get_all_tools(current_agent) current_span.span_data.tools = [t.name for t in all_tools] current_turn += 1 @@ -216,14 +324,14 @@ async def run( if current_turn == 1: input_guardrail_results, turn_result = await asyncio.gather( - cls._run_input_guardrails( + self._run_input_guardrails( starting_agent, starting_agent.input_guardrails + (run_config.input_guardrails or []), copy.deepcopy(input), context_wrapper, ), - cls._run_single_turn( + self._run_single_turn( agent=current_agent, all_tools=all_tools, original_input=original_input, @@ -237,7 +345,7 @@ async def run( ), ) else: - turn_result = await cls._run_single_turn( + turn_result = await self._run_single_turn( agent=current_agent, all_tools=all_tools, original_input=original_input, @@ -256,7 +364,7 @@ async def run( generated_items = turn_result.generated_items if isinstance(turn_result.next_step, NextStepFinalOutput): - output_guardrail_results = await cls._run_output_guardrails( + output_guardrail_results = await self._run_output_guardrails( current_agent.output_guardrails + (run_config.output_guardrails or []), current_agent, turn_result.next_step.output, @@ -287,9 +395,8 @@ async def run( if current_span: current_span.finish(reset_current=True) - @classmethod - def run_sync( - cls, + def run_sync_impl( + self, starting_agent: Agent[TContext], input: str | list[TResponseInputItem], *, @@ -299,42 +406,8 @@ def run_sync( run_config: RunConfig | None = None, previous_response_id: str | None = None, ) -> RunResult: - """Run a workflow synchronously, starting at the given agent. Note that this just wraps the - `run` method, so it will not work if there's already an event loop (e.g. inside an async - function, or in a Jupyter notebook or async context like FastAPI). For those cases, use - the `run` method instead. - - The agent will run in a loop until a final output is generated. The loop runs like so: - 1. The agent is invoked with the given input. - 2. If there is a final output (i.e. the agent produces something of type - `agent.output_type`, the loop terminates. - 3. If there's a handoff, we run the loop again, with the new agent. - 4. Else, we run tool calls (if any), and re-run the loop. - - In two cases, the agent may raise an exception: - 1. If the max_turns is exceeded, a MaxTurnsExceeded exception is raised. - 2. If a guardrail tripwire is triggered, a GuardrailTripwireTriggered exception is raised. - - Note that only the first agent's input guardrails are run. - - Args: - starting_agent: The starting agent to run. - input: The initial input to the agent. You can pass a single string for a user message, - or a list of input items. - context: The context to run the agent with. - max_turns: The maximum number of turns to run the agent for. A turn is defined as one - AI invocation (including any tool calls that might occur). - hooks: An object that receives callbacks on various lifecycle events. - run_config: Global settings for the entire agent run. - previous_response_id: The ID of the previous response, if using OpenAI models via the - Responses API, this allows you to skip passing in input from the previous turn. - - Returns: - A run result containing all the inputs, guardrail results and the output of the last - agent. Agents may perform handoffs, so we don't know the specific type of the output. - """ return asyncio.get_event_loop().run_until_complete( - cls.run( + self.run( starting_agent, input, context=context, @@ -345,9 +418,8 @@ def run_sync( ) ) - @classmethod - def run_streamed( - cls, + def run_streaming_impl( + self, starting_agent: Agent[TContext], input: str | list[TResponseInputItem], context: TContext | None = None, @@ -356,36 +428,6 @@ def run_streamed( run_config: RunConfig | None = None, previous_response_id: str | None = None, ) -> RunResultStreaming: - """Run a workflow starting at the given agent in streaming mode. The returned result object - contains a method you can use to stream semantic events as they are generated. - - The agent will run in a loop until a final output is generated. The loop runs like so: - 1. The agent is invoked with the given input. - 2. If there is a final output (i.e. the agent produces something of type - `agent.output_type`, the loop terminates. - 3. If there's a handoff, we run the loop again, with the new agent. - 4. Else, we run tool calls (if any), and re-run the loop. - - In two cases, the agent may raise an exception: - 1. If the max_turns is exceeded, a MaxTurnsExceeded exception is raised. - 2. If a guardrail tripwire is triggered, a GuardrailTripwireTriggered exception is raised. - - Note that only the first agent's input guardrails are run. - - Args: - starting_agent: The starting agent to run. - input: The initial input to the agent. You can pass a single string for a user message, - or a list of input items. - context: The context to run the agent with. - max_turns: The maximum number of turns to run the agent for. A turn is defined as one - AI invocation (including any tool calls that might occur). - hooks: An object that receives callbacks on various lifecycle events. - run_config: Global settings for the entire agent run. - previous_response_id: The ID of the previous response, if using OpenAI models via the - Responses API, this allows you to skip passing in input from the previous turn. - Returns: - A result object that contains data about the run, as well as a method to stream events. - """ if hooks is None: hooks = RunHooks[Any]() if run_config is None: @@ -406,7 +448,7 @@ def run_streamed( ) ) - output_schema = cls._get_output_schema(starting_agent) + output_schema = self._get_output_schema(starting_agent) context_wrapper: RunContextWrapper[TContext] = RunContextWrapper( context=context # type: ignore ) @@ -429,7 +471,7 @@ def run_streamed( # Kick off the actual agent loop in the background and return the streamed result object. streamed_result._run_impl_task = asyncio.create_task( - cls._run_streamed_impl( + self._run_streamed_impl( starting_input=input, streamed_result=streamed_result, starting_agent=starting_agent, @@ -933,36 +975,6 @@ async def _get_new_response( return new_response - @classmethod - def _get_output_schema(cls, agent: Agent[Any]) -> AgentOutputSchemaBase | None: - if agent.output_type is None or agent.output_type is str: - return None - elif isinstance(agent.output_type, AgentOutputSchemaBase): - return agent.output_type - - return AgentOutputSchema(agent.output_type) - - @classmethod - def _get_handoffs(cls, agent: Agent[Any]) -> list[Handoff]: - handoffs = [] - for handoff_item in agent.handoffs: - if isinstance(handoff_item, Handoff): - handoffs.append(handoff_item) - elif isinstance(handoff_item, Agent): - handoffs.append(handoff(handoff_item)) - return handoffs - @classmethod async def _get_all_tools(cls, agent: Agent[Any]) -> list[Tool]: return await agent.get_all_tools() - - @classmethod - def _get_model(cls, agent: Agent[Any], run_config: RunConfig) -> Model: - if isinstance(run_config.model, Model): - return run_config.model - elif isinstance(run_config.model, str): - return run_config.model_provider.get_model(run_config.model) - elif isinstance(agent.model, Model): - return agent.model - - return run_config.model_provider.get_model(agent.model) From 046c2874ee8f50934132347a7e6489a6df5600ed Mon Sep 17 00:00:00 2001 From: pakrym-oai Date: Mon, 19 May 2025 17:51:18 -0700 Subject: [PATCH 2/8] LLM: run.py: add support for configurable default Runner - Introduce a DEFAULT_RUNNER global and a set_default_runner() function to allow callers to specify a default Runner for agent runs. - Update Runner.run, run_sync, and run_streaming to use DEFAULT_RUNNER if set, otherwise fallback to DefaultRunner. - Add detailed docstrings to run, run_sync, and run_streaming methods for clarity on agent execution flow and exceptions. --- src/agents/run.py | 99 +++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 96 insertions(+), 3 deletions(-) diff --git a/src/agents/run.py b/src/agents/run.py index 018200a9..6b7a2bd3 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -46,6 +46,15 @@ from .util import _coro, _error_tracing DEFAULT_MAX_TURNS = 10 +DEFAULT_RUNNER: Runner | None = None + + +def set_default_runner(runner: Runner) -> None: + """ + Set the default runner to use for the agent run. + """ + global DEFAULT_RUNNER + DEFAULT_RUNNER = runner @dataclass @@ -161,7 +170,34 @@ async def run( run_config: RunConfig | None = None, previous_response_id: str | None = None, ) -> RunResult: - return await DefaultRunner().run_impl( + """Run a workflow starting at the given agent. The agent will run in a loop until a final + output is generated. The loop runs like so: + 1. The agent is invoked with the given input. + 2. If there is a final output (i.e. the agent produces something of type + `agent.output_type`, the loop terminates. + 3. If there's a handoff, we run the loop again, with the new agent. + 4. Else, we run tool calls (if any), and re-run the loop. + In two cases, the agent may raise an exception: + 1. If the max_turns is exceeded, a MaxTurnsExceeded exception is raised. + 2. If a guardrail tripwire is triggered, a GuardrailTripwireTriggered exception is raised. + Note that only the first agent's input guardrails are run. + Args: + starting_agent: The starting agent to run. + input: The initial input to the agent. You can pass a single string for a user message, + or a list of input items. + context: The context to run the agent with. + max_turns: The maximum number of turns to run the agent for. A turn is defined as one + AI invocation (including any tool calls that might occur). + hooks: An object that receives callbacks on various lifecycle events. + run_config: Global settings for the entire agent run. + previous_response_id: The ID of the previous response, if using OpenAI models via the + Responses API, this allows you to skip passing in input from the previous turn. + Returns: + A run result containing all the inputs, guardrail results and the output of the last + agent. Agents may perform handoffs, so we don't know the specific type of the output. + """ + runner = DEFAULT_RUNNER or DefaultRunner() + return await runner.run_impl( starting_agent, input, context=context, @@ -183,7 +219,37 @@ def run_sync( run_config: RunConfig | None = None, previous_response_id: str | None = None, ) -> RunResult: - return DefaultRunner().run_sync_impl( + """Run a workflow synchronously, starting at the given agent. Note that this just wraps the + `run` method, so it will not work if there's already an event loop (e.g. inside an async + function, or in a Jupyter notebook or async context like FastAPI). For those cases, use + the `run` method instead. + The agent will run in a loop until a final output is generated. The loop runs like so: + 1. The agent is invoked with the given input. + 2. If there is a final output (i.e. the agent produces something of type + `agent.output_type`, the loop terminates. + 3. If there's a handoff, we run the loop again, with the new agent. + 4. Else, we run tool calls (if any), and re-run the loop. + In two cases, the agent may raise an exception: + 1. If the max_turns is exceeded, a MaxTurnsExceeded exception is raised. + 2. If a guardrail tripwire is triggered, a GuardrailTripwireTriggered exception is raised. + Note that only the first agent's input guardrails are run. + Args: + starting_agent: The starting agent to run. + input: The initial input to the agent. You can pass a single string for a user message, + or a list of input items. + context: The context to run the agent with. + max_turns: The maximum number of turns to run the agent for. A turn is defined as one + AI invocation (including any tool calls that might occur). + hooks: An object that receives callbacks on various lifecycle events. + run_config: Global settings for the entire agent run. + previous_response_id: The ID of the previous response, if using OpenAI models via the + Responses API, this allows you to skip passing in input from the previous turn. + Returns: + A run result containing all the inputs, guardrail results and the output of the last + agent. Agents may perform handoffs, so we don't know the specific type of the output. + """ + runner = DEFAULT_RUNNER or DefaultRunner() + return runner.run_sync_impl( starting_agent, input, context=context, @@ -204,7 +270,34 @@ def run_streamed( run_config: RunConfig | None = None, previous_response_id: str | None = None, ) -> RunResultStreaming: - return DefaultRunner().run_streaming_impl( + """Run a workflow starting at the given agent in streaming mode. The returned result object + contains a method you can use to stream semantic events as they are generated. + The agent will run in a loop until a final output is generated. The loop runs like so: + 1. The agent is invoked with the given input. + 2. If there is a final output (i.e. the agent produces something of type + `agent.output_type`, the loop terminates. + 3. If there's a handoff, we run the loop again, with the new agent. + 4. Else, we run tool calls (if any), and re-run the loop. + In two cases, the agent may raise an exception: + 1. If the max_turns is exceeded, a MaxTurnsExceeded exception is raised. + 2. If a guardrail tripwire is triggered, a GuardrailTripwireTriggered exception is raised. + Note that only the first agent's input guardrails are run. + Args: + starting_agent: The starting agent to run. + input: The initial input to the agent. You can pass a single string for a user message, + or a list of input items. + context: The context to run the agent with. + max_turns: The maximum number of turns to run the agent for. A turn is defined as one + AI invocation (including any tool calls that might occur). + hooks: An object that receives callbacks on various lifecycle events. + run_config: Global settings for the entire agent run. + previous_response_id: The ID of the previous response, if using OpenAI models via the + Responses API, this allows you to skip passing in input from the previous turn. + Returns: + A result object that contains data about the run, as well as a method to stream events. + """ + runner = DEFAULT_RUNNER or DefaultRunner() + return runner.run_streaming_impl( starting_agent, input, context=context, From c7b5053e4e6f7542d65e2c22cdd6456a165f12ac Mon Sep 17 00:00:00 2001 From: pakrym-oai Date: Mon, 19 May 2025 17:52:14 -0700 Subject: [PATCH 3/8] LLM: agents: export DefaultRunner and set_default_runner in __init__ --- src/agents/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/agents/__init__.py b/src/agents/__init__.py index 6d7c90b4..3c9c8465 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -45,7 +45,7 @@ from .models.openai_provider import OpenAIProvider from .models.openai_responses import OpenAIResponsesModel from .result import RunResult, RunResultStreaming -from .run import RunConfig, Runner +from .run import DefaultRunner, RunConfig, Runner, set_default_runner from .run_context import RunContextWrapper, TContext from .stream_events import ( AgentUpdatedStreamEvent, @@ -150,6 +150,7 @@ def enable_verbose_stdout_logging(): "ToolsToFinalOutputFunction", "ToolsToFinalOutputResult", "Runner", + "DefaultRunner", "Model", "ModelProvider", "ModelTracing", @@ -244,6 +245,7 @@ def enable_verbose_stdout_logging(): "set_default_openai_key", "set_default_openai_client", "set_default_openai_api", + "set_default_runner", "set_tracing_export_api_key", "enable_verbose_stdout_logging", "gen_trace_id", From a10f9459bf84c1871cb969a902b95a4b354f60a3 Mon Sep 17 00:00:00 2001 From: pakrym-oai Date: Mon, 19 May 2025 18:02:40 -0700 Subject: [PATCH 4/8] test --- src/agents/run.py | 8 ++++---- tests/conftest.py | 6 ++++++ tests/test_run.py | 26 ++++++++++++++++++++++++++ 3 files changed, 36 insertions(+), 4 deletions(-) create mode 100644 tests/test_run.py diff --git a/src/agents/run.py b/src/agents/run.py index 6b7a2bd3..6bf1ed16 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -49,7 +49,7 @@ DEFAULT_RUNNER: Runner | None = None -def set_default_runner(runner: Runner) -> None: +def set_default_runner(runner: Runner | None) -> None: """ Set the default runner to use for the agent run. """ @@ -146,7 +146,7 @@ def run_sync_impl( pass @abc.abstractmethod - def run_streaming_impl( + def run_streamed_impl( self, starting_agent: Agent[TContext], input: str | list[TResponseInputItem], @@ -297,7 +297,7 @@ def run_streamed( A result object that contains data about the run, as well as a method to stream events. """ runner = DEFAULT_RUNNER or DefaultRunner() - return runner.run_streaming_impl( + return runner.run_streamed_impl( starting_agent, input, context=context, @@ -511,7 +511,7 @@ def run_sync_impl( ) ) - def run_streaming_impl( + def run_streamed_impl( self, starting_agent: Agent[TContext], input: str | list[TResponseInputItem], diff --git a/tests/conftest.py b/tests/conftest.py index ba0d8822..622b61b1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,6 +5,7 @@ from agents.models import _openai_shared from agents.models.openai_chatcompletions import OpenAIChatCompletionsModel from agents.models.openai_responses import OpenAIResponsesModel +from agents.run import set_default_runner from agents.tracing import set_trace_processors from agents.tracing.setup import GLOBAL_TRACE_PROVIDER @@ -33,6 +34,11 @@ def clear_openai_settings(): _openai_shared._use_responses_by_default = True +@pytest.fixture(autouse=True) +def clear_default_runner(): + set_default_runner(None) + + # This fixture will run after all tests end @pytest.fixture(autouse=True, scope="session") def shutdown_trace_provider(): diff --git a/tests/test_run.py b/tests/test_run.py new file mode 100644 index 00000000..b01c8605 --- /dev/null +++ b/tests/test_run.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +from unittest import mock + +import pytest + +from agents import Agent, Runner +from agents.run import set_default_runner + +from .fake_model import FakeModel + + +@pytest.mark.asyncio +async def test_static_run_methods_call_into_default_runner() -> None: + runner = mock.Mock(spec=Runner) + set_default_runner(runner) + + agent = Agent(name="test", model=FakeModel()) + await Runner.run(agent, input="test") + runner.run_impl.assert_called_once() + + Runner.run_streamed(agent, input="test") + runner.run_streamed_impl.assert_called_once() + + Runner.run_sync(agent, input="test") + runner.run_sync_impl.assert_called_once() From 2002dd9b525a2ab782cb059a1e8a016f0b2cc6ac Mon Sep 17 00:00:00 2001 From: pakrym-oai Date: Mon, 19 May 2025 18:05:20 -0700 Subject: [PATCH 5/8] LLM: rename Runner execution methods to use leading underscores This commit renames the abstract and concrete methods run_impl, run_sync_impl, and run_streamed_impl in the Runner and DefaultRunner classes to _run_impl, _run_sync_impl, and _run_streamed_impl (or _start_streaming as appropriate) for improved naming consistency. Updates all method calls and test mocks accordingly. No behavior changes. --- src/agents/run.py | 22 +++++++++++----------- tests/test_run.py | 6 +++--- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/src/agents/run.py b/src/agents/run.py index 6bf1ed16..045f6f0b 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -118,7 +118,7 @@ class RunConfig: class Runner(abc.ABC): @abc.abstractmethod - async def run_impl( + async def _run_impl( self, starting_agent: Agent[TContext], input: str | list[TResponseInputItem], @@ -132,7 +132,7 @@ async def run_impl( pass @abc.abstractmethod - def run_sync_impl( + def _run_sync_impl( self, starting_agent: Agent[TContext], input: str | list[TResponseInputItem], @@ -146,7 +146,7 @@ def run_sync_impl( pass @abc.abstractmethod - def run_streamed_impl( + def _run_streamed_impl( self, starting_agent: Agent[TContext], input: str | list[TResponseInputItem], @@ -197,7 +197,7 @@ async def run( agent. Agents may perform handoffs, so we don't know the specific type of the output. """ runner = DEFAULT_RUNNER or DefaultRunner() - return await runner.run_impl( + return await runner._run_impl( starting_agent, input, context=context, @@ -249,7 +249,7 @@ def run_sync( agent. Agents may perform handoffs, so we don't know the specific type of the output. """ runner = DEFAULT_RUNNER or DefaultRunner() - return runner.run_sync_impl( + return runner._run_sync_impl( starting_agent, input, context=context, @@ -297,7 +297,7 @@ def run_streamed( A result object that contains data about the run, as well as a method to stream events. """ runner = DEFAULT_RUNNER or DefaultRunner() - return runner.run_streamed_impl( + return runner._run_streamed_impl( starting_agent, input, context=context, @@ -339,7 +339,7 @@ def _get_model(cls, agent: Agent[Any], run_config: RunConfig) -> Model: class DefaultRunner(Runner): - async def run_impl( + async def _run_impl( self, starting_agent: Agent[TContext], input: str | list[TResponseInputItem], @@ -488,7 +488,7 @@ async def run_impl( if current_span: current_span.finish(reset_current=True) - def run_sync_impl( + def _run_sync_impl( self, starting_agent: Agent[TContext], input: str | list[TResponseInputItem], @@ -511,7 +511,7 @@ def run_sync_impl( ) ) - def run_streamed_impl( + def _run_streamed_impl( self, starting_agent: Agent[TContext], input: str | list[TResponseInputItem], @@ -564,7 +564,7 @@ def run_streamed_impl( # Kick off the actual agent loop in the background and return the streamed result object. streamed_result._run_impl_task = asyncio.create_task( - self._run_streamed_impl( + self._start_streaming( starting_input=input, streamed_result=streamed_result, starting_agent=starting_agent, @@ -621,7 +621,7 @@ async def _run_input_guardrails_with_queue( streamed_result.input_guardrail_results = guardrail_results @classmethod - async def _run_streamed_impl( + async def _start_streaming( cls, starting_input: str | list[TResponseInputItem], streamed_result: RunResultStreaming, diff --git a/tests/test_run.py b/tests/test_run.py index b01c8605..57e33d50 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -17,10 +17,10 @@ async def test_static_run_methods_call_into_default_runner() -> None: agent = Agent(name="test", model=FakeModel()) await Runner.run(agent, input="test") - runner.run_impl.assert_called_once() + runner._run_impl.assert_called_once() Runner.run_streamed(agent, input="test") - runner.run_streamed_impl.assert_called_once() + runner._run_streamed_impl.assert_called_once() Runner.run_sync(agent, input="test") - runner.run_sync_impl.assert_called_once() + runner._run_sync_impl.assert_called_once() From 40d1b91410f8ea347f074abc5e9bb813629a64d5 Mon Sep 17 00:00:00 2001 From: pakrym-oai Date: Wed, 21 May 2025 09:19:06 -0700 Subject: [PATCH 6/8] Add ID generation methods to TraceProvider (#729) --- src/agents/__init__.py | 2 ++ src/agents/tracing/__init__.py | 3 ++- src/agents/tracing/setup.py | 24 ++++++++++++++++++++++++ src/agents/tracing/util.py | 19 +++++++++---------- 4 files changed, 37 insertions(+), 11 deletions(-) diff --git a/src/agents/__init__.py b/src/agents/__init__.py index 3c9c8465..65e44244 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -92,6 +92,7 @@ handoff_span, mcp_tools_span, set_trace_processors, + set_trace_provider, set_tracing_disabled, set_tracing_export_api_key, speech_group_span, @@ -221,6 +222,7 @@ def enable_verbose_stdout_logging(): "guardrail_span", "handoff_span", "set_trace_processors", + "set_trace_provider", "set_tracing_disabled", "speech_group_span", "transcription_span", diff --git a/src/agents/tracing/__init__.py b/src/agents/tracing/__init__.py index 9df94426..07d8af6d 100644 --- a/src/agents/tracing/__init__.py +++ b/src/agents/tracing/__init__.py @@ -18,7 +18,7 @@ ) from .processor_interface import TracingProcessor from .processors import default_exporter, default_processor -from .setup import GLOBAL_TRACE_PROVIDER +from .setup import GLOBAL_TRACE_PROVIDER, set_trace_provider from .span_data import ( AgentSpanData, CustomSpanData, @@ -49,6 +49,7 @@ "handoff_span", "response_span", "set_trace_processors", + "set_trace_provider", "set_tracing_disabled", "trace", "Trace", diff --git a/src/agents/tracing/setup.py b/src/agents/tracing/setup.py index 9e27d210..daa7b86d 100644 --- a/src/agents/tracing/setup.py +++ b/src/agents/tracing/setup.py @@ -2,6 +2,8 @@ import os import threading +import uuid +from datetime import datetime, timezone from typing import Any from ..logger import logger @@ -118,6 +120,22 @@ def set_disabled(self, disabled: bool) -> None: """ self._disabled = disabled + def time_iso(self) -> str: + """Return the current time in ISO 8601 format.""" + return datetime.now(timezone.utc).isoformat() + + def gen_trace_id(self) -> str: + """Generate a new trace ID.""" + return f"trace_{uuid.uuid4().hex}" + + def gen_span_id(self) -> str: + """Generate a new span ID.""" + return f"span_{uuid.uuid4().hex[:24]}" + + def gen_group_id(self) -> str: + """Generate a new group ID.""" + return f"group_{uuid.uuid4().hex[:24]}" + def create_trace( self, name: str, @@ -212,3 +230,9 @@ def shutdown(self) -> None: GLOBAL_TRACE_PROVIDER = TraceProvider() + + +def set_trace_provider(provider: TraceProvider) -> None: + """Set the global trace provider used by tracing utilities.""" + global GLOBAL_TRACE_PROVIDER + GLOBAL_TRACE_PROVIDER = provider diff --git a/src/agents/tracing/util.py b/src/agents/tracing/util.py index f546b4e5..af2f5ff3 100644 --- a/src/agents/tracing/util.py +++ b/src/agents/tracing/util.py @@ -1,22 +1,21 @@ -import uuid -from datetime import datetime, timezone +from .setup import GLOBAL_TRACE_PROVIDER def time_iso() -> str: - """Returns the current time in ISO 8601 format.""" - return datetime.now(timezone.utc).isoformat() + """Return the current time in ISO 8601 format.""" + return GLOBAL_TRACE_PROVIDER.time_iso() def gen_trace_id() -> str: - """Generates a new trace ID.""" - return f"trace_{uuid.uuid4().hex}" + """Generate a new trace ID.""" + return GLOBAL_TRACE_PROVIDER.gen_trace_id() def gen_span_id() -> str: - """Generates a new span ID.""" - return f"span_{uuid.uuid4().hex[:24]}" + """Generate a new span ID.""" + return GLOBAL_TRACE_PROVIDER.gen_span_id() def gen_group_id() -> str: - """Generates a new group ID.""" - return f"group_{uuid.uuid4().hex[:24]}" + """Generate a new group ID.""" + return GLOBAL_TRACE_PROVIDER.gen_group_id() From b9145707c6c9361650d093a3f3a8e5d46ae8e52d Mon Sep 17 00:00:00 2001 From: pakrym-oai Date: Wed, 21 May 2025 09:34:06 -0700 Subject: [PATCH 7/8] LLM: refactor(tracing): move TraceProvider to separate module and replace global with accessor - Move TraceProvider and related logic from setup.py to new provider.py module - Replace direct GLOBAL_TRACE_PROVIDER references with get_trace_provider() accessor throughout tracing code - Update init and util imports to use get_trace_provider - Call set_trace_provider(TraceProvider()) on init - Remove ID and time generation logic from util, delegate to TraceProvider - Update SpanImpl to always use passed-in span_id --- src/agents/tracing/__init__.py | 13 +- src/agents/tracing/create.py | 32 ++--- src/agents/tracing/provider.py | 228 +++++++++++++++++++++++++++++++ src/agents/tracing/setup.py | 239 ++------------------------------- src/agents/tracing/spans.py | 2 +- src/agents/tracing/util.py | 10 +- tests/conftest.py | 4 +- 7 files changed, 271 insertions(+), 257 deletions(-) create mode 100644 src/agents/tracing/provider.py diff --git a/src/agents/tracing/__init__.py b/src/agents/tracing/__init__.py index 07d8af6d..4281c29f 100644 --- a/src/agents/tracing/__init__.py +++ b/src/agents/tracing/__init__.py @@ -1,5 +1,7 @@ import atexit +from agents.tracing.provider import TraceProvider + from .create import ( agent_span, custom_span, @@ -18,7 +20,7 @@ ) from .processor_interface import TracingProcessor from .processors import default_exporter, default_processor -from .setup import GLOBAL_TRACE_PROVIDER, set_trace_provider +from .setup import get_trace_provider, set_trace_provider from .span_data import ( AgentSpanData, CustomSpanData, @@ -81,21 +83,21 @@ def add_trace_processor(span_processor: TracingProcessor) -> None: """ Adds a new trace processor. This processor will receive all traces/spans. """ - GLOBAL_TRACE_PROVIDER.register_processor(span_processor) + get_trace_provider().register_processor(span_processor) def set_trace_processors(processors: list[TracingProcessor]) -> None: """ Set the list of trace processors. This will replace the current list of processors. """ - GLOBAL_TRACE_PROVIDER.set_processors(processors) + get_trace_provider().set_processors(processors) def set_tracing_disabled(disabled: bool) -> None: """ Set whether tracing is globally disabled. """ - GLOBAL_TRACE_PROVIDER.set_disabled(disabled) + get_trace_provider().set_disabled(disabled) def set_tracing_export_api_key(api_key: str) -> None: @@ -105,10 +107,11 @@ def set_tracing_export_api_key(api_key: str) -> None: default_exporter().set_api_key(api_key) +set_trace_provider(TraceProvider()) # Add the default processor, which exports traces and spans to the backend in batches. You can # change the default behavior by either: # 1. calling add_trace_processor(), which adds additional processors, or # 2. calling set_trace_processors(), which replaces the default processor. add_trace_processor(default_processor()) -atexit.register(GLOBAL_TRACE_PROVIDER.shutdown) +atexit.register(get_trace_provider().shutdown) diff --git a/src/agents/tracing/create.py b/src/agents/tracing/create.py index b6fe4610..ac451abf 100644 --- a/src/agents/tracing/create.py +++ b/src/agents/tracing/create.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Any from ..logger import logger -from .setup import GLOBAL_TRACE_PROVIDER +from .setup import get_trace_provider from .span_data import ( AgentSpanData, CustomSpanData, @@ -56,13 +56,13 @@ def trace( Returns: The newly created trace object. """ - current_trace = GLOBAL_TRACE_PROVIDER.get_current_trace() + current_trace = get_trace_provider().get_current_trace() if current_trace: logger.warning( "Trace already exists. Creating a new trace, but this is probably a mistake." ) - return GLOBAL_TRACE_PROVIDER.create_trace( + return get_trace_provider().create_trace( name=workflow_name, trace_id=trace_id, group_id=group_id, @@ -73,12 +73,12 @@ def trace( def get_current_trace() -> Trace | None: """Returns the currently active trace, if present.""" - return GLOBAL_TRACE_PROVIDER.get_current_trace() + return get_trace_provider().get_current_trace() def get_current_span() -> Span[Any] | None: """Returns the currently active span, if present.""" - return GLOBAL_TRACE_PROVIDER.get_current_span() + return get_trace_provider().get_current_span() def agent_span( @@ -108,7 +108,7 @@ def agent_span( Returns: The newly created agent span. """ - return GLOBAL_TRACE_PROVIDER.create_span( + return get_trace_provider().create_span( span_data=AgentSpanData(name=name, handoffs=handoffs, tools=tools, output_type=output_type), span_id=span_id, parent=parent, @@ -141,7 +141,7 @@ def function_span( Returns: The newly created function span. """ - return GLOBAL_TRACE_PROVIDER.create_span( + return get_trace_provider().create_span( span_data=FunctionSpanData(name=name, input=input, output=output), span_id=span_id, parent=parent, @@ -183,7 +183,7 @@ def generation_span( Returns: The newly created generation span. """ - return GLOBAL_TRACE_PROVIDER.create_span( + return get_trace_provider().create_span( span_data=GenerationSpanData( input=input, output=output, @@ -215,7 +215,7 @@ def response_span( trace/span as the parent. disabled: If True, we will return a Span but the Span will not be recorded. """ - return GLOBAL_TRACE_PROVIDER.create_span( + return get_trace_provider().create_span( span_data=ResponseSpanData(response=response), span_id=span_id, parent=parent, @@ -246,7 +246,7 @@ def handoff_span( Returns: The newly created handoff span. """ - return GLOBAL_TRACE_PROVIDER.create_span( + return get_trace_provider().create_span( span_data=HandoffSpanData(from_agent=from_agent, to_agent=to_agent), span_id=span_id, parent=parent, @@ -278,7 +278,7 @@ def custom_span( Returns: The newly created custom span. """ - return GLOBAL_TRACE_PROVIDER.create_span( + return get_trace_provider().create_span( span_data=CustomSpanData(name=name, data=data or {}), span_id=span_id, parent=parent, @@ -306,7 +306,7 @@ def guardrail_span( trace/span as the parent. disabled: If True, we will return a Span but the Span will not be recorded. """ - return GLOBAL_TRACE_PROVIDER.create_span( + return get_trace_provider().create_span( span_data=GuardrailSpanData(name=name, triggered=triggered), span_id=span_id, parent=parent, @@ -344,7 +344,7 @@ def transcription_span( Returns: The newly created speech-to-text span. """ - return GLOBAL_TRACE_PROVIDER.create_span( + return get_trace_provider().create_span( span_data=TranscriptionSpanData( input=input, input_format=input_format, @@ -386,7 +386,7 @@ def speech_span( trace/span as the parent. disabled: If True, we will return a Span but the Span will not be recorded. """ - return GLOBAL_TRACE_PROVIDER.create_span( + return get_trace_provider().create_span( span_data=SpeechSpanData( model=model, input=input, @@ -419,7 +419,7 @@ def speech_group_span( trace/span as the parent. disabled: If True, we will return a Span but the Span will not be recorded. """ - return GLOBAL_TRACE_PROVIDER.create_span( + return get_trace_provider().create_span( span_data=SpeechGroupSpanData(input=input), span_id=span_id, parent=parent, @@ -447,7 +447,7 @@ def mcp_tools_span( trace/span as the parent. disabled: If True, we will return a Span but the Span will not be recorded. """ - return GLOBAL_TRACE_PROVIDER.create_span( + return get_trace_provider().create_span( span_data=MCPListToolsSpanData(server=server, result=result), span_id=span_id, parent=parent, diff --git a/src/agents/tracing/provider.py b/src/agents/tracing/provider.py new file mode 100644 index 00000000..b9f4c63e --- /dev/null +++ b/src/agents/tracing/provider.py @@ -0,0 +1,228 @@ +from __future__ import annotations + +import os +import threading +import uuid +from datetime import datetime, timezone +from typing import Any + +from ..logger import logger +from .processor_interface import TracingProcessor +from .scope import Scope +from .spans import NoOpSpan, Span, SpanImpl, TSpanData +from .traces import NoOpTrace, Trace, TraceImpl + + +class SynchronousMultiTracingProcessor(TracingProcessor): + """ + Forwards all calls to a list of TracingProcessors, in order of registration. + """ + + def __init__(self): + # Using a tuple to avoid race conditions when iterating over processors + self._processors: tuple[TracingProcessor, ...] = () + self._lock = threading.Lock() + + def add_tracing_processor(self, tracing_processor: TracingProcessor): + """ + Add a processor to the list of processors. Each processor will receive all traces/spans. + """ + with self._lock: + self._processors += (tracing_processor,) + + def set_processors(self, processors: list[TracingProcessor]): + """ + Set the list of processors. This will replace the current list of processors. + """ + with self._lock: + self._processors = tuple(processors) + + def on_trace_start(self, trace: Trace) -> None: + """ + Called when a trace is started. + """ + for processor in self._processors: + processor.on_trace_start(trace) + + def on_trace_end(self, trace: Trace) -> None: + """ + Called when a trace is finished. + """ + for processor in self._processors: + processor.on_trace_end(trace) + + def on_span_start(self, span: Span[Any]) -> None: + """ + Called when a span is started. + """ + for processor in self._processors: + processor.on_span_start(span) + + def on_span_end(self, span: Span[Any]) -> None: + """ + Called when a span is finished. + """ + for processor in self._processors: + processor.on_span_end(span) + + def shutdown(self) -> None: + """ + Called when the application stops. + """ + for processor in self._processors: + logger.debug(f"Shutting down trace processor {processor}") + processor.shutdown() + + def force_flush(self): + """ + Force the processors to flush their buffers. + """ + for processor in self._processors: + processor.force_flush() + + +class TraceProvider: + def __init__(self): + self._multi_processor = SynchronousMultiTracingProcessor() + self._disabled = os.environ.get("OPENAI_AGENTS_DISABLE_TRACING", "false").lower() in ( + "true", + "1", + ) + + def register_processor(self, processor: TracingProcessor): + """ + Add a processor to the list of processors. Each processor will receive all traces/spans. + """ + self._multi_processor.add_tracing_processor(processor) + + def set_processors(self, processors: list[TracingProcessor]): + """ + Set the list of processors. This will replace the current list of processors. + """ + self._multi_processor.set_processors(processors) + + def get_current_trace(self) -> Trace | None: + """ + Returns the currently active trace, if any. + """ + return Scope.get_current_trace() + + def get_current_span(self) -> Span[Any] | None: + """ + Returns the currently active span, if any. + """ + return Scope.get_current_span() + + def set_disabled(self, disabled: bool) -> None: + """ + Set whether tracing is disabled. + """ + self._disabled = disabled + + def time_iso(self) -> str: + """Return the current time in ISO 8601 format.""" + return datetime.now(timezone.utc).isoformat() + + def gen_trace_id(self) -> str: + """Generate a new trace ID.""" + return f"trace_{uuid.uuid4().hex}" + + def gen_span_id(self) -> str: + """Generate a new span ID.""" + return f"span_{uuid.uuid4().hex[:24]}" + + def gen_group_id(self) -> str: + """Generate a new group ID.""" + return f"group_{uuid.uuid4().hex[:24]}" + + def create_trace( + self, + name: str, + trace_id: str | None = None, + group_id: str | None = None, + metadata: dict[str, Any] | None = None, + disabled: bool = False, + ) -> Trace: + """ + Create a new trace. + """ + if self._disabled or disabled: + logger.debug(f"Tracing is disabled. Not creating trace {name}") + return NoOpTrace() + + trace_id = trace_id or self.gen_trace_id() + + logger.debug(f"Creating trace {name} with id {trace_id}") + + return TraceImpl( + name=name, + trace_id=trace_id, + group_id=group_id, + metadata=metadata, + processor=self._multi_processor, + ) + + def create_span( + self, + span_data: TSpanData, + span_id: str | None = None, + parent: Trace | Span[Any] | None = None, + disabled: bool = False, + ) -> Span[TSpanData]: + """ + Create a new span. + """ + if self._disabled or disabled: + logger.debug(f"Tracing is disabled. Not creating span {span_data}") + return NoOpSpan(span_data) + + if not parent: + current_span = Scope.get_current_span() + current_trace = Scope.get_current_trace() + if current_trace is None: + logger.error( + "No active trace. Make sure to start a trace with `trace()` first" + "Returning NoOpSpan." + ) + return NoOpSpan(span_data) + elif isinstance(current_trace, NoOpTrace) or isinstance(current_span, NoOpSpan): + logger.debug( + f"Parent {current_span} or {current_trace} is no-op, returning NoOpSpan" + ) + return NoOpSpan(span_data) + + parent_id = current_span.span_id if current_span else None + trace_id = current_trace.trace_id + + elif isinstance(parent, Trace): + if isinstance(parent, NoOpTrace): + logger.debug(f"Parent {parent} is no-op, returning NoOpSpan") + return NoOpSpan(span_data) + trace_id = parent.trace_id + parent_id = None + elif isinstance(parent, Span): + if isinstance(parent, NoOpSpan): + logger.debug(f"Parent {parent} is no-op, returning NoOpSpan") + return NoOpSpan(span_data) + parent_id = parent.span_id + trace_id = parent.trace_id + + logger.debug(f"Creating span {span_data} with id {span_id}") + + return SpanImpl( + trace_id=trace_id, + span_id=span_id or self.gen_span_id(), + parent_id=parent_id, + processor=self._multi_processor, + span_data=span_data, + ) + + def shutdown(self) -> None: + if self._disabled: + return + + try: + logger.debug("Shutting down trace provider") + self._multi_processor.shutdown() + except Exception as e: + logger.error(f"Error shutting down trace provider: {e}") diff --git a/src/agents/tracing/setup.py b/src/agents/tracing/setup.py index daa7b86d..3a56b728 100644 --- a/src/agents/tracing/setup.py +++ b/src/agents/tracing/setup.py @@ -1,238 +1,21 @@ from __future__ import annotations -import os -import threading -import uuid -from datetime import datetime, timezone -from typing import Any +from typing import TYPE_CHECKING -from ..logger import logger -from . import util -from .processor_interface import TracingProcessor -from .scope import Scope -from .spans import NoOpSpan, Span, SpanImpl, TSpanData -from .traces import NoOpTrace, Trace, TraceImpl +if TYPE_CHECKING: + from .provider import TraceProvider - -class SynchronousMultiTracingProcessor(TracingProcessor): - """ - Forwards all calls to a list of TracingProcessors, in order of registration. - """ - - def __init__(self): - # Using a tuple to avoid race conditions when iterating over processors - self._processors: tuple[TracingProcessor, ...] = () - self._lock = threading.Lock() - - def add_tracing_processor(self, tracing_processor: TracingProcessor): - """ - Add a processor to the list of processors. Each processor will receive all traces/spans. - """ - with self._lock: - self._processors += (tracing_processor,) - - def set_processors(self, processors: list[TracingProcessor]): - """ - Set the list of processors. This will replace the current list of processors. - """ - with self._lock: - self._processors = tuple(processors) - - def on_trace_start(self, trace: Trace) -> None: - """ - Called when a trace is started. - """ - for processor in self._processors: - processor.on_trace_start(trace) - - def on_trace_end(self, trace: Trace) -> None: - """ - Called when a trace is finished. - """ - for processor in self._processors: - processor.on_trace_end(trace) - - def on_span_start(self, span: Span[Any]) -> None: - """ - Called when a span is started. - """ - for processor in self._processors: - processor.on_span_start(span) - - def on_span_end(self, span: Span[Any]) -> None: - """ - Called when a span is finished. - """ - for processor in self._processors: - processor.on_span_end(span) - - def shutdown(self) -> None: - """ - Called when the application stops. - """ - for processor in self._processors: - logger.debug(f"Shutting down trace processor {processor}") - processor.shutdown() - - def force_flush(self): - """ - Force the processors to flush their buffers. - """ - for processor in self._processors: - processor.force_flush() - - -class TraceProvider: - def __init__(self): - self._multi_processor = SynchronousMultiTracingProcessor() - self._disabled = os.environ.get("OPENAI_AGENTS_DISABLE_TRACING", "false").lower() in ( - "true", - "1", - ) - - def register_processor(self, processor: TracingProcessor): - """ - Add a processor to the list of processors. Each processor will receive all traces/spans. - """ - self._multi_processor.add_tracing_processor(processor) - - def set_processors(self, processors: list[TracingProcessor]): - """ - Set the list of processors. This will replace the current list of processors. - """ - self._multi_processor.set_processors(processors) - - def get_current_trace(self) -> Trace | None: - """ - Returns the currently active trace, if any. - """ - return Scope.get_current_trace() - - def get_current_span(self) -> Span[Any] | None: - """ - Returns the currently active span, if any. - """ - return Scope.get_current_span() - - def set_disabled(self, disabled: bool) -> None: - """ - Set whether tracing is disabled. - """ - self._disabled = disabled - - def time_iso(self) -> str: - """Return the current time in ISO 8601 format.""" - return datetime.now(timezone.utc).isoformat() - - def gen_trace_id(self) -> str: - """Generate a new trace ID.""" - return f"trace_{uuid.uuid4().hex}" - - def gen_span_id(self) -> str: - """Generate a new span ID.""" - return f"span_{uuid.uuid4().hex[:24]}" - - def gen_group_id(self) -> str: - """Generate a new group ID.""" - return f"group_{uuid.uuid4().hex[:24]}" - - def create_trace( - self, - name: str, - trace_id: str | None = None, - group_id: str | None = None, - metadata: dict[str, Any] | None = None, - disabled: bool = False, - ) -> Trace: - """ - Create a new trace. - """ - if self._disabled or disabled: - logger.debug(f"Tracing is disabled. Not creating trace {name}") - return NoOpTrace() - - trace_id = trace_id or util.gen_trace_id() - - logger.debug(f"Creating trace {name} with id {trace_id}") - - return TraceImpl( - name=name, - trace_id=trace_id, - group_id=group_id, - metadata=metadata, - processor=self._multi_processor, - ) - - def create_span( - self, - span_data: TSpanData, - span_id: str | None = None, - parent: Trace | Span[Any] | None = None, - disabled: bool = False, - ) -> Span[TSpanData]: - """ - Create a new span. - """ - if self._disabled or disabled: - logger.debug(f"Tracing is disabled. Not creating span {span_data}") - return NoOpSpan(span_data) - - if not parent: - current_span = Scope.get_current_span() - current_trace = Scope.get_current_trace() - if current_trace is None: - logger.error( - "No active trace. Make sure to start a trace with `trace()` first" - "Returning NoOpSpan." - ) - return NoOpSpan(span_data) - elif isinstance(current_trace, NoOpTrace) or isinstance(current_span, NoOpSpan): - logger.debug( - f"Parent {current_span} or {current_trace} is no-op, returning NoOpSpan" - ) - return NoOpSpan(span_data) - - parent_id = current_span.span_id if current_span else None - trace_id = current_trace.trace_id - - elif isinstance(parent, Trace): - if isinstance(parent, NoOpTrace): - logger.debug(f"Parent {parent} is no-op, returning NoOpSpan") - return NoOpSpan(span_data) - trace_id = parent.trace_id - parent_id = None - elif isinstance(parent, Span): - if isinstance(parent, NoOpSpan): - logger.debug(f"Parent {parent} is no-op, returning NoOpSpan") - return NoOpSpan(span_data) - parent_id = parent.span_id - trace_id = parent.trace_id - - logger.debug(f"Creating span {span_data} with id {span_id}") - - return SpanImpl( - trace_id=trace_id, - span_id=span_id, - parent_id=parent_id, - processor=self._multi_processor, - span_data=span_data, - ) - - def shutdown(self) -> None: - if self._disabled: - return - - try: - logger.debug("Shutting down trace provider") - self._multi_processor.shutdown() - except Exception as e: - logger.error(f"Error shutting down trace provider: {e}") - - -GLOBAL_TRACE_PROVIDER = TraceProvider() +GLOBAL_TRACE_PROVIDER: TraceProvider | None = None def set_trace_provider(provider: TraceProvider) -> None: """Set the global trace provider used by tracing utilities.""" global GLOBAL_TRACE_PROVIDER GLOBAL_TRACE_PROVIDER = provider + + +def get_trace_provider() -> TraceProvider: + """Get the global trace provider used by tracing utilities.""" + if GLOBAL_TRACE_PROVIDER is None: + raise RuntimeError("Trace provider not set") + return GLOBAL_TRACE_PROVIDER diff --git a/src/agents/tracing/spans.py b/src/agents/tracing/spans.py index ee933e73..129c468d 100644 --- a/src/agents/tracing/spans.py +++ b/src/agents/tracing/spans.py @@ -178,7 +178,7 @@ def __init__( span_data: TSpanData, ): self._trace_id = trace_id - self._span_id = span_id or util.gen_span_id() + self._span_id = span_id self._parent_id = parent_id self._started_at: str | None = None self._ended_at: str | None = None diff --git a/src/agents/tracing/util.py b/src/agents/tracing/util.py index af2f5ff3..7f436d01 100644 --- a/src/agents/tracing/util.py +++ b/src/agents/tracing/util.py @@ -1,21 +1,21 @@ -from .setup import GLOBAL_TRACE_PROVIDER +from .setup import get_trace_provider def time_iso() -> str: """Return the current time in ISO 8601 format.""" - return GLOBAL_TRACE_PROVIDER.time_iso() + return get_trace_provider().time_iso() def gen_trace_id() -> str: """Generate a new trace ID.""" - return GLOBAL_TRACE_PROVIDER.gen_trace_id() + return get_trace_provider().gen_trace_id() def gen_span_id() -> str: """Generate a new span ID.""" - return GLOBAL_TRACE_PROVIDER.gen_span_id() + return get_trace_provider().gen_span_id() def gen_group_id() -> str: """Generate a new group ID.""" - return GLOBAL_TRACE_PROVIDER.gen_group_id() + return get_trace_provider().gen_group_id() diff --git a/tests/conftest.py b/tests/conftest.py index 622b61b1..f87e8559 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,7 +7,7 @@ from agents.models.openai_responses import OpenAIResponsesModel from agents.run import set_default_runner from agents.tracing import set_trace_processors -from agents.tracing.setup import GLOBAL_TRACE_PROVIDER +from agents.tracing.setup import get_trace_provider from .testing_processor import SPAN_PROCESSOR_TESTING @@ -43,7 +43,7 @@ def clear_default_runner(): @pytest.fixture(autouse=True, scope="session") def shutdown_trace_provider(): yield - GLOBAL_TRACE_PROVIDER.shutdown() + get_trace_provider().shutdown() @pytest.fixture(autouse=True) From 27646ee1c31dc191f1cf1268563de8ca6c6f111b Mon Sep 17 00:00:00 2001 From: pakrym-oai Date: Wed, 21 May 2025 09:54:06 -0700 Subject: [PATCH 8/8] LLM: tracing/spans: generate span_id if not provided in SpanImpl constructor --- src/agents/tracing/spans.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/agents/tracing/spans.py b/src/agents/tracing/spans.py index 129c468d..ee933e73 100644 --- a/src/agents/tracing/spans.py +++ b/src/agents/tracing/spans.py @@ -178,7 +178,7 @@ def __init__( span_data: TSpanData, ): self._trace_id = trace_id - self._span_id = span_id + self._span_id = span_id or util.gen_span_id() self._parent_id = parent_id self._started_at: str | None = None self._ended_at: str | None = None