Skip to content

Make Runner an abstract base class #720

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from .models.openai_provider import OpenAIProvider
from .models.openai_responses import OpenAIResponsesModel
from .result import RunResult, RunResultStreaming
from .run import RunConfig, Runner
from .run import DefaultRunner, RunConfig, Runner, set_default_runner
from .run_context import RunContextWrapper, TContext
from .stream_events import (
AgentUpdatedStreamEvent,
Expand Down Expand Up @@ -92,6 +92,7 @@
handoff_span,
mcp_tools_span,
set_trace_processors,
set_trace_provider,
set_tracing_disabled,
set_tracing_export_api_key,
speech_group_span,
Expand Down Expand Up @@ -150,6 +151,7 @@ def enable_verbose_stdout_logging():
"ToolsToFinalOutputFunction",
"ToolsToFinalOutputResult",
"Runner",
"DefaultRunner",
"Model",
"ModelProvider",
"ModelTracing",
Expand Down Expand Up @@ -220,6 +222,7 @@ def enable_verbose_stdout_logging():
"guardrail_span",
"handoff_span",
"set_trace_processors",
"set_trace_provider",
"set_tracing_disabled",
"speech_group_span",
"transcription_span",
Expand All @@ -244,6 +247,7 @@ def enable_verbose_stdout_logging():
"set_default_openai_key",
"set_default_openai_client",
"set_default_openai_api",
"set_default_runner",
"set_tracing_export_api_key",
"enable_verbose_stdout_logging",
"gen_trace_id",
Expand Down
335 changes: 220 additions & 115 deletions src/agents/run.py

Large diffs are not rendered by default.

14 changes: 9 additions & 5 deletions src/agents/tracing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import atexit

from agents.tracing.provider import TraceProvider

from .create import (
agent_span,
custom_span,
Expand All @@ -18,7 +20,7 @@
)
from .processor_interface import TracingProcessor
from .processors import default_exporter, default_processor
from .setup import GLOBAL_TRACE_PROVIDER
from .setup import get_trace_provider, set_trace_provider
from .span_data import (
AgentSpanData,
CustomSpanData,
Expand Down Expand Up @@ -49,6 +51,7 @@
"handoff_span",
"response_span",
"set_trace_processors",
"set_trace_provider",
"set_tracing_disabled",
"trace",
"Trace",
Expand Down Expand Up @@ -80,21 +83,21 @@ def add_trace_processor(span_processor: TracingProcessor) -> None:
"""
Adds a new trace processor. This processor will receive all traces/spans.
"""
GLOBAL_TRACE_PROVIDER.register_processor(span_processor)
get_trace_provider().register_processor(span_processor)


def set_trace_processors(processors: list[TracingProcessor]) -> None:
"""
Set the list of trace processors. This will replace the current list of processors.
"""
GLOBAL_TRACE_PROVIDER.set_processors(processors)
get_trace_provider().set_processors(processors)


def set_tracing_disabled(disabled: bool) -> None:
"""
Set whether tracing is globally disabled.
"""
GLOBAL_TRACE_PROVIDER.set_disabled(disabled)
get_trace_provider().set_disabled(disabled)


def set_tracing_export_api_key(api_key: str) -> None:
Expand All @@ -104,10 +107,11 @@ def set_tracing_export_api_key(api_key: str) -> None:
default_exporter().set_api_key(api_key)


set_trace_provider(TraceProvider())
# Add the default processor, which exports traces and spans to the backend in batches. You can
# change the default behavior by either:
# 1. calling add_trace_processor(), which adds additional processors, or
# 2. calling set_trace_processors(), which replaces the default processor.
add_trace_processor(default_processor())

atexit.register(GLOBAL_TRACE_PROVIDER.shutdown)
atexit.register(get_trace_provider().shutdown)
32 changes: 16 additions & 16 deletions src/agents/tracing/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import TYPE_CHECKING, Any

from ..logger import logger
from .setup import GLOBAL_TRACE_PROVIDER
from .setup import get_trace_provider
from .span_data import (
AgentSpanData,
CustomSpanData,
Expand Down Expand Up @@ -56,13 +56,13 @@ def trace(
Returns:
The newly created trace object.
"""
current_trace = GLOBAL_TRACE_PROVIDER.get_current_trace()
current_trace = get_trace_provider().get_current_trace()
if current_trace:
logger.warning(
"Trace already exists. Creating a new trace, but this is probably a mistake."
)

return GLOBAL_TRACE_PROVIDER.create_trace(
return get_trace_provider().create_trace(
name=workflow_name,
trace_id=trace_id,
group_id=group_id,
Expand All @@ -73,12 +73,12 @@ def trace(

def get_current_trace() -> Trace | None:
"""Returns the currently active trace, if present."""
return GLOBAL_TRACE_PROVIDER.get_current_trace()
return get_trace_provider().get_current_trace()


def get_current_span() -> Span[Any] | None:
"""Returns the currently active span, if present."""
return GLOBAL_TRACE_PROVIDER.get_current_span()
return get_trace_provider().get_current_span()


def agent_span(
Expand Down Expand Up @@ -108,7 +108,7 @@ def agent_span(
Returns:
The newly created agent span.
"""
return GLOBAL_TRACE_PROVIDER.create_span(
return get_trace_provider().create_span(
span_data=AgentSpanData(name=name, handoffs=handoffs, tools=tools, output_type=output_type),
span_id=span_id,
parent=parent,
Expand Down Expand Up @@ -141,7 +141,7 @@ def function_span(
Returns:
The newly created function span.
"""
return GLOBAL_TRACE_PROVIDER.create_span(
return get_trace_provider().create_span(
span_data=FunctionSpanData(name=name, input=input, output=output),
span_id=span_id,
parent=parent,
Expand Down Expand Up @@ -183,7 +183,7 @@ def generation_span(
Returns:
The newly created generation span.
"""
return GLOBAL_TRACE_PROVIDER.create_span(
return get_trace_provider().create_span(
span_data=GenerationSpanData(
input=input,
output=output,
Expand Down Expand Up @@ -215,7 +215,7 @@ def response_span(
trace/span as the parent.
disabled: If True, we will return a Span but the Span will not be recorded.
"""
return GLOBAL_TRACE_PROVIDER.create_span(
return get_trace_provider().create_span(
span_data=ResponseSpanData(response=response),
span_id=span_id,
parent=parent,
Expand Down Expand Up @@ -246,7 +246,7 @@ def handoff_span(
Returns:
The newly created handoff span.
"""
return GLOBAL_TRACE_PROVIDER.create_span(
return get_trace_provider().create_span(
span_data=HandoffSpanData(from_agent=from_agent, to_agent=to_agent),
span_id=span_id,
parent=parent,
Expand Down Expand Up @@ -278,7 +278,7 @@ def custom_span(
Returns:
The newly created custom span.
"""
return GLOBAL_TRACE_PROVIDER.create_span(
return get_trace_provider().create_span(
span_data=CustomSpanData(name=name, data=data or {}),
span_id=span_id,
parent=parent,
Expand Down Expand Up @@ -306,7 +306,7 @@ def guardrail_span(
trace/span as the parent.
disabled: If True, we will return a Span but the Span will not be recorded.
"""
return GLOBAL_TRACE_PROVIDER.create_span(
return get_trace_provider().create_span(
span_data=GuardrailSpanData(name=name, triggered=triggered),
span_id=span_id,
parent=parent,
Expand Down Expand Up @@ -344,7 +344,7 @@ def transcription_span(
Returns:
The newly created speech-to-text span.
"""
return GLOBAL_TRACE_PROVIDER.create_span(
return get_trace_provider().create_span(
span_data=TranscriptionSpanData(
input=input,
input_format=input_format,
Expand Down Expand Up @@ -386,7 +386,7 @@ def speech_span(
trace/span as the parent.
disabled: If True, we will return a Span but the Span will not be recorded.
"""
return GLOBAL_TRACE_PROVIDER.create_span(
return get_trace_provider().create_span(
span_data=SpeechSpanData(
model=model,
input=input,
Expand Down Expand Up @@ -419,7 +419,7 @@ def speech_group_span(
trace/span as the parent.
disabled: If True, we will return a Span but the Span will not be recorded.
"""
return GLOBAL_TRACE_PROVIDER.create_span(
return get_trace_provider().create_span(
span_data=SpeechGroupSpanData(input=input),
span_id=span_id,
parent=parent,
Expand Down Expand Up @@ -447,7 +447,7 @@ def mcp_tools_span(
trace/span as the parent.
disabled: If True, we will return a Span but the Span will not be recorded.
"""
return GLOBAL_TRACE_PROVIDER.create_span(
return get_trace_provider().create_span(
span_data=MCPListToolsSpanData(server=server, result=result),
span_id=span_id,
parent=parent,
Expand Down
Loading