diff --git a/src/agents/__init__.py b/src/agents/__init__.py index 6d7c90b4..65e44244 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -45,7 +45,7 @@ from .models.openai_provider import OpenAIProvider from .models.openai_responses import OpenAIResponsesModel from .result import RunResult, RunResultStreaming -from .run import RunConfig, Runner +from .run import DefaultRunner, RunConfig, Runner, set_default_runner from .run_context import RunContextWrapper, TContext from .stream_events import ( AgentUpdatedStreamEvent, @@ -92,6 +92,7 @@ handoff_span, mcp_tools_span, set_trace_processors, + set_trace_provider, set_tracing_disabled, set_tracing_export_api_key, speech_group_span, @@ -150,6 +151,7 @@ def enable_verbose_stdout_logging(): "ToolsToFinalOutputFunction", "ToolsToFinalOutputResult", "Runner", + "DefaultRunner", "Model", "ModelProvider", "ModelTracing", @@ -220,6 +222,7 @@ def enable_verbose_stdout_logging(): "guardrail_span", "handoff_span", "set_trace_processors", + "set_trace_provider", "set_tracing_disabled", "speech_group_span", "transcription_span", @@ -244,6 +247,7 @@ def enable_verbose_stdout_logging(): "set_default_openai_key", "set_default_openai_client", "set_default_openai_api", + "set_default_runner", "set_tracing_export_api_key", "enable_verbose_stdout_logging", "gen_trace_id", diff --git a/src/agents/run.py b/src/agents/run.py index 849da7bf..045f6f0b 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -1,5 +1,6 @@ from __future__ import annotations +import abc import asyncio import copy from dataclasses import dataclass, field @@ -45,6 +46,15 @@ from .util import _coro, _error_tracing DEFAULT_MAX_TURNS = 10 +DEFAULT_RUNNER: Runner | None = None + + +def set_default_runner(runner: Runner | None) -> None: + """ + Set the default runner to use for the agent run. + """ + global DEFAULT_RUNNER + DEFAULT_RUNNER = runner @dataclass @@ -106,7 +116,48 @@ class RunConfig: """ -class Runner: +class Runner(abc.ABC): + @abc.abstractmethod + async def _run_impl( + self, + starting_agent: Agent[TContext], + input: str | list[TResponseInputItem], + *, + context: TContext | None = None, + max_turns: int = DEFAULT_MAX_TURNS, + hooks: RunHooks[TContext] | None = None, + run_config: RunConfig | None = None, + previous_response_id: str | None = None, + ) -> RunResult: + pass + + @abc.abstractmethod + def _run_sync_impl( + self, + starting_agent: Agent[TContext], + input: str | list[TResponseInputItem], + *, + context: TContext | None = None, + max_turns: int = DEFAULT_MAX_TURNS, + hooks: RunHooks[TContext] | None = None, + run_config: RunConfig | None = None, + previous_response_id: str | None = None, + ) -> RunResult: + pass + + @abc.abstractmethod + def _run_streamed_impl( + self, + starting_agent: Agent[TContext], + input: str | list[TResponseInputItem], + context: TContext | None = None, + max_turns: int = DEFAULT_MAX_TURNS, + hooks: RunHooks[TContext] | None = None, + run_config: RunConfig | None = None, + previous_response_id: str | None = None, + ) -> RunResultStreaming: + pass + @classmethod async def run( cls, @@ -126,13 +177,10 @@ async def run( `agent.output_type`, the loop terminates. 3. If there's a handoff, we run the loop again, with the new agent. 4. Else, we run tool calls (if any), and re-run the loop. - In two cases, the agent may raise an exception: 1. If the max_turns is exceeded, a MaxTurnsExceeded exception is raised. 2. If a guardrail tripwire is triggered, a GuardrailTripwireTriggered exception is raised. - Note that only the first agent's input guardrails are run. - Args: starting_agent: The starting agent to run. input: The initial input to the agent. You can pass a single string for a user message, @@ -144,11 +192,164 @@ async def run( run_config: Global settings for the entire agent run. previous_response_id: The ID of the previous response, if using OpenAI models via the Responses API, this allows you to skip passing in input from the previous turn. + Returns: + A run result containing all the inputs, guardrail results and the output of the last + agent. Agents may perform handoffs, so we don't know the specific type of the output. + """ + runner = DEFAULT_RUNNER or DefaultRunner() + return await runner._run_impl( + starting_agent, + input, + context=context, + max_turns=max_turns, + hooks=hooks, + run_config=run_config, + previous_response_id=previous_response_id, + ) + @classmethod + def run_sync( + cls, + starting_agent: Agent[TContext], + input: str | list[TResponseInputItem], + *, + context: TContext | None = None, + max_turns: int = DEFAULT_MAX_TURNS, + hooks: RunHooks[TContext] | None = None, + run_config: RunConfig | None = None, + previous_response_id: str | None = None, + ) -> RunResult: + """Run a workflow synchronously, starting at the given agent. Note that this just wraps the + `run` method, so it will not work if there's already an event loop (e.g. inside an async + function, or in a Jupyter notebook or async context like FastAPI). For those cases, use + the `run` method instead. + The agent will run in a loop until a final output is generated. The loop runs like so: + 1. The agent is invoked with the given input. + 2. If there is a final output (i.e. the agent produces something of type + `agent.output_type`, the loop terminates. + 3. If there's a handoff, we run the loop again, with the new agent. + 4. Else, we run tool calls (if any), and re-run the loop. + In two cases, the agent may raise an exception: + 1. If the max_turns is exceeded, a MaxTurnsExceeded exception is raised. + 2. If a guardrail tripwire is triggered, a GuardrailTripwireTriggered exception is raised. + Note that only the first agent's input guardrails are run. + Args: + starting_agent: The starting agent to run. + input: The initial input to the agent. You can pass a single string for a user message, + or a list of input items. + context: The context to run the agent with. + max_turns: The maximum number of turns to run the agent for. A turn is defined as one + AI invocation (including any tool calls that might occur). + hooks: An object that receives callbacks on various lifecycle events. + run_config: Global settings for the entire agent run. + previous_response_id: The ID of the previous response, if using OpenAI models via the + Responses API, this allows you to skip passing in input from the previous turn. Returns: A run result containing all the inputs, guardrail results and the output of the last agent. Agents may perform handoffs, so we don't know the specific type of the output. """ + runner = DEFAULT_RUNNER or DefaultRunner() + return runner._run_sync_impl( + starting_agent, + input, + context=context, + max_turns=max_turns, + hooks=hooks, + run_config=run_config, + previous_response_id=previous_response_id, + ) + + @classmethod + def run_streamed( + cls, + starting_agent: Agent[TContext], + input: str | list[TResponseInputItem], + context: TContext | None = None, + max_turns: int = DEFAULT_MAX_TURNS, + hooks: RunHooks[TContext] | None = None, + run_config: RunConfig | None = None, + previous_response_id: str | None = None, + ) -> RunResultStreaming: + """Run a workflow starting at the given agent in streaming mode. The returned result object + contains a method you can use to stream semantic events as they are generated. + The agent will run in a loop until a final output is generated. The loop runs like so: + 1. The agent is invoked with the given input. + 2. If there is a final output (i.e. the agent produces something of type + `agent.output_type`, the loop terminates. + 3. If there's a handoff, we run the loop again, with the new agent. + 4. Else, we run tool calls (if any), and re-run the loop. + In two cases, the agent may raise an exception: + 1. If the max_turns is exceeded, a MaxTurnsExceeded exception is raised. + 2. If a guardrail tripwire is triggered, a GuardrailTripwireTriggered exception is raised. + Note that only the first agent's input guardrails are run. + Args: + starting_agent: The starting agent to run. + input: The initial input to the agent. You can pass a single string for a user message, + or a list of input items. + context: The context to run the agent with. + max_turns: The maximum number of turns to run the agent for. A turn is defined as one + AI invocation (including any tool calls that might occur). + hooks: An object that receives callbacks on various lifecycle events. + run_config: Global settings for the entire agent run. + previous_response_id: The ID of the previous response, if using OpenAI models via the + Responses API, this allows you to skip passing in input from the previous turn. + Returns: + A result object that contains data about the run, as well as a method to stream events. + """ + runner = DEFAULT_RUNNER or DefaultRunner() + return runner._run_streamed_impl( + starting_agent, + input, + context=context, + max_turns=max_turns, + hooks=hooks, + run_config=run_config, + previous_response_id=previous_response_id, + ) + + @classmethod + def _get_output_schema(cls, agent: Agent[Any]) -> AgentOutputSchemaBase | None: + if agent.output_type is None or agent.output_type is str: + return None + elif isinstance(agent.output_type, AgentOutputSchemaBase): + return agent.output_type + + return AgentOutputSchema(agent.output_type) + + @classmethod + def _get_handoffs(cls, agent: Agent[Any]) -> list[Handoff]: + handoffs = [] + for handoff_item in agent.handoffs: + if isinstance(handoff_item, Handoff): + handoffs.append(handoff_item) + elif isinstance(handoff_item, Agent): + handoffs.append(handoff(handoff_item)) + return handoffs + + @classmethod + def _get_model(cls, agent: Agent[Any], run_config: RunConfig) -> Model: + if isinstance(run_config.model, Model): + return run_config.model + elif isinstance(run_config.model, str): + return run_config.model_provider.get_model(run_config.model) + elif isinstance(agent.model, Model): + return agent.model + + return run_config.model_provider.get_model(agent.model) + + +class DefaultRunner(Runner): + async def _run_impl( + self, + starting_agent: Agent[TContext], + input: str | list[TResponseInputItem], + *, + context: TContext | None = None, + max_turns: int = DEFAULT_MAX_TURNS, + hooks: RunHooks[TContext] | None = None, + run_config: RunConfig | None = None, + previous_response_id: str | None = None, + ) -> RunResult: if hooks is None: hooks = RunHooks[Any]() if run_config is None: @@ -183,8 +384,8 @@ async def run( # Start an agent span if we don't have one. This span is ended if the current # agent changes, or if the agent loop ends. if current_span is None: - handoff_names = [h.agent_name for h in cls._get_handoffs(current_agent)] - if output_schema := cls._get_output_schema(current_agent): + handoff_names = [h.agent_name for h in self._get_handoffs(current_agent)] + if output_schema := self._get_output_schema(current_agent): output_type_name = output_schema.name() else: output_type_name = "str" @@ -196,7 +397,7 @@ async def run( ) current_span.start(mark_as_current=True) - all_tools = await cls._get_all_tools(current_agent) + all_tools = await self._get_all_tools(current_agent) current_span.span_data.tools = [t.name for t in all_tools] current_turn += 1 @@ -216,14 +417,14 @@ async def run( if current_turn == 1: input_guardrail_results, turn_result = await asyncio.gather( - cls._run_input_guardrails( + self._run_input_guardrails( starting_agent, starting_agent.input_guardrails + (run_config.input_guardrails or []), copy.deepcopy(input), context_wrapper, ), - cls._run_single_turn( + self._run_single_turn( agent=current_agent, all_tools=all_tools, original_input=original_input, @@ -237,7 +438,7 @@ async def run( ), ) else: - turn_result = await cls._run_single_turn( + turn_result = await self._run_single_turn( agent=current_agent, all_tools=all_tools, original_input=original_input, @@ -256,7 +457,7 @@ async def run( generated_items = turn_result.generated_items if isinstance(turn_result.next_step, NextStepFinalOutput): - output_guardrail_results = await cls._run_output_guardrails( + output_guardrail_results = await self._run_output_guardrails( current_agent.output_guardrails + (run_config.output_guardrails or []), current_agent, turn_result.next_step.output, @@ -287,9 +488,8 @@ async def run( if current_span: current_span.finish(reset_current=True) - @classmethod - def run_sync( - cls, + def _run_sync_impl( + self, starting_agent: Agent[TContext], input: str | list[TResponseInputItem], *, @@ -299,42 +499,8 @@ def run_sync( run_config: RunConfig | None = None, previous_response_id: str | None = None, ) -> RunResult: - """Run a workflow synchronously, starting at the given agent. Note that this just wraps the - `run` method, so it will not work if there's already an event loop (e.g. inside an async - function, or in a Jupyter notebook or async context like FastAPI). For those cases, use - the `run` method instead. - - The agent will run in a loop until a final output is generated. The loop runs like so: - 1. The agent is invoked with the given input. - 2. If there is a final output (i.e. the agent produces something of type - `agent.output_type`, the loop terminates. - 3. If there's a handoff, we run the loop again, with the new agent. - 4. Else, we run tool calls (if any), and re-run the loop. - - In two cases, the agent may raise an exception: - 1. If the max_turns is exceeded, a MaxTurnsExceeded exception is raised. - 2. If a guardrail tripwire is triggered, a GuardrailTripwireTriggered exception is raised. - - Note that only the first agent's input guardrails are run. - - Args: - starting_agent: The starting agent to run. - input: The initial input to the agent. You can pass a single string for a user message, - or a list of input items. - context: The context to run the agent with. - max_turns: The maximum number of turns to run the agent for. A turn is defined as one - AI invocation (including any tool calls that might occur). - hooks: An object that receives callbacks on various lifecycle events. - run_config: Global settings for the entire agent run. - previous_response_id: The ID of the previous response, if using OpenAI models via the - Responses API, this allows you to skip passing in input from the previous turn. - - Returns: - A run result containing all the inputs, guardrail results and the output of the last - agent. Agents may perform handoffs, so we don't know the specific type of the output. - """ return asyncio.get_event_loop().run_until_complete( - cls.run( + self.run( starting_agent, input, context=context, @@ -345,9 +511,8 @@ def run_sync( ) ) - @classmethod - def run_streamed( - cls, + def _run_streamed_impl( + self, starting_agent: Agent[TContext], input: str | list[TResponseInputItem], context: TContext | None = None, @@ -356,36 +521,6 @@ def run_streamed( run_config: RunConfig | None = None, previous_response_id: str | None = None, ) -> RunResultStreaming: - """Run a workflow starting at the given agent in streaming mode. The returned result object - contains a method you can use to stream semantic events as they are generated. - - The agent will run in a loop until a final output is generated. The loop runs like so: - 1. The agent is invoked with the given input. - 2. If there is a final output (i.e. the agent produces something of type - `agent.output_type`, the loop terminates. - 3. If there's a handoff, we run the loop again, with the new agent. - 4. Else, we run tool calls (if any), and re-run the loop. - - In two cases, the agent may raise an exception: - 1. If the max_turns is exceeded, a MaxTurnsExceeded exception is raised. - 2. If a guardrail tripwire is triggered, a GuardrailTripwireTriggered exception is raised. - - Note that only the first agent's input guardrails are run. - - Args: - starting_agent: The starting agent to run. - input: The initial input to the agent. You can pass a single string for a user message, - or a list of input items. - context: The context to run the agent with. - max_turns: The maximum number of turns to run the agent for. A turn is defined as one - AI invocation (including any tool calls that might occur). - hooks: An object that receives callbacks on various lifecycle events. - run_config: Global settings for the entire agent run. - previous_response_id: The ID of the previous response, if using OpenAI models via the - Responses API, this allows you to skip passing in input from the previous turn. - Returns: - A result object that contains data about the run, as well as a method to stream events. - """ if hooks is None: hooks = RunHooks[Any]() if run_config is None: @@ -406,7 +541,7 @@ def run_streamed( ) ) - output_schema = cls._get_output_schema(starting_agent) + output_schema = self._get_output_schema(starting_agent) context_wrapper: RunContextWrapper[TContext] = RunContextWrapper( context=context # type: ignore ) @@ -429,7 +564,7 @@ def run_streamed( # Kick off the actual agent loop in the background and return the streamed result object. streamed_result._run_impl_task = asyncio.create_task( - cls._run_streamed_impl( + self._start_streaming( starting_input=input, streamed_result=streamed_result, starting_agent=starting_agent, @@ -486,7 +621,7 @@ async def _run_input_guardrails_with_queue( streamed_result.input_guardrail_results = guardrail_results @classmethod - async def _run_streamed_impl( + async def _start_streaming( cls, starting_input: str | list[TResponseInputItem], streamed_result: RunResultStreaming, @@ -933,36 +1068,6 @@ async def _get_new_response( return new_response - @classmethod - def _get_output_schema(cls, agent: Agent[Any]) -> AgentOutputSchemaBase | None: - if agent.output_type is None or agent.output_type is str: - return None - elif isinstance(agent.output_type, AgentOutputSchemaBase): - return agent.output_type - - return AgentOutputSchema(agent.output_type) - - @classmethod - def _get_handoffs(cls, agent: Agent[Any]) -> list[Handoff]: - handoffs = [] - for handoff_item in agent.handoffs: - if isinstance(handoff_item, Handoff): - handoffs.append(handoff_item) - elif isinstance(handoff_item, Agent): - handoffs.append(handoff(handoff_item)) - return handoffs - @classmethod async def _get_all_tools(cls, agent: Agent[Any]) -> list[Tool]: return await agent.get_all_tools() - - @classmethod - def _get_model(cls, agent: Agent[Any], run_config: RunConfig) -> Model: - if isinstance(run_config.model, Model): - return run_config.model - elif isinstance(run_config.model, str): - return run_config.model_provider.get_model(run_config.model) - elif isinstance(agent.model, Model): - return agent.model - - return run_config.model_provider.get_model(agent.model) diff --git a/src/agents/tracing/__init__.py b/src/agents/tracing/__init__.py index 9df94426..4281c29f 100644 --- a/src/agents/tracing/__init__.py +++ b/src/agents/tracing/__init__.py @@ -1,5 +1,7 @@ import atexit +from agents.tracing.provider import TraceProvider + from .create import ( agent_span, custom_span, @@ -18,7 +20,7 @@ ) from .processor_interface import TracingProcessor from .processors import default_exporter, default_processor -from .setup import GLOBAL_TRACE_PROVIDER +from .setup import get_trace_provider, set_trace_provider from .span_data import ( AgentSpanData, CustomSpanData, @@ -49,6 +51,7 @@ "handoff_span", "response_span", "set_trace_processors", + "set_trace_provider", "set_tracing_disabled", "trace", "Trace", @@ -80,21 +83,21 @@ def add_trace_processor(span_processor: TracingProcessor) -> None: """ Adds a new trace processor. This processor will receive all traces/spans. """ - GLOBAL_TRACE_PROVIDER.register_processor(span_processor) + get_trace_provider().register_processor(span_processor) def set_trace_processors(processors: list[TracingProcessor]) -> None: """ Set the list of trace processors. This will replace the current list of processors. """ - GLOBAL_TRACE_PROVIDER.set_processors(processors) + get_trace_provider().set_processors(processors) def set_tracing_disabled(disabled: bool) -> None: """ Set whether tracing is globally disabled. """ - GLOBAL_TRACE_PROVIDER.set_disabled(disabled) + get_trace_provider().set_disabled(disabled) def set_tracing_export_api_key(api_key: str) -> None: @@ -104,10 +107,11 @@ def set_tracing_export_api_key(api_key: str) -> None: default_exporter().set_api_key(api_key) +set_trace_provider(TraceProvider()) # Add the default processor, which exports traces and spans to the backend in batches. You can # change the default behavior by either: # 1. calling add_trace_processor(), which adds additional processors, or # 2. calling set_trace_processors(), which replaces the default processor. add_trace_processor(default_processor()) -atexit.register(GLOBAL_TRACE_PROVIDER.shutdown) +atexit.register(get_trace_provider().shutdown) diff --git a/src/agents/tracing/create.py b/src/agents/tracing/create.py index b6fe4610..ac451abf 100644 --- a/src/agents/tracing/create.py +++ b/src/agents/tracing/create.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Any from ..logger import logger -from .setup import GLOBAL_TRACE_PROVIDER +from .setup import get_trace_provider from .span_data import ( AgentSpanData, CustomSpanData, @@ -56,13 +56,13 @@ def trace( Returns: The newly created trace object. """ - current_trace = GLOBAL_TRACE_PROVIDER.get_current_trace() + current_trace = get_trace_provider().get_current_trace() if current_trace: logger.warning( "Trace already exists. Creating a new trace, but this is probably a mistake." ) - return GLOBAL_TRACE_PROVIDER.create_trace( + return get_trace_provider().create_trace( name=workflow_name, trace_id=trace_id, group_id=group_id, @@ -73,12 +73,12 @@ def trace( def get_current_trace() -> Trace | None: """Returns the currently active trace, if present.""" - return GLOBAL_TRACE_PROVIDER.get_current_trace() + return get_trace_provider().get_current_trace() def get_current_span() -> Span[Any] | None: """Returns the currently active span, if present.""" - return GLOBAL_TRACE_PROVIDER.get_current_span() + return get_trace_provider().get_current_span() def agent_span( @@ -108,7 +108,7 @@ def agent_span( Returns: The newly created agent span. """ - return GLOBAL_TRACE_PROVIDER.create_span( + return get_trace_provider().create_span( span_data=AgentSpanData(name=name, handoffs=handoffs, tools=tools, output_type=output_type), span_id=span_id, parent=parent, @@ -141,7 +141,7 @@ def function_span( Returns: The newly created function span. """ - return GLOBAL_TRACE_PROVIDER.create_span( + return get_trace_provider().create_span( span_data=FunctionSpanData(name=name, input=input, output=output), span_id=span_id, parent=parent, @@ -183,7 +183,7 @@ def generation_span( Returns: The newly created generation span. """ - return GLOBAL_TRACE_PROVIDER.create_span( + return get_trace_provider().create_span( span_data=GenerationSpanData( input=input, output=output, @@ -215,7 +215,7 @@ def response_span( trace/span as the parent. disabled: If True, we will return a Span but the Span will not be recorded. """ - return GLOBAL_TRACE_PROVIDER.create_span( + return get_trace_provider().create_span( span_data=ResponseSpanData(response=response), span_id=span_id, parent=parent, @@ -246,7 +246,7 @@ def handoff_span( Returns: The newly created handoff span. """ - return GLOBAL_TRACE_PROVIDER.create_span( + return get_trace_provider().create_span( span_data=HandoffSpanData(from_agent=from_agent, to_agent=to_agent), span_id=span_id, parent=parent, @@ -278,7 +278,7 @@ def custom_span( Returns: The newly created custom span. """ - return GLOBAL_TRACE_PROVIDER.create_span( + return get_trace_provider().create_span( span_data=CustomSpanData(name=name, data=data or {}), span_id=span_id, parent=parent, @@ -306,7 +306,7 @@ def guardrail_span( trace/span as the parent. disabled: If True, we will return a Span but the Span will not be recorded. """ - return GLOBAL_TRACE_PROVIDER.create_span( + return get_trace_provider().create_span( span_data=GuardrailSpanData(name=name, triggered=triggered), span_id=span_id, parent=parent, @@ -344,7 +344,7 @@ def transcription_span( Returns: The newly created speech-to-text span. """ - return GLOBAL_TRACE_PROVIDER.create_span( + return get_trace_provider().create_span( span_data=TranscriptionSpanData( input=input, input_format=input_format, @@ -386,7 +386,7 @@ def speech_span( trace/span as the parent. disabled: If True, we will return a Span but the Span will not be recorded. """ - return GLOBAL_TRACE_PROVIDER.create_span( + return get_trace_provider().create_span( span_data=SpeechSpanData( model=model, input=input, @@ -419,7 +419,7 @@ def speech_group_span( trace/span as the parent. disabled: If True, we will return a Span but the Span will not be recorded. """ - return GLOBAL_TRACE_PROVIDER.create_span( + return get_trace_provider().create_span( span_data=SpeechGroupSpanData(input=input), span_id=span_id, parent=parent, @@ -447,7 +447,7 @@ def mcp_tools_span( trace/span as the parent. disabled: If True, we will return a Span but the Span will not be recorded. """ - return GLOBAL_TRACE_PROVIDER.create_span( + return get_trace_provider().create_span( span_data=MCPListToolsSpanData(server=server, result=result), span_id=span_id, parent=parent, diff --git a/src/agents/tracing/provider.py b/src/agents/tracing/provider.py new file mode 100644 index 00000000..b9f4c63e --- /dev/null +++ b/src/agents/tracing/provider.py @@ -0,0 +1,228 @@ +from __future__ import annotations + +import os +import threading +import uuid +from datetime import datetime, timezone +from typing import Any + +from ..logger import logger +from .processor_interface import TracingProcessor +from .scope import Scope +from .spans import NoOpSpan, Span, SpanImpl, TSpanData +from .traces import NoOpTrace, Trace, TraceImpl + + +class SynchronousMultiTracingProcessor(TracingProcessor): + """ + Forwards all calls to a list of TracingProcessors, in order of registration. + """ + + def __init__(self): + # Using a tuple to avoid race conditions when iterating over processors + self._processors: tuple[TracingProcessor, ...] = () + self._lock = threading.Lock() + + def add_tracing_processor(self, tracing_processor: TracingProcessor): + """ + Add a processor to the list of processors. Each processor will receive all traces/spans. + """ + with self._lock: + self._processors += (tracing_processor,) + + def set_processors(self, processors: list[TracingProcessor]): + """ + Set the list of processors. This will replace the current list of processors. + """ + with self._lock: + self._processors = tuple(processors) + + def on_trace_start(self, trace: Trace) -> None: + """ + Called when a trace is started. + """ + for processor in self._processors: + processor.on_trace_start(trace) + + def on_trace_end(self, trace: Trace) -> None: + """ + Called when a trace is finished. + """ + for processor in self._processors: + processor.on_trace_end(trace) + + def on_span_start(self, span: Span[Any]) -> None: + """ + Called when a span is started. + """ + for processor in self._processors: + processor.on_span_start(span) + + def on_span_end(self, span: Span[Any]) -> None: + """ + Called when a span is finished. + """ + for processor in self._processors: + processor.on_span_end(span) + + def shutdown(self) -> None: + """ + Called when the application stops. + """ + for processor in self._processors: + logger.debug(f"Shutting down trace processor {processor}") + processor.shutdown() + + def force_flush(self): + """ + Force the processors to flush their buffers. + """ + for processor in self._processors: + processor.force_flush() + + +class TraceProvider: + def __init__(self): + self._multi_processor = SynchronousMultiTracingProcessor() + self._disabled = os.environ.get("OPENAI_AGENTS_DISABLE_TRACING", "false").lower() in ( + "true", + "1", + ) + + def register_processor(self, processor: TracingProcessor): + """ + Add a processor to the list of processors. Each processor will receive all traces/spans. + """ + self._multi_processor.add_tracing_processor(processor) + + def set_processors(self, processors: list[TracingProcessor]): + """ + Set the list of processors. This will replace the current list of processors. + """ + self._multi_processor.set_processors(processors) + + def get_current_trace(self) -> Trace | None: + """ + Returns the currently active trace, if any. + """ + return Scope.get_current_trace() + + def get_current_span(self) -> Span[Any] | None: + """ + Returns the currently active span, if any. + """ + return Scope.get_current_span() + + def set_disabled(self, disabled: bool) -> None: + """ + Set whether tracing is disabled. + """ + self._disabled = disabled + + def time_iso(self) -> str: + """Return the current time in ISO 8601 format.""" + return datetime.now(timezone.utc).isoformat() + + def gen_trace_id(self) -> str: + """Generate a new trace ID.""" + return f"trace_{uuid.uuid4().hex}" + + def gen_span_id(self) -> str: + """Generate a new span ID.""" + return f"span_{uuid.uuid4().hex[:24]}" + + def gen_group_id(self) -> str: + """Generate a new group ID.""" + return f"group_{uuid.uuid4().hex[:24]}" + + def create_trace( + self, + name: str, + trace_id: str | None = None, + group_id: str | None = None, + metadata: dict[str, Any] | None = None, + disabled: bool = False, + ) -> Trace: + """ + Create a new trace. + """ + if self._disabled or disabled: + logger.debug(f"Tracing is disabled. Not creating trace {name}") + return NoOpTrace() + + trace_id = trace_id or self.gen_trace_id() + + logger.debug(f"Creating trace {name} with id {trace_id}") + + return TraceImpl( + name=name, + trace_id=trace_id, + group_id=group_id, + metadata=metadata, + processor=self._multi_processor, + ) + + def create_span( + self, + span_data: TSpanData, + span_id: str | None = None, + parent: Trace | Span[Any] | None = None, + disabled: bool = False, + ) -> Span[TSpanData]: + """ + Create a new span. + """ + if self._disabled or disabled: + logger.debug(f"Tracing is disabled. Not creating span {span_data}") + return NoOpSpan(span_data) + + if not parent: + current_span = Scope.get_current_span() + current_trace = Scope.get_current_trace() + if current_trace is None: + logger.error( + "No active trace. Make sure to start a trace with `trace()` first" + "Returning NoOpSpan." + ) + return NoOpSpan(span_data) + elif isinstance(current_trace, NoOpTrace) or isinstance(current_span, NoOpSpan): + logger.debug( + f"Parent {current_span} or {current_trace} is no-op, returning NoOpSpan" + ) + return NoOpSpan(span_data) + + parent_id = current_span.span_id if current_span else None + trace_id = current_trace.trace_id + + elif isinstance(parent, Trace): + if isinstance(parent, NoOpTrace): + logger.debug(f"Parent {parent} is no-op, returning NoOpSpan") + return NoOpSpan(span_data) + trace_id = parent.trace_id + parent_id = None + elif isinstance(parent, Span): + if isinstance(parent, NoOpSpan): + logger.debug(f"Parent {parent} is no-op, returning NoOpSpan") + return NoOpSpan(span_data) + parent_id = parent.span_id + trace_id = parent.trace_id + + logger.debug(f"Creating span {span_data} with id {span_id}") + + return SpanImpl( + trace_id=trace_id, + span_id=span_id or self.gen_span_id(), + parent_id=parent_id, + processor=self._multi_processor, + span_data=span_data, + ) + + def shutdown(self) -> None: + if self._disabled: + return + + try: + logger.debug("Shutting down trace provider") + self._multi_processor.shutdown() + except Exception as e: + logger.error(f"Error shutting down trace provider: {e}") diff --git a/src/agents/tracing/setup.py b/src/agents/tracing/setup.py index 9e27d210..3a56b728 100644 --- a/src/agents/tracing/setup.py +++ b/src/agents/tracing/setup.py @@ -1,214 +1,21 @@ from __future__ import annotations -import os -import threading -from typing import Any +from typing import TYPE_CHECKING -from ..logger import logger -from . import util -from .processor_interface import TracingProcessor -from .scope import Scope -from .spans import NoOpSpan, Span, SpanImpl, TSpanData -from .traces import NoOpTrace, Trace, TraceImpl +if TYPE_CHECKING: + from .provider import TraceProvider +GLOBAL_TRACE_PROVIDER: TraceProvider | None = None -class SynchronousMultiTracingProcessor(TracingProcessor): - """ - Forwards all calls to a list of TracingProcessors, in order of registration. - """ - def __init__(self): - # Using a tuple to avoid race conditions when iterating over processors - self._processors: tuple[TracingProcessor, ...] = () - self._lock = threading.Lock() +def set_trace_provider(provider: TraceProvider) -> None: + """Set the global trace provider used by tracing utilities.""" + global GLOBAL_TRACE_PROVIDER + GLOBAL_TRACE_PROVIDER = provider - def add_tracing_processor(self, tracing_processor: TracingProcessor): - """ - Add a processor to the list of processors. Each processor will receive all traces/spans. - """ - with self._lock: - self._processors += (tracing_processor,) - def set_processors(self, processors: list[TracingProcessor]): - """ - Set the list of processors. This will replace the current list of processors. - """ - with self._lock: - self._processors = tuple(processors) - - def on_trace_start(self, trace: Trace) -> None: - """ - Called when a trace is started. - """ - for processor in self._processors: - processor.on_trace_start(trace) - - def on_trace_end(self, trace: Trace) -> None: - """ - Called when a trace is finished. - """ - for processor in self._processors: - processor.on_trace_end(trace) - - def on_span_start(self, span: Span[Any]) -> None: - """ - Called when a span is started. - """ - for processor in self._processors: - processor.on_span_start(span) - - def on_span_end(self, span: Span[Any]) -> None: - """ - Called when a span is finished. - """ - for processor in self._processors: - processor.on_span_end(span) - - def shutdown(self) -> None: - """ - Called when the application stops. - """ - for processor in self._processors: - logger.debug(f"Shutting down trace processor {processor}") - processor.shutdown() - - def force_flush(self): - """ - Force the processors to flush their buffers. - """ - for processor in self._processors: - processor.force_flush() - - -class TraceProvider: - def __init__(self): - self._multi_processor = SynchronousMultiTracingProcessor() - self._disabled = os.environ.get("OPENAI_AGENTS_DISABLE_TRACING", "false").lower() in ( - "true", - "1", - ) - - def register_processor(self, processor: TracingProcessor): - """ - Add a processor to the list of processors. Each processor will receive all traces/spans. - """ - self._multi_processor.add_tracing_processor(processor) - - def set_processors(self, processors: list[TracingProcessor]): - """ - Set the list of processors. This will replace the current list of processors. - """ - self._multi_processor.set_processors(processors) - - def get_current_trace(self) -> Trace | None: - """ - Returns the currently active trace, if any. - """ - return Scope.get_current_trace() - - def get_current_span(self) -> Span[Any] | None: - """ - Returns the currently active span, if any. - """ - return Scope.get_current_span() - - def set_disabled(self, disabled: bool) -> None: - """ - Set whether tracing is disabled. - """ - self._disabled = disabled - - def create_trace( - self, - name: str, - trace_id: str | None = None, - group_id: str | None = None, - metadata: dict[str, Any] | None = None, - disabled: bool = False, - ) -> Trace: - """ - Create a new trace. - """ - if self._disabled or disabled: - logger.debug(f"Tracing is disabled. Not creating trace {name}") - return NoOpTrace() - - trace_id = trace_id or util.gen_trace_id() - - logger.debug(f"Creating trace {name} with id {trace_id}") - - return TraceImpl( - name=name, - trace_id=trace_id, - group_id=group_id, - metadata=metadata, - processor=self._multi_processor, - ) - - def create_span( - self, - span_data: TSpanData, - span_id: str | None = None, - parent: Trace | Span[Any] | None = None, - disabled: bool = False, - ) -> Span[TSpanData]: - """ - Create a new span. - """ - if self._disabled or disabled: - logger.debug(f"Tracing is disabled. Not creating span {span_data}") - return NoOpSpan(span_data) - - if not parent: - current_span = Scope.get_current_span() - current_trace = Scope.get_current_trace() - if current_trace is None: - logger.error( - "No active trace. Make sure to start a trace with `trace()` first" - "Returning NoOpSpan." - ) - return NoOpSpan(span_data) - elif isinstance(current_trace, NoOpTrace) or isinstance(current_span, NoOpSpan): - logger.debug( - f"Parent {current_span} or {current_trace} is no-op, returning NoOpSpan" - ) - return NoOpSpan(span_data) - - parent_id = current_span.span_id if current_span else None - trace_id = current_trace.trace_id - - elif isinstance(parent, Trace): - if isinstance(parent, NoOpTrace): - logger.debug(f"Parent {parent} is no-op, returning NoOpSpan") - return NoOpSpan(span_data) - trace_id = parent.trace_id - parent_id = None - elif isinstance(parent, Span): - if isinstance(parent, NoOpSpan): - logger.debug(f"Parent {parent} is no-op, returning NoOpSpan") - return NoOpSpan(span_data) - parent_id = parent.span_id - trace_id = parent.trace_id - - logger.debug(f"Creating span {span_data} with id {span_id}") - - return SpanImpl( - trace_id=trace_id, - span_id=span_id, - parent_id=parent_id, - processor=self._multi_processor, - span_data=span_data, - ) - - def shutdown(self) -> None: - if self._disabled: - return - - try: - logger.debug("Shutting down trace provider") - self._multi_processor.shutdown() - except Exception as e: - logger.error(f"Error shutting down trace provider: {e}") - - -GLOBAL_TRACE_PROVIDER = TraceProvider() +def get_trace_provider() -> TraceProvider: + """Get the global trace provider used by tracing utilities.""" + if GLOBAL_TRACE_PROVIDER is None: + raise RuntimeError("Trace provider not set") + return GLOBAL_TRACE_PROVIDER diff --git a/src/agents/tracing/util.py b/src/agents/tracing/util.py index f546b4e5..7f436d01 100644 --- a/src/agents/tracing/util.py +++ b/src/agents/tracing/util.py @@ -1,22 +1,21 @@ -import uuid -from datetime import datetime, timezone +from .setup import get_trace_provider def time_iso() -> str: - """Returns the current time in ISO 8601 format.""" - return datetime.now(timezone.utc).isoformat() + """Return the current time in ISO 8601 format.""" + return get_trace_provider().time_iso() def gen_trace_id() -> str: - """Generates a new trace ID.""" - return f"trace_{uuid.uuid4().hex}" + """Generate a new trace ID.""" + return get_trace_provider().gen_trace_id() def gen_span_id() -> str: - """Generates a new span ID.""" - return f"span_{uuid.uuid4().hex[:24]}" + """Generate a new span ID.""" + return get_trace_provider().gen_span_id() def gen_group_id() -> str: - """Generates a new group ID.""" - return f"group_{uuid.uuid4().hex[:24]}" + """Generate a new group ID.""" + return get_trace_provider().gen_group_id() diff --git a/tests/conftest.py b/tests/conftest.py index ba0d8822..f87e8559 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,8 +5,9 @@ from agents.models import _openai_shared from agents.models.openai_chatcompletions import OpenAIChatCompletionsModel from agents.models.openai_responses import OpenAIResponsesModel +from agents.run import set_default_runner from agents.tracing import set_trace_processors -from agents.tracing.setup import GLOBAL_TRACE_PROVIDER +from agents.tracing.setup import get_trace_provider from .testing_processor import SPAN_PROCESSOR_TESTING @@ -33,11 +34,16 @@ def clear_openai_settings(): _openai_shared._use_responses_by_default = True +@pytest.fixture(autouse=True) +def clear_default_runner(): + set_default_runner(None) + + # This fixture will run after all tests end @pytest.fixture(autouse=True, scope="session") def shutdown_trace_provider(): yield - GLOBAL_TRACE_PROVIDER.shutdown() + get_trace_provider().shutdown() @pytest.fixture(autouse=True) diff --git a/tests/test_run.py b/tests/test_run.py new file mode 100644 index 00000000..57e33d50 --- /dev/null +++ b/tests/test_run.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +from unittest import mock + +import pytest + +from agents import Agent, Runner +from agents.run import set_default_runner + +from .fake_model import FakeModel + + +@pytest.mark.asyncio +async def test_static_run_methods_call_into_default_runner() -> None: + runner = mock.Mock(spec=Runner) + set_default_runner(runner) + + agent = Agent(name="test", model=FakeModel()) + await Runner.run(agent, input="test") + runner._run_impl.assert_called_once() + + Runner.run_streamed(agent, input="test") + runner._run_streamed_impl.assert_called_once() + + Runner.run_sync(agent, input="test") + runner._run_sync_impl.assert_called_once()