From 428c9a65bf0c17198d9a2a616159b9eb8badb2b6 Mon Sep 17 00:00:00 2001 From: franz101 Date: Mon, 19 May 2025 21:41:11 +0200 Subject: [PATCH 01/14] Add Galileo to external tracing processors list (#662) --- docs/tracing.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/tracing.md b/docs/tracing.md index dd883c5aa..4a9c1bd90 100644 --- a/docs/tracing.md +++ b/docs/tracing.md @@ -115,3 +115,4 @@ To customize this default setup, to send traces to alternative or additional bac - [Langfuse](https://langfuse.com/docs/integrations/openaiagentssdk/openai-agents) - [Langtrace](https://docs.langtrace.ai/supported-integrations/llm-frameworks/openai-agents-sdk) - [Okahu-Monocle](https://github.com/monocle2ai/monocle) +- [Galileo](https://v2docs.galileo.ai/integrations/openai-agent-integration#openai-agent-integration) From 466b44df180718a5d53c45293db2f57b6e719f95 Mon Sep 17 00:00:00 2001 From: WJPBProjects <76624567+WJPBProjects@users.noreply.github.com> Date: Tue, 20 May 2025 18:23:56 +0100 Subject: [PATCH 02/14] Dev/add usage details to Usage class (#726) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit PR to enhance the `Usage` object and related logic, to support more granular token accounting, matching the details available in the [OpenAI Responses API](https://platform.openai.com/docs/api-reference/responses) . Specifically, it: - Adds `input_tokens_details` and `output_tokens_details` fields to the `Usage` dataclass, storing detailed token breakdowns (e.g., `cached_tokens`, `reasoning_tokens`). - Flows this change through - Updates and extends tests to match - Adds a test for the Usage.add method ### Motivation - Aligns the SDK’s usage with the latest OpenAI responses API Usage object - Supports downstream use cases that require fine-grained token usage data (e.g., billing, analytics, optimization) requested by startups --------- Co-authored-by: Wulfie Bain --- src/agents/extensions/models/litellm_model.py | 11 ++++ src/agents/models/openai_chatcompletions.py | 15 +++++- src/agents/models/openai_responses.py | 2 + src/agents/run.py | 2 + src/agents/usage.py | 22 +++++++- .../test_litellm_chatcompletions_stream.py | 16 +++++- tests/test_extra_headers.py | 20 ++++--- tests/test_openai_chatcompletions.py | 17 +++++- tests/test_openai_chatcompletions_stream.py | 16 +++++- tests/test_responses_tracing.py | 20 ++++++- tests/test_usage.py | 52 +++++++++++++++++++ 11 files changed, 178 insertions(+), 15 deletions(-) create mode 100644 tests/test_usage.py diff --git a/src/agents/extensions/models/litellm_model.py b/src/agents/extensions/models/litellm_model.py index d3b25a198..ffb2c3c1c 100644 --- a/src/agents/extensions/models/litellm_model.py +++ b/src/agents/extensions/models/litellm_model.py @@ -6,6 +6,7 @@ from typing import Any, Literal, cast, overload import litellm.types +from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails from agents.exceptions import ModelBehaviorError @@ -107,6 +108,16 @@ async def get_response( input_tokens=response_usage.prompt_tokens, output_tokens=response_usage.completion_tokens, total_tokens=response_usage.total_tokens, + input_tokens_details=InputTokensDetails( + cached_tokens=getattr( + response_usage.prompt_tokens_details, "cached_tokens", 0 + ) + ), + output_tokens_details=OutputTokensDetails( + reasoning_tokens=getattr( + response_usage.completion_tokens_details, "reasoning_tokens", 0 + ) + ), ) if response.usage else Usage() diff --git a/src/agents/models/openai_chatcompletions.py b/src/agents/models/openai_chatcompletions.py index 89619f838..4465ff2fd 100644 --- a/src/agents/models/openai_chatcompletions.py +++ b/src/agents/models/openai_chatcompletions.py @@ -9,6 +9,7 @@ from openai.types import ChatModel from openai.types.chat import ChatCompletion, ChatCompletionChunk from openai.types.responses import Response +from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails from .. import _debug from ..agent_output import AgentOutputSchemaBase @@ -83,6 +84,18 @@ async def get_response( input_tokens=response.usage.prompt_tokens, output_tokens=response.usage.completion_tokens, total_tokens=response.usage.total_tokens, + input_tokens_details=InputTokensDetails( + cached_tokens=getattr( + response.usage.prompt_tokens_details, "cached_tokens", 0 + ) + or 0, + ), + output_tokens_details=OutputTokensDetails( + reasoning_tokens=getattr( + response.usage.completion_tokens_details, "reasoning_tokens", 0 + ) + or 0, + ), ) if response.usage else Usage() @@ -252,7 +265,7 @@ async def _fetch_response( stream_options=self._non_null_or_not_given(stream_options), store=self._non_null_or_not_given(store), reasoning_effort=self._non_null_or_not_given(reasoning_effort), - extra_headers={ **HEADERS, **(model_settings.extra_headers or {}) }, + extra_headers={**HEADERS, **(model_settings.extra_headers or {})}, extra_query=model_settings.extra_query, extra_body=model_settings.extra_body, metadata=self._non_null_or_not_given(model_settings.metadata), diff --git a/src/agents/models/openai_responses.py b/src/agents/models/openai_responses.py index c1ff85b98..6ec8f8f7b 100644 --- a/src/agents/models/openai_responses.py +++ b/src/agents/models/openai_responses.py @@ -98,6 +98,8 @@ async def get_response( input_tokens=response.usage.input_tokens, output_tokens=response.usage.output_tokens, total_tokens=response.usage.total_tokens, + input_tokens_details=response.usage.input_tokens_details, + output_tokens_details=response.usage.output_tokens_details, ) if response.usage else Usage() diff --git a/src/agents/run.py b/src/agents/run.py index 849da7bfc..b196c3bf1 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -689,6 +689,8 @@ async def _run_single_turn_streamed( input_tokens=event.response.usage.input_tokens, output_tokens=event.response.usage.output_tokens, total_tokens=event.response.usage.total_tokens, + input_tokens_details=event.response.usage.input_tokens_details, + output_tokens_details=event.response.usage.output_tokens_details, ) if event.response.usage else Usage() diff --git a/src/agents/usage.py b/src/agents/usage.py index 23d989b4b..843f62937 100644 --- a/src/agents/usage.py +++ b/src/agents/usage.py @@ -1,4 +1,6 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field + +from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails @dataclass @@ -9,9 +11,18 @@ class Usage: input_tokens: int = 0 """Total input tokens sent, across all requests.""" + input_tokens_details: InputTokensDetails = field( + default_factory=lambda: InputTokensDetails(cached_tokens=0) + ) + """Details about the input tokens, matching responses API usage details.""" output_tokens: int = 0 """Total output tokens received, across all requests.""" + output_tokens_details: OutputTokensDetails = field( + default_factory=lambda: OutputTokensDetails(reasoning_tokens=0) + ) + """Details about the output tokens, matching responses API usage details.""" + total_tokens: int = 0 """Total tokens sent and received, across all requests.""" @@ -20,3 +31,12 @@ def add(self, other: "Usage") -> None: self.input_tokens += other.input_tokens if other.input_tokens else 0 self.output_tokens += other.output_tokens if other.output_tokens else 0 self.total_tokens += other.total_tokens if other.total_tokens else 0 + self.input_tokens_details = InputTokensDetails( + cached_tokens=self.input_tokens_details.cached_tokens + + other.input_tokens_details.cached_tokens + ) + + self.output_tokens_details = OutputTokensDetails( + reasoning_tokens=self.output_tokens_details.reasoning_tokens + + other.output_tokens_details.reasoning_tokens + ) diff --git a/tests/models/test_litellm_chatcompletions_stream.py b/tests/models/test_litellm_chatcompletions_stream.py index 80bd8ea22..06e46b39c 100644 --- a/tests/models/test_litellm_chatcompletions_stream.py +++ b/tests/models/test_litellm_chatcompletions_stream.py @@ -8,7 +8,11 @@ ChoiceDeltaToolCall, ChoiceDeltaToolCallFunction, ) -from openai.types.completion_usage import CompletionUsage +from openai.types.completion_usage import ( + CompletionTokensDetails, + CompletionUsage, + PromptTokensDetails, +) from openai.types.responses import ( Response, ResponseFunctionToolCall, @@ -46,7 +50,13 @@ async def test_stream_response_yields_events_for_text_content(monkeypatch) -> No model="fake", object="chat.completion.chunk", choices=[Choice(index=0, delta=ChoiceDelta(content="llo"))], - usage=CompletionUsage(completion_tokens=5, prompt_tokens=7, total_tokens=12), + usage=CompletionUsage( + completion_tokens=5, + prompt_tokens=7, + total_tokens=12, + completion_tokens_details=CompletionTokensDetails(reasoning_tokens=2), + prompt_tokens_details=PromptTokensDetails(cached_tokens=6), + ), ) async def fake_stream() -> AsyncIterator[ChatCompletionChunk]: @@ -112,6 +122,8 @@ async def patched_fetch_response(self, *args, **kwargs): assert completed_resp.usage.input_tokens == 7 assert completed_resp.usage.output_tokens == 5 assert completed_resp.usage.total_tokens == 12 + assert completed_resp.usage.input_tokens_details.cached_tokens == 6 + assert completed_resp.usage.output_tokens_details.reasoning_tokens == 2 @pytest.mark.allow_call_model_methods diff --git a/tests/test_extra_headers.py b/tests/test_extra_headers.py index f29c25408..a6af30077 100644 --- a/tests/test_extra_headers.py +++ b/tests/test_extra_headers.py @@ -1,6 +1,7 @@ import pytest from openai.types.chat.chat_completion import ChatCompletion, Choice from openai.types.chat.chat_completion_message import ChatCompletionMessage +from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails from agents import ModelSettings, ModelTracing, OpenAIChatCompletionsModel, OpenAIResponsesModel @@ -17,21 +18,29 @@ class DummyResponses: async def create(self, **kwargs): nonlocal called_kwargs called_kwargs = kwargs + class DummyResponse: id = "dummy" output = [] usage = type( - "Usage", (), {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0} + "Usage", + (), + { + "input_tokens": 0, + "output_tokens": 0, + "total_tokens": 0, + "input_tokens_details": InputTokensDetails(cached_tokens=0), + "output_tokens_details": OutputTokensDetails(reasoning_tokens=0), + }, )() + return DummyResponse() class DummyClient: def __init__(self): self.responses = DummyResponses() - - - model = OpenAIResponsesModel(model="gpt-4", openai_client=DummyClient()) # type: ignore + model = OpenAIResponsesModel(model="gpt-4", openai_client=DummyClient()) # type: ignore extra_headers = {"X-Test-Header": "test-value"} await model.get_response( system_instructions=None, @@ -47,7 +56,6 @@ def __init__(self): assert called_kwargs["extra_headers"]["X-Test-Header"] == "test-value" - @pytest.mark.allow_call_model_methods @pytest.mark.asyncio async def test_extra_headers_passed_to_openai_client(): @@ -76,7 +84,7 @@ def __init__(self): self.chat = type("_Chat", (), {"completions": DummyCompletions()})() self.base_url = "https://api.openai.com" - model = OpenAIChatCompletionsModel(model="gpt-4", openai_client=DummyClient()) # type: ignore + model = OpenAIChatCompletionsModel(model="gpt-4", openai_client=DummyClient()) # type: ignore extra_headers = {"X-Test-Header": "test-value"} await model.get_response( system_instructions=None, diff --git a/tests/test_openai_chatcompletions.py b/tests/test_openai_chatcompletions.py index ba3ec68d0..ba4605d08 100644 --- a/tests/test_openai_chatcompletions.py +++ b/tests/test_openai_chatcompletions.py @@ -13,7 +13,10 @@ ChatCompletionMessageToolCall, Function, ) -from openai.types.completion_usage import CompletionUsage +from openai.types.completion_usage import ( + CompletionUsage, + PromptTokensDetails, +) from openai.types.responses import ( Response, ResponseFunctionToolCall, @@ -51,7 +54,13 @@ async def test_get_response_with_text_message(monkeypatch) -> None: model="fake", object="chat.completion", choices=[choice], - usage=CompletionUsage(completion_tokens=5, prompt_tokens=7, total_tokens=12), + usage=CompletionUsage( + completion_tokens=5, + prompt_tokens=7, + total_tokens=12, + # completion_tokens_details left blank to test default + prompt_tokens_details=PromptTokensDetails(cached_tokens=3), + ), ) async def patched_fetch_response(self, *args, **kwargs): @@ -81,6 +90,8 @@ async def patched_fetch_response(self, *args, **kwargs): assert resp.usage.input_tokens == 7 assert resp.usage.output_tokens == 5 assert resp.usage.total_tokens == 12 + assert resp.usage.input_tokens_details.cached_tokens == 3 + assert resp.usage.output_tokens_details.reasoning_tokens == 0 assert resp.response_id is None @@ -127,6 +138,8 @@ async def patched_fetch_response(self, *args, **kwargs): assert resp.usage.requests == 0 assert resp.usage.input_tokens == 0 assert resp.usage.output_tokens == 0 + assert resp.usage.input_tokens_details.cached_tokens == 0 + assert resp.usage.output_tokens_details.reasoning_tokens == 0 @pytest.mark.allow_call_model_methods diff --git a/tests/test_openai_chatcompletions_stream.py b/tests/test_openai_chatcompletions_stream.py index b82f24303..5c8bb9e3a 100644 --- a/tests/test_openai_chatcompletions_stream.py +++ b/tests/test_openai_chatcompletions_stream.py @@ -8,7 +8,11 @@ ChoiceDeltaToolCall, ChoiceDeltaToolCallFunction, ) -from openai.types.completion_usage import CompletionUsage +from openai.types.completion_usage import ( + CompletionTokensDetails, + CompletionUsage, + PromptTokensDetails, +) from openai.types.responses import ( Response, ResponseFunctionToolCall, @@ -46,7 +50,13 @@ async def test_stream_response_yields_events_for_text_content(monkeypatch) -> No model="fake", object="chat.completion.chunk", choices=[Choice(index=0, delta=ChoiceDelta(content="llo"))], - usage=CompletionUsage(completion_tokens=5, prompt_tokens=7, total_tokens=12), + usage=CompletionUsage( + completion_tokens=5, + prompt_tokens=7, + total_tokens=12, + prompt_tokens_details=PromptTokensDetails(cached_tokens=2), + completion_tokens_details=CompletionTokensDetails(reasoning_tokens=3), + ), ) async def fake_stream() -> AsyncIterator[ChatCompletionChunk]: @@ -112,6 +122,8 @@ async def patched_fetch_response(self, *args, **kwargs): assert completed_resp.usage.input_tokens == 7 assert completed_resp.usage.output_tokens == 5 assert completed_resp.usage.total_tokens == 12 + assert completed_resp.usage.input_tokens_details.cached_tokens == 2 + assert completed_resp.usage.output_tokens_details.reasoning_tokens == 3 @pytest.mark.allow_call_model_methods diff --git a/tests/test_responses_tracing.py b/tests/test_responses_tracing.py index 0bc97a953..dfac74bb9 100644 --- a/tests/test_responses_tracing.py +++ b/tests/test_responses_tracing.py @@ -1,7 +1,10 @@ +from typing import Optional + import pytest from inline_snapshot import snapshot from openai import AsyncOpenAI from openai.types.responses import ResponseCompletedEvent +from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails from agents import ModelSettings, ModelTracing, OpenAIResponsesModel, trace from agents.tracing.span_data import ResponseSpanData @@ -16,10 +19,25 @@ def is_disabled(self): class DummyUsage: - def __init__(self, input_tokens=1, output_tokens=1, total_tokens=2): + def __init__( + self, + input_tokens: int = 1, + input_tokens_details: Optional[InputTokensDetails] = None, + output_tokens: int = 1, + output_tokens_details: Optional[OutputTokensDetails] = None, + total_tokens: int = 2, + ): self.input_tokens = input_tokens self.output_tokens = output_tokens self.total_tokens = total_tokens + self.input_tokens_details = ( + input_tokens_details if input_tokens_details else InputTokensDetails(cached_tokens=0) + ) + self.output_tokens_details = ( + output_tokens_details + if output_tokens_details + else OutputTokensDetails(reasoning_tokens=0) + ) class DummyResponse: diff --git a/tests/test_usage.py b/tests/test_usage.py new file mode 100644 index 000000000..405f99ddf --- /dev/null +++ b/tests/test_usage.py @@ -0,0 +1,52 @@ +from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails + +from agents.usage import Usage + + +def test_usage_add_aggregates_all_fields(): + u1 = Usage( + requests=1, + input_tokens=10, + input_tokens_details=InputTokensDetails(cached_tokens=3), + output_tokens=20, + output_tokens_details=OutputTokensDetails(reasoning_tokens=5), + total_tokens=30, + ) + u2 = Usage( + requests=2, + input_tokens=7, + input_tokens_details=InputTokensDetails(cached_tokens=4), + output_tokens=8, + output_tokens_details=OutputTokensDetails(reasoning_tokens=6), + total_tokens=15, + ) + + u1.add(u2) + + assert u1.requests == 3 + assert u1.input_tokens == 17 + assert u1.output_tokens == 28 + assert u1.total_tokens == 45 + assert u1.input_tokens_details.cached_tokens == 7 + assert u1.output_tokens_details.reasoning_tokens == 11 + + +def test_usage_add_aggregates_with_none_values(): + u1 = Usage() + u2 = Usage( + requests=2, + input_tokens=7, + input_tokens_details=InputTokensDetails(cached_tokens=4), + output_tokens=8, + output_tokens_details=OutputTokensDetails(reasoning_tokens=6), + total_tokens=15, + ) + + u1.add(u2) + + assert u1.requests == 2 + assert u1.input_tokens == 7 + assert u1.output_tokens == 8 + assert u1.total_tokens == 15 + assert u1.input_tokens_details.cached_tokens == 4 + assert u1.output_tokens_details.reasoning_tokens == 6 From ce2e2a4571c2b176e8641c558fedaa7bc1692013 Mon Sep 17 00:00:00 2001 From: Rohan Mehta Date: Wed, 21 May 2025 15:17:58 -0400 Subject: [PATCH 03/14] Upgrade openAI sdk version (#730) --- [//]: # (BEGIN SAPLING FOOTER) * #732 * #731 * __->__ #730 --- pyproject.toml | 2 +- src/agents/models/chatcmpl_stream_handler.py | 26 +++++++++++++++++++- src/agents/models/openai_responses.py | 16 ++++-------- tests/fake_model.py | 1 + tests/test_responses_tracing.py | 4 +++ tests/voice/test_workflow.py | 2 ++ uv.lock | 10 ++++---- 7 files changed, 43 insertions(+), 18 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 672258c42..200ac2485 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ requires-python = ">=3.9" license = "MIT" authors = [{ name = "OpenAI", email = "support@openai.com" }] dependencies = [ - "openai>=1.76.0", + "openai>=1.81.0", "pydantic>=2.10, <3", "griffe>=1.5.6, <2", "typing-extensions>=4.12.2, <5", diff --git a/src/agents/models/chatcmpl_stream_handler.py b/src/agents/models/chatcmpl_stream_handler.py index c71adeb55..d18f5912a 100644 --- a/src/agents/models/chatcmpl_stream_handler.py +++ b/src/agents/models/chatcmpl_stream_handler.py @@ -38,6 +38,16 @@ class StreamingState: function_calls: dict[int, ResponseFunctionToolCall] = field(default_factory=dict) +class SequenceNumber: + def __init__(self): + self._sequence_number = 0 + + def get_and_increment(self) -> int: + num = self._sequence_number + self._sequence_number += 1 + return num + + class ChatCmplStreamHandler: @classmethod async def handle_stream( @@ -47,13 +57,14 @@ async def handle_stream( ) -> AsyncIterator[TResponseStreamEvent]: usage: CompletionUsage | None = None state = StreamingState() - + sequence_number = SequenceNumber() async for chunk in stream: if not state.started: state.started = True yield ResponseCreatedEvent( response=response, type="response.created", + sequence_number=sequence_number.get_and_increment(), ) # This is always set by the OpenAI API, but not by others e.g. LiteLLM @@ -89,6 +100,7 @@ async def handle_stream( item=assistant_item, output_index=0, type="response.output_item.added", + sequence_number=sequence_number.get_and_increment(), ) yield ResponseContentPartAddedEvent( content_index=state.text_content_index_and_output[0], @@ -100,6 +112,7 @@ async def handle_stream( annotations=[], ), type="response.content_part.added", + sequence_number=sequence_number.get_and_increment(), ) # Emit the delta for this segment of content yield ResponseTextDeltaEvent( @@ -108,6 +121,7 @@ async def handle_stream( item_id=FAKE_RESPONSES_ID, output_index=0, type="response.output_text.delta", + sequence_number=sequence_number.get_and_increment(), ) # Accumulate the text into the response part state.text_content_index_and_output[1].text += delta.content @@ -134,6 +148,7 @@ async def handle_stream( item=assistant_item, output_index=0, type="response.output_item.added", + sequence_number=sequence_number.get_and_increment(), ) yield ResponseContentPartAddedEvent( content_index=state.refusal_content_index_and_output[0], @@ -145,6 +160,7 @@ async def handle_stream( annotations=[], ), type="response.content_part.added", + sequence_number=sequence_number.get_and_increment(), ) # Emit the delta for this segment of refusal yield ResponseRefusalDeltaEvent( @@ -153,6 +169,7 @@ async def handle_stream( item_id=FAKE_RESPONSES_ID, output_index=0, type="response.refusal.delta", + sequence_number=sequence_number.get_and_increment(), ) # Accumulate the refusal string in the output part state.refusal_content_index_and_output[1].refusal += delta.refusal @@ -190,6 +207,7 @@ async def handle_stream( output_index=0, part=state.text_content_index_and_output[1], type="response.content_part.done", + sequence_number=sequence_number.get_and_increment(), ) if state.refusal_content_index_and_output: @@ -201,6 +219,7 @@ async def handle_stream( output_index=0, part=state.refusal_content_index_and_output[1], type="response.content_part.done", + sequence_number=sequence_number.get_and_increment(), ) # Actually send events for the function calls @@ -216,6 +235,7 @@ async def handle_stream( ), output_index=function_call_starting_index, type="response.output_item.added", + sequence_number=sequence_number.get_and_increment(), ) # Then, yield the args yield ResponseFunctionCallArgumentsDeltaEvent( @@ -223,6 +243,7 @@ async def handle_stream( item_id=FAKE_RESPONSES_ID, output_index=function_call_starting_index, type="response.function_call_arguments.delta", + sequence_number=sequence_number.get_and_increment(), ) # Finally, the ResponseOutputItemDone yield ResponseOutputItemDoneEvent( @@ -235,6 +256,7 @@ async def handle_stream( ), output_index=function_call_starting_index, type="response.output_item.done", + sequence_number=sequence_number.get_and_increment(), ) # Finally, send the Response completed event @@ -258,6 +280,7 @@ async def handle_stream( item=assistant_msg, output_index=0, type="response.output_item.done", + sequence_number=sequence_number.get_and_increment(), ) for function_call in state.function_calls.values(): @@ -289,4 +312,5 @@ async def handle_stream( yield ResponseCompletedEvent( response=final_response, type="response.completed", + sequence_number=sequence_number.get_and_increment(), ) diff --git a/src/agents/models/openai_responses.py b/src/agents/models/openai_responses.py index 6ec8f8f7b..cb6567909 100644 --- a/src/agents/models/openai_responses.py +++ b/src/agents/models/openai_responses.py @@ -10,6 +10,7 @@ from openai.types.responses import ( Response, ResponseCompletedEvent, + ResponseIncludable, ResponseStreamEvent, ResponseTextConfigParam, ToolParam, @@ -36,13 +37,6 @@ _USER_AGENT = f"Agents/Python {__version__}" _HEADERS = {"User-Agent": _USER_AGENT} -# From the Responses API -IncludeLiteral = Literal[ - "file_search_call.results", - "message.input_image.image_url", - "computer_call_output.output.image_url", -] - class OpenAIResponsesModel(Model): """ @@ -273,7 +267,7 @@ def _get_client(self) -> AsyncOpenAI: @dataclass class ConvertedTools: tools: list[ToolParam] - includes: list[IncludeLiteral] + includes: list[ResponseIncludable] class Converter: @@ -330,7 +324,7 @@ def convert_tools( handoffs: list[Handoff[Any]], ) -> ConvertedTools: converted_tools: list[ToolParam] = [] - includes: list[IncludeLiteral] = [] + includes: list[ResponseIncludable] = [] computer_tools = [tool for tool in tools if isinstance(tool, ComputerTool)] if len(computer_tools) > 1: @@ -348,7 +342,7 @@ def convert_tools( return ConvertedTools(tools=converted_tools, includes=includes) @classmethod - def _convert_tool(cls, tool: Tool) -> tuple[ToolParam, IncludeLiteral | None]: + def _convert_tool(cls, tool: Tool) -> tuple[ToolParam, ResponseIncludable | None]: """Returns converted tool and includes""" if isinstance(tool, FunctionTool): @@ -359,7 +353,7 @@ def _convert_tool(cls, tool: Tool) -> tuple[ToolParam, IncludeLiteral | None]: "type": "function", "description": tool.description, } - includes: IncludeLiteral | None = None + includes: ResponseIncludable | None = None elif isinstance(tool, WebSearchTool): ws: WebSearchToolParam = { "type": "web_search_preview", diff --git a/tests/fake_model.py b/tests/fake_model.py index 32f919ef1..9f0c83a2f 100644 --- a/tests/fake_model.py +++ b/tests/fake_model.py @@ -129,6 +129,7 @@ async def stream_response( yield ResponseCompletedEvent( type="response.completed", response=get_response_obj(output, usage=self.hardcoded_usage), + sequence_number=0, ) diff --git a/tests/test_responses_tracing.py b/tests/test_responses_tracing.py index dfac74bb9..db24fe496 100644 --- a/tests/test_responses_tracing.py +++ b/tests/test_responses_tracing.py @@ -50,6 +50,7 @@ def __aiter__(self): yield ResponseCompletedEvent( type="response.completed", response=fake_model.get_response_obj(self.output), + sequence_number=0, ) @@ -201,6 +202,7 @@ async def __aiter__(self): yield ResponseCompletedEvent( type="response.completed", response=fake_model.get_response_obj([], "dummy-id-123"), + sequence_number=0, ) return DummyStream() @@ -253,6 +255,7 @@ async def __aiter__(self): yield ResponseCompletedEvent( type="response.completed", response=fake_model.get_response_obj([], "dummy-id-123"), + sequence_number=0, ) return DummyStream() @@ -304,6 +307,7 @@ async def __aiter__(self): yield ResponseCompletedEvent( type="response.completed", response=fake_model.get_response_obj([], "dummy-id-123"), + sequence_number=0, ) return DummyStream() diff --git a/tests/voice/test_workflow.py b/tests/voice/test_workflow.py index 2bdf2a657..035a05d56 100644 --- a/tests/voice/test_workflow.py +++ b/tests/voice/test_workflow.py @@ -81,11 +81,13 @@ async def stream_response( type="response.output_text.delta", output_index=0, item_id=item.id, + sequence_number=0, ) yield ResponseCompletedEvent( type="response.completed", response=get_response_obj(output), + sequence_number=1, ) diff --git a/uv.lock b/uv.lock index 6ccc19966..7a0cb1e6b 100644 --- a/uv.lock +++ b/uv.lock @@ -1461,7 +1461,7 @@ wheels = [ [[package]] name = "openai" -version = "1.76.0" +version = "1.81.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, @@ -1473,14 +1473,14 @@ dependencies = [ { name = "tqdm" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/84/51/817969ec969b73d8ddad085670ecd8a45ef1af1811d8c3b8a177ca4d1309/openai-1.76.0.tar.gz", hash = "sha256:fd2bfaf4608f48102d6b74f9e11c5ecaa058b60dad9c36e409c12477dfd91fb2", size = 434660 } +sdist = { url = "https://files.pythonhosted.org/packages/1c/89/a1e4f3fa7ca4f7fec90dbf47d93b7cd5ff65924926733af15044e302a192/openai-1.81.0.tar.gz", hash = "sha256:349567a8607e0bcffd28e02f96b5c2397d0d25d06732d90ab3ecbf97abf030f9", size = 456861 } wheels = [ - { url = "https://files.pythonhosted.org/packages/59/aa/84e02ab500ca871eb8f62784426963a1c7c17a72fea3c7f268af4bbaafa5/openai-1.76.0-py3-none-any.whl", hash = "sha256:a712b50e78cf78e6d7b2a8f69c4978243517c2c36999756673e07a14ce37dc0a", size = 661201 }, + { url = "https://files.pythonhosted.org/packages/02/66/bcc7f9bf48e8610a33e3b5c96a5a644dad032d92404ea2a5e8b43ba067e8/openai-1.81.0-py3-none-any.whl", hash = "sha256:1c71572e22b43876c5d7d65ade0b7b516bb527c3d44ae94111267a09125f7bae", size = 717529 }, ] [[package]] name = "openai-agents" -version = "0.0.14" +version = "0.0.15" source = { editable = "." } dependencies = [ { name = "griffe" }, @@ -1536,7 +1536,7 @@ requires-dist = [ { name = "litellm", marker = "extra == 'litellm'", specifier = ">=1.67.4.post1,<2" }, { name = "mcp", marker = "python_full_version >= '3.10'", specifier = ">=1.8.0,<2" }, { name = "numpy", marker = "python_full_version >= '3.10' and extra == 'voice'", specifier = ">=2.2.0,<3" }, - { name = "openai", specifier = ">=1.76.0" }, + { name = "openai", specifier = ">=1.81.0" }, { name = "pydantic", specifier = ">=2.10,<3" }, { name = "requests", specifier = ">=2.0,<3" }, { name = "types-requests", specifier = ">=2.0,<3" }, From 9fa5c39d69937a215a6f247883243fe38c5a39c2 Mon Sep 17 00:00:00 2001 From: Rohan Mehta Date: Wed, 21 May 2025 15:21:37 -0400 Subject: [PATCH 04/14] Hosted MCP support (#731) --- [//]: # (BEGIN SAPLING FOOTER) * #732 * __->__ #731 --- examples/hosted_mcp/__init__.py | 0 examples/hosted_mcp/approvals.py | 61 ++++++++++++++ examples/hosted_mcp/simple.py | 47 +++++++++++ src/agents/__init__.py | 8 ++ src/agents/_run_impl.py | 114 ++++++++++++++++++++++++-- src/agents/items.py | 41 ++++++++- src/agents/models/openai_responses.py | 6 +- src/agents/stream_events.py | 2 + src/agents/tool.py | 64 ++++++++++++++- 9 files changed, 332 insertions(+), 11 deletions(-) create mode 100644 examples/hosted_mcp/__init__.py create mode 100644 examples/hosted_mcp/approvals.py create mode 100644 examples/hosted_mcp/simple.py diff --git a/examples/hosted_mcp/__init__.py b/examples/hosted_mcp/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/hosted_mcp/approvals.py b/examples/hosted_mcp/approvals.py new file mode 100644 index 000000000..2cabb3ee2 --- /dev/null +++ b/examples/hosted_mcp/approvals.py @@ -0,0 +1,61 @@ +import argparse +import asyncio + +from agents import ( + Agent, + HostedMCPTool, + MCPToolApprovalFunctionResult, + MCPToolApprovalRequest, + Runner, +) + +"""This example demonstrates how to use the hosted MCP support in the OpenAI Responses API, with +approval callbacks.""" + + +def approval_callback(request: MCPToolApprovalRequest) -> MCPToolApprovalFunctionResult: + answer = input(f"Approve running the tool `{request.data.name}`? (y/n) ") + result: MCPToolApprovalFunctionResult = {"approve": answer == "y"} + if not result["approve"]: + result["reason"] = "User denied" + return result + + +async def main(verbose: bool, stream: bool): + agent = Agent( + name="Assistant", + tools=[ + HostedMCPTool( + tool_config={ + "type": "mcp", + "server_label": "gitmcp", + "server_url": "https://gitmcp.io/openai/codex", + "require_approval": "always", + }, + on_approval_request=approval_callback, + ) + ], + ) + + if stream: + result = Runner.run_streamed(agent, "Which language is this repo written in?") + async for event in result.stream_events(): + if event.type == "run_item_stream_event": + print(f"Got event of type {event.item.__class__.__name__}") + print(f"Done streaming; final result: {result.final_output}") + else: + res = await Runner.run(agent, "Which language is this repo written in?") + print(res.final_output) + + if verbose: + for item in result.new_items: + print(item) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--verbose", action="store_true", default=False) + parser.add_argument("--stream", action="store_true", default=False) + args = parser.parse_args() + + asyncio.run(main(args.verbose, args.stream)) diff --git a/examples/hosted_mcp/simple.py b/examples/hosted_mcp/simple.py new file mode 100644 index 000000000..508c3a7ae --- /dev/null +++ b/examples/hosted_mcp/simple.py @@ -0,0 +1,47 @@ +import argparse +import asyncio + +from agents import Agent, HostedMCPTool, Runner + +"""This example demonstrates how to use the hosted MCP support in the OpenAI Responses API, with +approvals not required for any tools. You should only use this for trusted MCP servers.""" + + +async def main(verbose: bool, stream: bool): + agent = Agent( + name="Assistant", + tools=[ + HostedMCPTool( + tool_config={ + "type": "mcp", + "server_label": "gitmcp", + "server_url": "https://gitmcp.io/openai/codex", + "require_approval": "never", + } + ) + ], + ) + + if stream: + result = Runner.run_streamed(agent, "Which language is this repo written in?") + async for event in result.stream_events(): + if event.type == "run_item_stream_event": + print(f"Got event of type {event.item.__class__.__name__}") + print(f"Done streaming; final result: {result.final_output}") + else: + res = await Runner.run(agent, "Which language is this repo written in?") + print(res.final_output) + # The repository is primarily written in multiple languages, including Rust and TypeScript... + + if verbose: + for item in result.new_items: + print(item) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--verbose", action="store_true", default=False) + parser.add_argument("--stream", action="store_true", default=False) + args = parser.parse_args() + + asyncio.run(main(args.verbose, args.stream)) diff --git a/src/agents/__init__.py b/src/agents/__init__.py index 6d7c90b4f..36c26b80d 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -58,6 +58,10 @@ FileSearchTool, FunctionTool, FunctionToolResult, + HostedMCPTool, + MCPToolApprovalFunction, + MCPToolApprovalFunctionResult, + MCPToolApprovalRequest, Tool, WebSearchTool, default_tool_error_function, @@ -208,6 +212,10 @@ def enable_verbose_stdout_logging(): "FileSearchTool", "Tool", "WebSearchTool", + "HostedMCPTool", + "MCPToolApprovalFunction", + "MCPToolApprovalRequest", + "MCPToolApprovalFunctionResult", "function_tool", "Usage", "add_trace_processor", diff --git a/src/agents/_run_impl.py b/src/agents/_run_impl.py index b5a83685c..ab1e78797 100644 --- a/src/agents/_run_impl.py +++ b/src/agents/_run_impl.py @@ -25,7 +25,8 @@ ActionType, ActionWait, ) -from openai.types.responses.response_input_param import ComputerCallOutput +from openai.types.responses.response_input_param import ComputerCallOutput, McpApprovalResponse +from openai.types.responses.response_output_item import McpApprovalRequest, McpCall, McpListTools from openai.types.responses.response_reasoning_item import ResponseReasoningItem from .agent import Agent, ToolsToFinalOutputResult @@ -38,6 +39,9 @@ HandoffCallItem, HandoffOutputItem, ItemHelpers, + MCPApprovalRequestItem, + MCPApprovalResponseItem, + MCPListToolsItem, MessageOutputItem, ModelResponse, ReasoningItem, @@ -52,7 +56,14 @@ from .models.interface import ModelTracing from .run_context import RunContextWrapper, TContext from .stream_events import RunItemStreamEvent, StreamEvent -from .tool import ComputerTool, FunctionTool, FunctionToolResult, Tool +from .tool import ( + ComputerTool, + FunctionTool, + FunctionToolResult, + HostedMCPTool, + MCPToolApprovalRequest, + Tool, +) from .tracing import ( SpanError, Trace, @@ -112,6 +123,12 @@ class ToolRunComputerAction: computer_tool: ComputerTool +@dataclass +class ToolRunMCPApprovalRequest: + request_item: McpApprovalRequest + mcp_tool: HostedMCPTool + + @dataclass class ProcessedResponse: new_items: list[RunItem] @@ -119,8 +136,9 @@ class ProcessedResponse: functions: list[ToolRunFunction] computer_actions: list[ToolRunComputerAction] tools_used: list[str] # Names of all tools used, including hosted tools + mcp_approval_requests: list[ToolRunMCPApprovalRequest] # Only requests with callbacks - def has_tools_to_run(self) -> bool: + def has_tools_or_approvals_to_run(self) -> bool: # Handoffs, functions and computer actions need local processing # Hosted tools have already run, so there's nothing to do. return any( @@ -128,6 +146,7 @@ def has_tools_to_run(self) -> bool: self.handoffs, self.functions, self.computer_actions, + self.mcp_approval_requests, ] ) @@ -226,7 +245,16 @@ async def execute_tools_and_side_effects( new_step_items.extend([result.run_item for result in function_results]) new_step_items.extend(computer_results) - # Second, check if there are any handoffs + # Next, run the MCP approval requests + if processed_response.mcp_approval_requests: + approval_results = await cls.execute_mcp_approval_requests( + agent=agent, + approval_requests=processed_response.mcp_approval_requests, + context_wrapper=context_wrapper, + ) + new_step_items.extend(approval_results) + + # Next, check if there are any handoffs if run_handoffs := processed_response.handoffs: return await cls.execute_handoffs( agent=agent, @@ -240,7 +268,7 @@ async def execute_tools_and_side_effects( run_config=run_config, ) - # Third, we'll check if the tool use should result in a final output + # Next, we'll check if the tool use should result in a final output check_tool_use = await cls._check_for_final_output_from_tools( agent=agent, tool_results=function_results, @@ -295,7 +323,7 @@ async def execute_tools_and_side_effects( ) elif ( not output_schema or output_schema.is_plain_text() - ) and not processed_response.has_tools_to_run(): + ) and not processed_response.has_tools_or_approvals_to_run(): return await cls.execute_final_output( agent=agent, original_input=original_input, @@ -343,10 +371,16 @@ def process_model_response( run_handoffs = [] functions = [] computer_actions = [] + mcp_approval_requests = [] tools_used: list[str] = [] handoff_map = {handoff.tool_name: handoff for handoff in handoffs} function_map = {tool.name: tool for tool in all_tools if isinstance(tool, FunctionTool)} computer_tool = next((tool for tool in all_tools if isinstance(tool, ComputerTool)), None) + hosted_mcp_server_map = { + tool.tool_config["server_label"]: tool + for tool in all_tools + if isinstance(tool, HostedMCPTool) + } for output in response.output: if isinstance(output, ResponseOutputMessage): @@ -375,6 +409,34 @@ def process_model_response( computer_actions.append( ToolRunComputerAction(tool_call=output, computer_tool=computer_tool) ) + elif isinstance(output, McpApprovalRequest): + items.append(MCPApprovalRequestItem(raw_item=output, agent=agent)) + if output.server_label not in hosted_mcp_server_map: + _error_tracing.attach_error_to_current_span( + SpanError( + message="MCP server label not found", + data={"server_label": output.server_label}, + ) + ) + raise ModelBehaviorError(f"MCP server label {output.server_label} not found") + else: + server = hosted_mcp_server_map[output.server_label] + if server.on_approval_request: + mcp_approval_requests.append( + ToolRunMCPApprovalRequest( + request_item=output, + mcp_tool=server, + ) + ) + else: + logger.warning( + f"MCP server {output.server_label} has no on_approval_request hook" + ) + elif isinstance(output, McpListTools): + items.append(MCPListToolsItem(raw_item=output, agent=agent)) + elif isinstance(output, McpCall): + items.append(ToolCallItem(raw_item=output, agent=agent)) + tools_used.append(output.name) elif not isinstance(output, ResponseFunctionToolCall): logger.warning(f"Unexpected output type, ignoring: {type(output)}") continue @@ -417,6 +479,7 @@ def process_model_response( functions=functions, computer_actions=computer_actions, tools_used=tools_used, + mcp_approval_requests=mcp_approval_requests, ) @classmethod @@ -643,6 +706,40 @@ async def execute_handoffs( next_step=NextStepHandoff(new_agent), ) + @classmethod + async def execute_mcp_approval_requests( + cls, + *, + agent: Agent[TContext], + approval_requests: list[ToolRunMCPApprovalRequest], + context_wrapper: RunContextWrapper[TContext], + ) -> list[RunItem]: + async def run_single_approval(approval_request: ToolRunMCPApprovalRequest) -> RunItem: + callback = approval_request.mcp_tool.on_approval_request + assert callback is not None, "Callback is required for MCP approval requests" + maybe_awaitable_result = callback( + MCPToolApprovalRequest(context_wrapper, approval_request.request_item) + ) + if inspect.isawaitable(maybe_awaitable_result): + result = await maybe_awaitable_result + else: + result = maybe_awaitable_result + reason = result.get("reason", None) + raw_item: McpApprovalResponse = { + "approval_request_id": approval_request.request_item.id, + "approve": result["approve"], + "type": "mcp_approval_response", + } + if not result["approve"] and reason: + raw_item["reason"] = reason + return MCPApprovalResponseItem( + raw_item=raw_item, + agent=agent, + ) + + tasks = [run_single_approval(approval_request) for approval_request in approval_requests] + return await asyncio.gather(*tasks) + @classmethod async def execute_final_output( cls, @@ -727,6 +824,11 @@ def stream_step_result_to_queue( event = RunItemStreamEvent(item=item, name="tool_output") elif isinstance(item, ReasoningItem): event = RunItemStreamEvent(item=item, name="reasoning_item_created") + elif isinstance(item, MCPApprovalRequestItem): + event = RunItemStreamEvent(item=item, name="mcp_approval_requested") + elif isinstance(item, MCPListToolsItem): + event = RunItemStreamEvent(item=item, name="mcp_list_tools") + else: logger.warning(f"Unexpected item type: {type(item)}") event = None diff --git a/src/agents/items.py b/src/agents/items.py index 8fb2b52a3..65a911798 100644 --- a/src/agents/items.py +++ b/src/agents/items.py @@ -18,7 +18,12 @@ ResponseOutputText, ResponseStreamEvent, ) -from openai.types.responses.response_input_item_param import ComputerCallOutput, FunctionCallOutput +from openai.types.responses.response_input_item_param import ( + ComputerCallOutput, + FunctionCallOutput, + McpApprovalResponse, +) +from openai.types.responses.response_output_item import McpApprovalRequest, McpCall, McpListTools from openai.types.responses.response_reasoning_item import ResponseReasoningItem from pydantic import BaseModel from typing_extensions import TypeAlias @@ -108,6 +113,7 @@ class HandoffOutputItem(RunItemBase[TResponseInputItem]): ResponseComputerToolCall, ResponseFileSearchToolCall, ResponseFunctionWebSearch, + McpCall, ] """A type that represents a tool call item.""" @@ -147,6 +153,36 @@ class ReasoningItem(RunItemBase[ResponseReasoningItem]): type: Literal["reasoning_item"] = "reasoning_item" +@dataclass +class MCPListToolsItem(RunItemBase[McpListTools]): + """Represents a call to an MCP server to list tools.""" + + raw_item: McpListTools + """The raw MCP list tools call.""" + + type: Literal["mcp_list_tools_item"] = "mcp_list_tools_item" + + +@dataclass +class MCPApprovalRequestItem(RunItemBase[McpApprovalRequest]): + """Represents a request for MCP approval.""" + + raw_item: McpApprovalRequest + """The raw MCP approval request.""" + + type: Literal["mcp_approval_request_item"] = "mcp_approval_request_item" + + +@dataclass +class MCPApprovalResponseItem(RunItemBase[McpApprovalResponse]): + """Represents a response to an MCP approval request.""" + + raw_item: McpApprovalResponse + """The raw MCP approval response.""" + + type: Literal["mcp_approval_response_item"] = "mcp_approval_response_item" + + RunItem: TypeAlias = Union[ MessageOutputItem, HandoffCallItem, @@ -154,6 +190,9 @@ class ReasoningItem(RunItemBase[ResponseReasoningItem]): ToolCallItem, ToolCallOutputItem, ReasoningItem, + MCPListToolsItem, + MCPApprovalRequestItem, + MCPApprovalResponseItem, ] """An item generated by an agent.""" diff --git a/src/agents/models/openai_responses.py b/src/agents/models/openai_responses.py index cb6567909..65a4f5caf 100644 --- a/src/agents/models/openai_responses.py +++ b/src/agents/models/openai_responses.py @@ -24,7 +24,7 @@ from ..handoffs import Handoff from ..items import ItemHelpers, ModelResponse, TResponseInputItem from ..logger import logger -from ..tool import ComputerTool, FileSearchTool, FunctionTool, Tool, WebSearchTool +from ..tool import ComputerTool, FileSearchTool, FunctionTool, HostedMCPTool, Tool, WebSearchTool from ..tracing import SpanError, response_span from ..usage import Usage from ..version import __version__ @@ -383,7 +383,9 @@ def _convert_tool(cls, tool: Tool) -> tuple[ToolParam, ResponseIncludable | None "display_height": tool.computer.dimensions[1], } includes = None - + elif isinstance(tool, HostedMCPTool): + converted_tool = tool.tool_config + includes = None else: raise UserError(f"Unknown tool type: {type(tool)}, tool") diff --git a/src/agents/stream_events.py b/src/agents/stream_events.py index bd37d11f3..111d0b951 100644 --- a/src/agents/stream_events.py +++ b/src/agents/stream_events.py @@ -35,6 +35,8 @@ class RunItemStreamEvent: "tool_called", "tool_output", "reasoning_item_created", + "mcp_approval_requested", + "mcp_list_tools", ] """The name of the event.""" diff --git a/src/agents/tool.py b/src/agents/tool.py index c1c162423..3bcd57c2e 100644 --- a/src/agents/tool.py +++ b/src/agents/tool.py @@ -7,9 +7,11 @@ from typing import Any, Callable, Literal, Union, overload from openai.types.responses.file_search_tool_param import Filters, RankingOptions +from openai.types.responses.response_output_item import McpApprovalRequest +from openai.types.responses.tool_param import Mcp from openai.types.responses.web_search_tool_param import UserLocation from pydantic import ValidationError -from typing_extensions import Concatenate, ParamSpec +from typing_extensions import Concatenate, NotRequired, ParamSpec, TypedDict from . import _debug from .computer import AsyncComputer, Computer @@ -130,7 +132,55 @@ def name(self): return "computer_use_preview" -Tool = Union[FunctionTool, FileSearchTool, WebSearchTool, ComputerTool] +@dataclass +class MCPToolApprovalRequest: + """A request to approve a tool call.""" + + ctx_wrapper: RunContextWrapper[Any] + """The run context.""" + + data: McpApprovalRequest + """The data from the MCP tool approval request.""" + + +class MCPToolApprovalFunctionResult(TypedDict): + """The result of an MCP tool approval function.""" + + approve: bool + """Whether to approve the tool call.""" + + reason: NotRequired[str] + """An optional reason, if rejected.""" + + +MCPToolApprovalFunction = Callable[ + [MCPToolApprovalRequest], MaybeAwaitable[MCPToolApprovalFunctionResult] +] +"""A function that approves or rejects a tool call.""" + + +@dataclass +class HostedMCPTool: + """A tool that allows the LLM to use a remote MCP server. The LLM will automatically list and + call tools, without requiring a a round trip back to your code. + If you want to run MCP servers locally via stdio, in a VPC or other non-publicly-accessible + environment, or you just prefer to run tool calls locally, then you can instead use the servers + in `agents.mcp` and pass `Agent(mcp_servers=[...])` to the agent.""" + + tool_config: Mcp + """The MCP tool config, which includes the server URL and other settings.""" + + on_approval_request: MCPToolApprovalFunction | None = None + """An optional function that will be called if approval is requested for an MCP tool. If not + provided, you will need to manually add approvals/rejections to the input and call + `Runner.run(...)` again.""" + + @property + def name(self): + return "hosted_mcp" + + +Tool = Union[FunctionTool, FileSearchTool, WebSearchTool, ComputerTool, HostedMCPTool] """A tool that can be used in an agent.""" @@ -308,3 +358,13 @@ def decorator(real_func: ToolFunction[...]) -> FunctionTool: return _create_function_tool(real_func) return decorator + return decorator + return decorator + return decorator + return decorator + return decorator + return decorator + return decorator + return decorator + return decorator + return decorator From 079764f0ab463fda9ecf397b0a5d8e466e87a86c Mon Sep 17 00:00:00 2001 From: Rohan Mehta Date: Wed, 21 May 2025 15:26:22 -0400 Subject: [PATCH 05/14] Add support for local shell, image generator, code interpreter tools (#732) --- examples/tools/code_interpreter.py | 34 +++++++ examples/tools/image_generator.py | 54 +++++++++++ src/agents/__init__.py | 10 +++ src/agents/_run_impl.py | 124 +++++++++++++++++++++++++- src/agents/items.py | 21 ++++- src/agents/models/openai_responses.py | 35 +++++++- src/agents/tool.py | 76 +++++++++++++--- 7 files changed, 334 insertions(+), 20 deletions(-) create mode 100644 examples/tools/code_interpreter.py create mode 100644 examples/tools/image_generator.py diff --git a/examples/tools/code_interpreter.py b/examples/tools/code_interpreter.py new file mode 100644 index 000000000..a5843ce3f --- /dev/null +++ b/examples/tools/code_interpreter.py @@ -0,0 +1,34 @@ +import asyncio + +from agents import Agent, CodeInterpreterTool, Runner, trace + + +async def main(): + agent = Agent( + name="Code interpreter", + instructions="You love doing math.", + tools=[ + CodeInterpreterTool( + tool_config={"type": "code_interpreter", "container": {"type": "auto"}}, + ) + ], + ) + + with trace("Code interpreter example"): + print("Solving math problem...") + result = Runner.run_streamed(agent, "What is the square root of273 * 312821 plus 1782?") + async for event in result.stream_events(): + if ( + event.type == "run_item_stream_event" + and event.item.type == "tool_call_item" + and event.item.raw_item.type == "code_interpreter_call" + ): + print(f"Code interpreter code:\n```\n{event.item.raw_item.code}\n```\n") + elif event.type == "run_item_stream_event": + print(f"Other event: {event.item.type}") + + print(f"Final output: {result.final_output}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/tools/image_generator.py b/examples/tools/image_generator.py new file mode 100644 index 000000000..fd6fcc6ba --- /dev/null +++ b/examples/tools/image_generator.py @@ -0,0 +1,54 @@ +import asyncio +import base64 +import os +import subprocess +import sys +import tempfile + +from agents import Agent, ImageGenerationTool, Runner, trace + + +def open_file(path: str) -> None: + if sys.platform.startswith("darwin"): + subprocess.run(["open", path], check=False) # macOS + elif os.name == "nt": # Windows + os.astartfile(path) # type: ignore + elif os.name == "posix": + subprocess.run(["xdg-open", path], check=False) # Linux/Unix + else: + print(f"Don't know how to open files on this platform: {sys.platform}") + + +async def main(): + agent = Agent( + name="Image generator", + instructions="You are a helpful agent.", + tools=[ + ImageGenerationTool( + tool_config={"type": "image_generation", "quality": "low"}, + ) + ], + ) + + with trace("Image generation example"): + print("Generating image, this may take a while...") + result = await Runner.run( + agent, "Create an image of a frog eating a pizza, comic book style." + ) + print(result.final_output) + for item in result.new_items: + if ( + item.type == "tool_call_item" + and item.raw_item.type == "image_generation_call" + and (img_result := item.raw_item.result) + ): + with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp: + tmp.write(base64.b64decode(img_result)) + temp_path = tmp.name + + # Open the image + open_file(temp_path) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/agents/__init__.py b/src/agents/__init__.py index 36c26b80d..58949157a 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -54,11 +54,16 @@ StreamEvent, ) from .tool import ( + CodeInterpreterTool, ComputerTool, FileSearchTool, FunctionTool, FunctionToolResult, HostedMCPTool, + ImageGenerationTool, + LocalShellCommandRequest, + LocalShellExecutor, + LocalShellTool, MCPToolApprovalFunction, MCPToolApprovalFunctionResult, MCPToolApprovalRequest, @@ -210,6 +215,11 @@ def enable_verbose_stdout_logging(): "FunctionToolResult", "ComputerTool", "FileSearchTool", + "CodeInterpreterTool", + "ImageGenerationTool", + "LocalShellCommandRequest", + "LocalShellExecutor", + "LocalShellTool", "Tool", "WebSearchTool", "HostedMCPTool", diff --git a/src/agents/_run_impl.py b/src/agents/_run_impl.py index ab1e78797..2cfa270e0 100644 --- a/src/agents/_run_impl.py +++ b/src/agents/_run_impl.py @@ -14,6 +14,9 @@ ResponseFunctionWebSearch, ResponseOutputMessage, ) +from openai.types.responses.response_code_interpreter_tool_call import ( + ResponseCodeInterpreterToolCall, +) from openai.types.responses.response_computer_tool_call import ( ActionClick, ActionDoubleClick, @@ -26,7 +29,12 @@ ActionWait, ) from openai.types.responses.response_input_param import ComputerCallOutput, McpApprovalResponse -from openai.types.responses.response_output_item import McpApprovalRequest, McpCall, McpListTools +from openai.types.responses.response_output_item import ( + ImageGenerationCall, + LocalShellCall, + McpApprovalRequest, + McpListTools, +) from openai.types.responses.response_reasoning_item import ResponseReasoningItem from .agent import Agent, ToolsToFinalOutputResult @@ -61,6 +69,8 @@ FunctionTool, FunctionToolResult, HostedMCPTool, + LocalShellCommandRequest, + LocalShellTool, MCPToolApprovalRequest, Tool, ) @@ -129,12 +139,19 @@ class ToolRunMCPApprovalRequest: mcp_tool: HostedMCPTool +@dataclass +class ToolRunLocalShellCall: + tool_call: LocalShellCall + local_shell_tool: LocalShellTool + + @dataclass class ProcessedResponse: new_items: list[RunItem] handoffs: list[ToolRunHandoff] functions: list[ToolRunFunction] computer_actions: list[ToolRunComputerAction] + local_shell_calls: list[ToolRunLocalShellCall] tools_used: list[str] # Names of all tools used, including hosted tools mcp_approval_requests: list[ToolRunMCPApprovalRequest] # Only requests with callbacks @@ -146,6 +163,7 @@ def has_tools_or_approvals_to_run(self) -> bool: self.handoffs, self.functions, self.computer_actions, + self.local_shell_calls, self.mcp_approval_requests, ] ) @@ -371,11 +389,15 @@ def process_model_response( run_handoffs = [] functions = [] computer_actions = [] + local_shell_calls = [] mcp_approval_requests = [] tools_used: list[str] = [] handoff_map = {handoff.tool_name: handoff for handoff in handoffs} function_map = {tool.name: tool for tool in all_tools if isinstance(tool, FunctionTool)} computer_tool = next((tool for tool in all_tools if isinstance(tool, ComputerTool)), None) + local_shell_tool = next( + (tool for tool in all_tools if isinstance(tool, LocalShellTool)), None + ) hosted_mcp_server_map = { tool.tool_config["server_label"]: tool for tool in all_tools @@ -434,9 +456,29 @@ def process_model_response( ) elif isinstance(output, McpListTools): items.append(MCPListToolsItem(raw_item=output, agent=agent)) - elif isinstance(output, McpCall): + elif isinstance(output, ImageGenerationCall): + items.append(ToolCallItem(raw_item=output, agent=agent)) + tools_used.append("image_generation") + elif isinstance(output, ResponseCodeInterpreterToolCall): items.append(ToolCallItem(raw_item=output, agent=agent)) - tools_used.append(output.name) + tools_used.append("code_interpreter") + elif isinstance(output, LocalShellCall): + items.append(ToolCallItem(raw_item=output, agent=agent)) + tools_used.append("local_shell") + if not local_shell_tool: + _error_tracing.attach_error_to_current_span( + SpanError( + message="Local shell tool not found", + data={}, + ) + ) + raise ModelBehaviorError( + "Model produced local shell call without a local shell tool." + ) + local_shell_calls.append( + ToolRunLocalShellCall(tool_call=output, local_shell_tool=local_shell_tool) + ) + elif not isinstance(output, ResponseFunctionToolCall): logger.warning(f"Unexpected output type, ignoring: {type(output)}") continue @@ -478,6 +520,7 @@ def process_model_response( handoffs=run_handoffs, functions=functions, computer_actions=computer_actions, + local_shell_calls=local_shell_calls, tools_used=tools_used, mcp_approval_requests=mcp_approval_requests, ) @@ -552,6 +595,30 @@ async def run_single_tool( for tool_run, result in zip(tool_runs, results) ] + @classmethod + async def execute_local_shell_calls( + cls, + *, + agent: Agent[TContext], + calls: list[ToolRunLocalShellCall], + context_wrapper: RunContextWrapper[TContext], + hooks: RunHooks[TContext], + config: RunConfig, + ) -> list[RunItem]: + results: list[RunItem] = [] + # Need to run these serially, because each call can affect the local shell state + for call in calls: + results.append( + await LocalShellAction.execute( + agent=agent, + call=call, + hooks=hooks, + context_wrapper=context_wrapper, + config=config, + ) + ) + return results + @classmethod async def execute_computer_actions( cls, @@ -1021,3 +1088,54 @@ async def _get_screenshot_async( await computer.wait() return await computer.screenshot() + + +class LocalShellAction: + @classmethod + async def execute( + cls, + *, + agent: Agent[TContext], + call: ToolRunLocalShellCall, + hooks: RunHooks[TContext], + context_wrapper: RunContextWrapper[TContext], + config: RunConfig, + ) -> RunItem: + await asyncio.gather( + hooks.on_tool_start(context_wrapper, agent, call.local_shell_tool), + ( + agent.hooks.on_tool_start(context_wrapper, agent, call.local_shell_tool) + if agent.hooks + else _coro.noop_coroutine() + ), + ) + + request = LocalShellCommandRequest( + ctx_wrapper=context_wrapper, + data=call.tool_call, + ) + output = call.local_shell_tool.executor(request) + if inspect.isawaitable(output): + result = await output + else: + result = output + + await asyncio.gather( + hooks.on_tool_end(context_wrapper, agent, call.local_shell_tool, result), + ( + agent.hooks.on_tool_end(context_wrapper, agent, call.local_shell_tool, result) + if agent.hooks + else _coro.noop_coroutine() + ), + ) + + return ToolCallOutputItem( + agent=agent, + output=output, + raw_item={ + "type": "local_shell_call_output", + "id": call.tool_call.call_id, + "output": result, + # "id": "out" + call.tool_call.id, # TODO remove this, it should be optional + }, + ) diff --git a/src/agents/items.py b/src/agents/items.py index 65a911798..64797ad22 100644 --- a/src/agents/items.py +++ b/src/agents/items.py @@ -18,12 +18,22 @@ ResponseOutputText, ResponseStreamEvent, ) +from openai.types.responses.response_code_interpreter_tool_call import ( + ResponseCodeInterpreterToolCall, +) from openai.types.responses.response_input_item_param import ( ComputerCallOutput, FunctionCallOutput, + LocalShellCallOutput, McpApprovalResponse, ) -from openai.types.responses.response_output_item import McpApprovalRequest, McpCall, McpListTools +from openai.types.responses.response_output_item import ( + ImageGenerationCall, + LocalShellCall, + McpApprovalRequest, + McpCall, + McpListTools, +) from openai.types.responses.response_reasoning_item import ResponseReasoningItem from pydantic import BaseModel from typing_extensions import TypeAlias @@ -114,6 +124,9 @@ class HandoffOutputItem(RunItemBase[TResponseInputItem]): ResponseFileSearchToolCall, ResponseFunctionWebSearch, McpCall, + ResponseCodeInterpreterToolCall, + ImageGenerationCall, + LocalShellCall, ] """A type that represents a tool call item.""" @@ -129,10 +142,12 @@ class ToolCallItem(RunItemBase[ToolCallItemTypes]): @dataclass -class ToolCallOutputItem(RunItemBase[Union[FunctionCallOutput, ComputerCallOutput]]): +class ToolCallOutputItem( + RunItemBase[Union[FunctionCallOutput, ComputerCallOutput, LocalShellCallOutput]] +): """Represents the output of a tool call.""" - raw_item: FunctionCallOutput | ComputerCallOutput + raw_item: FunctionCallOutput | ComputerCallOutput | LocalShellCallOutput """The raw item from the model.""" output: Any diff --git a/src/agents/models/openai_responses.py b/src/agents/models/openai_responses.py index 65a4f5caf..86c8e69cb 100644 --- a/src/agents/models/openai_responses.py +++ b/src/agents/models/openai_responses.py @@ -24,7 +24,17 @@ from ..handoffs import Handoff from ..items import ItemHelpers, ModelResponse, TResponseInputItem from ..logger import logger -from ..tool import ComputerTool, FileSearchTool, FunctionTool, HostedMCPTool, Tool, WebSearchTool +from ..tool import ( + CodeInterpreterTool, + ComputerTool, + FileSearchTool, + FunctionTool, + HostedMCPTool, + ImageGenerationTool, + LocalShellTool, + Tool, + WebSearchTool, +) from ..tracing import SpanError, response_span from ..usage import Usage from ..version import __version__ @@ -295,6 +305,18 @@ def convert_tool_choice( return { "type": "computer_use_preview", } + elif tool_choice == "image_generation": + return { + "type": "image_generation", + } + elif tool_choice == "code_interpreter": + return { + "type": "code_interpreter", + } + elif tool_choice == "mcp": + return { + "type": "mcp", + } else: return { "type": "function", @@ -386,6 +408,17 @@ def _convert_tool(cls, tool: Tool) -> tuple[ToolParam, ResponseIncludable | None elif isinstance(tool, HostedMCPTool): converted_tool = tool.tool_config includes = None + elif isinstance(tool, ImageGenerationTool): + converted_tool = tool.tool_config + includes = None + elif isinstance(tool, CodeInterpreterTool): + converted_tool = tool.tool_config + includes = None + elif isinstance(tool, LocalShellTool): + converted_tool = { + "type": "local_shell", + } + includes = None else: raise UserError(f"Unknown tool type: {type(tool)}, tool") diff --git a/src/agents/tool.py b/src/agents/tool.py index 3bcd57c2e..fd5a21c89 100644 --- a/src/agents/tool.py +++ b/src/agents/tool.py @@ -7,8 +7,8 @@ from typing import Any, Callable, Literal, Union, overload from openai.types.responses.file_search_tool_param import Filters, RankingOptions -from openai.types.responses.response_output_item import McpApprovalRequest -from openai.types.responses.tool_param import Mcp +from openai.types.responses.response_output_item import LocalShellCall, McpApprovalRequest +from openai.types.responses.tool_param import CodeInterpreter, ImageGeneration, Mcp from openai.types.responses.web_search_tool_param import UserLocation from pydantic import ValidationError from typing_extensions import Concatenate, NotRequired, ParamSpec, TypedDict @@ -180,7 +180,67 @@ def name(self): return "hosted_mcp" -Tool = Union[FunctionTool, FileSearchTool, WebSearchTool, ComputerTool, HostedMCPTool] +@dataclass +class CodeInterpreterTool: + """A tool that allows the LLM to execute code in a sandboxed environment.""" + + tool_config: CodeInterpreter + """The tool config, which includes the container and other settings.""" + + @property + def name(self): + return "code_interpreter" + + +@dataclass +class ImageGenerationTool: + """A tool that allows the LLM to generate images.""" + + tool_config: ImageGeneration + """The tool config, which image generation settings.""" + + @property + def name(self): + return "image_generation" + + +@dataclass +class LocalShellCommandRequest: + """A request to execute a command on a shell.""" + + ctx_wrapper: RunContextWrapper[Any] + """The run context.""" + + data: LocalShellCall + """The data from the local shell tool call.""" + + +LocalShellExecutor = Callable[[LocalShellCommandRequest], MaybeAwaitable[str]] +"""A function that executes a command on a shell.""" + + +@dataclass +class LocalShellTool: + """A tool that allows the LLM to execute commands on a shell.""" + + executor: LocalShellExecutor + """A function that executes a command on a shell.""" + + @property + def name(self): + return "local_shell" + + +Tool = Union[ + FunctionTool, + FileSearchTool, + WebSearchTool, + ComputerTool, + HostedMCPTool, + LocalShellTool, + ImageGenerationTool, + CodeInterpreterTool, +] """A tool that can be used in an agent.""" @@ -358,13 +418,3 @@ def decorator(real_func: ToolFunction[...]) -> FunctionTool: return _create_function_tool(real_func) return decorator - return decorator - return decorator - return decorator - return decorator - return decorator - return decorator - return decorator - return decorator - return decorator - return decorator From 1992be3e8d1746164f0f47f3e2001de4ab4059b9 Mon Sep 17 00:00:00 2001 From: Rohan Mehta Date: Wed, 21 May 2025 16:06:13 -0400 Subject: [PATCH 06/14] v0.0.16 (#733) --- pyproject.toml | 2 +- uv.lock | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 200ac2485..38a2f2b64 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "openai-agents" -version = "0.0.15" +version = "0.0.16" description = "OpenAI Agents SDK" readme = "README.md" requires-python = ">=3.9" diff --git a/uv.lock b/uv.lock index 7a0cb1e6b..6f2f3f843 100644 --- a/uv.lock +++ b/uv.lock @@ -1480,7 +1480,7 @@ wheels = [ [[package]] name = "openai-agents" -version = "0.0.15" +version = "0.0.16" source = { editable = "." } dependencies = [ { name = "griffe" }, From 1364f4408e8e0d8bec1d6ac8337f68b3b888a4a7 Mon Sep 17 00:00:00 2001 From: Andrew Han Date: Wed, 21 May 2025 14:59:47 -0700 Subject: [PATCH 07/14] fix Gemini token validation issue with LiteLLM (#735) Fix for #734 --- src/agents/extensions/models/litellm_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/agents/extensions/models/litellm_model.py b/src/agents/extensions/models/litellm_model.py index ffb2c3c1c..49e2d42d7 100644 --- a/src/agents/extensions/models/litellm_model.py +++ b/src/agents/extensions/models/litellm_model.py @@ -111,12 +111,12 @@ async def get_response( input_tokens_details=InputTokensDetails( cached_tokens=getattr( response_usage.prompt_tokens_details, "cached_tokens", 0 - ) + ) or 0 ), output_tokens_details=OutputTokensDetails( reasoning_tokens=getattr( response_usage.completion_tokens_details, "reasoning_tokens", 0 - ) + ) or 0 ), ) if response.usage From 63c85bb720462416ab012ef575d4b92fd2191d8f Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Thu, 22 May 2025 14:21:50 +1000 Subject: [PATCH 08/14] Initial plan for issue (#13)=]====== Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> From e6aaeb765cdb164675aed66500b5ae2272342342 Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Thu, 22 May 2025 21:51:29 +1000 Subject: [PATCH 09/14] Initial plan for issue (#14) Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> From 0ae9ed3a5a0ac5eae6145459e3c0a358f232fb9c Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 24 May 2025 03:23:57 +0000 Subject: [PATCH 10/14] Initial plan for issue From 6b107ce497442a627d523523749778587b69d71a Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 24 May 2025 03:41:07 +0000 Subject: [PATCH 11/14] Add executive assistant and deepgram voice support Co-authored-by: TGreen87 <170493160+TGreen87@users.noreply.github.com> --- src/agents/__init__.py | 3 ++ src/agents/executive_assistant/__init__.py | 37 ++++++++++++++ src/agents/executive_assistant/memory.py | 40 +++++++++++++++ src/agents/executive_assistant/rag.py | 18 +++++++ src/agents/executive_assistant/tools.py | 17 +++++++ src/agents/voice/__init__.py | 6 +++ .../voice/models/deepgram_model_provider.py | 33 ++++++++++++ src/agents/voice/models/deepgram_stt.py | 51 +++++++++++++++++++ src/agents/voice/models/deepgram_tts.py | 29 +++++++++++ tests/executive_assistant/__init__.py | 1 + tests/executive_assistant/test_agent.py | 27 ++++++++++ tests/voice/test_deepgram_models.py | 14 +++++ 12 files changed, 276 insertions(+) create mode 100644 src/agents/executive_assistant/__init__.py create mode 100644 src/agents/executive_assistant/memory.py create mode 100644 src/agents/executive_assistant/rag.py create mode 100644 src/agents/executive_assistant/tools.py create mode 100644 src/agents/voice/models/deepgram_model_provider.py create mode 100644 src/agents/voice/models/deepgram_stt.py create mode 100644 src/agents/voice/models/deepgram_tts.py create mode 100644 tests/executive_assistant/__init__.py create mode 100644 tests/executive_assistant/test_agent.py create mode 100644 tests/voice/test_deepgram_models.py diff --git a/src/agents/__init__.py b/src/agents/__init__.py index 58949157a..76efb8a59 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -6,6 +6,7 @@ from . import _config from .agent import Agent, ToolsToFinalOutputFunction, ToolsToFinalOutputResult +from .executive_assistant import ExecutiveAssistantState, executive_assistant_agent from .agent_output import AgentOutputSchema, AgentOutputSchemaBase from .computer import AsyncComputer, Button, Computer, Environment from .exceptions import ( @@ -158,6 +159,8 @@ def enable_verbose_stdout_logging(): "Agent", "ToolsToFinalOutputFunction", "ToolsToFinalOutputResult", + "ExecutiveAssistantState", + "executive_assistant_agent", "Runner", "Model", "ModelProvider", diff --git a/src/agents/executive_assistant/__init__.py b/src/agents/executive_assistant/__init__.py new file mode 100644 index 000000000..1d91ac91d --- /dev/null +++ b/src/agents/executive_assistant/__init__.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +# mypy: ignore-errors + +"""Executive Assistant agent. + +This agent orchestrates other agents and provides voice capabilities using +Deepgram STT and TTS models. It maintains short-term and long-term memory and +can retrieve information via a simple RAG component. +""" +from agents import Agent # noqa: E402 + +from .memory import LongTermMemory, ShortTermMemory # noqa: E402 +from .rag import Retriever # noqa: E402 +from .tools import get_calendar_events, send_email # noqa: E402 + + +class ExecutiveAssistantState: + """Holds resources used by the Executive Assistant.""" + + def __init__(self, memory_path: str = "memory.json") -> None: + self.short_memory = ShortTermMemory() + self.long_memory = LongTermMemory(memory_path) + self.retriever = Retriever() + + +executive_assistant_agent = Agent( + name="ExecutiveAssistant", + instructions=( + "You are an executive assistant. Use the available tools to help the user. " + "Remember important facts during the conversation for later retrieval." + ), + model="gpt-4o-mini", + tools=[get_calendar_events, send_email], +) + +__all__ = ["ExecutiveAssistantState", "executive_assistant_agent"] \ No newline at end of file diff --git a/src/agents/executive_assistant/memory.py b/src/agents/executive_assistant/memory.py new file mode 100644 index 000000000..4a73573d6 --- /dev/null +++ b/src/agents/executive_assistant/memory.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any + + +class ShortTermMemory: + """In-memory store for conversation turns.""" + + def __init__(self) -> None: + self._messages: list[dict[str, str]] = [] + + def add(self, role: str, content: str) -> None: + """Add a message to memory.""" + self._messages.append({"role": role, "content": content}) + + def to_list(self) -> list[dict[str, str]]: + """Return the last 20 messages.""" + return self._messages[-20:] + + +class LongTermMemory: + """Simple file backed memory store.""" + + def __init__(self, path: str | Path) -> None: + self._path = Path(path) + if self._path.exists(): + self._data = json.loads(self._path.read_text()) + else: + self._data = [] + + def add(self, item: Any) -> None: + """Persist an item to disk.""" + self._data.append(item) + self._path.write_text(json.dumps(self._data)) + + def all(self) -> list[Any]: + """Return all persisted items.""" + return list(self._data) \ No newline at end of file diff --git a/src/agents/executive_assistant/rag.py b/src/agents/executive_assistant/rag.py new file mode 100644 index 000000000..c1395163d --- /dev/null +++ b/src/agents/executive_assistant/rag.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +from collections.abc import Iterable + + +class Retriever: + """Very small RAG retriever stub.""" + + def __init__(self, corpus: Iterable[str] | None = None) -> None: + self._corpus = list(corpus or []) + + def add(self, document: str) -> None: + """Add a document to the corpus.""" + self._corpus.append(document) + + def search(self, query: str) -> list[str]: + """Return documents containing the query string.""" + return [doc for doc in self._corpus if query.lower() in doc.lower()] \ No newline at end of file diff --git a/src/agents/executive_assistant/tools.py b/src/agents/executive_assistant/tools.py new file mode 100644 index 000000000..24f38a0d8 --- /dev/null +++ b/src/agents/executive_assistant/tools.py @@ -0,0 +1,17 @@ +from __future__ import annotations + +from agents.tool import function_tool + + +@function_tool +def get_calendar_events(date: str) -> str: + """Retrieve calendar events for a given date.""" + # TODO: Integrate with calendar API. + return f"No events found for {date}." + + +@function_tool +def send_email(recipient: str, subject: str, body: str) -> str: + """Send a simple email.""" + # TODO: Integrate with email service. + return "Email sent." \ No newline at end of file diff --git a/src/agents/voice/__init__.py b/src/agents/voice/__init__.py index e11ee4467..abaa98db7 100644 --- a/src/agents/voice/__init__.py +++ b/src/agents/voice/__init__.py @@ -10,6 +10,9 @@ TTSVoice, VoiceModelProvider, ) +from .models.deepgram_model_provider import DeepgramVoiceModelProvider +from .models.deepgram_stt import DeepgramSTTModel +from .models.deepgram_tts import DeepgramTTSModel from .models.openai_model_provider import OpenAIVoiceModelProvider from .models.openai_stt import OpenAISTTModel, OpenAISTTTranscriptionSession from .models.openai_tts import OpenAITTSModel @@ -38,6 +41,9 @@ "OpenAIVoiceModelProvider", "OpenAISTTModel", "OpenAITTSModel", + "DeepgramVoiceModelProvider", + "DeepgramSTTModel", + "DeepgramTTSModel", "VoiceStreamEventAudio", "VoiceStreamEventLifecycle", "VoiceStreamEvent", diff --git a/src/agents/voice/models/deepgram_model_provider.py b/src/agents/voice/models/deepgram_model_provider.py new file mode 100644 index 000000000..de6c2100a --- /dev/null +++ b/src/agents/voice/models/deepgram_model_provider.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +import httpx # type: ignore + +from ..model import STTModel, TTSModel, VoiceModelProvider +from .deepgram_stt import DeepgramSTTModel +from .deepgram_tts import DeepgramTTSModel + +DEFAULT_STT_MODEL = "nova-3" +DEFAULT_TTS_MODEL = "aura-2" + + +class DeepgramVoiceModelProvider(VoiceModelProvider): + """Voice model provider for Deepgram APIs.""" + + def __init__(self, api_key: str, *, client: httpx.AsyncClient | None = None) -> None: + self._api_key = api_key + self._client = client + + def _get_client(self) -> httpx.AsyncClient: + if self._client is None: + self._client = httpx.AsyncClient() + return self._client + + def get_stt_model(self, model_name: str | None) -> STTModel: + return DeepgramSTTModel( + model_name or DEFAULT_STT_MODEL, self._api_key, client=self._get_client() + ) + + def get_tts_model(self, model_name: str | None) -> TTSModel: + return DeepgramTTSModel( + model_name or DEFAULT_TTS_MODEL, self._api_key, client=self._get_client() + ) \ No newline at end of file diff --git a/src/agents/voice/models/deepgram_stt.py b/src/agents/voice/models/deepgram_stt.py new file mode 100644 index 000000000..aee004182 --- /dev/null +++ b/src/agents/voice/models/deepgram_stt.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +from typing import Any + +import httpx # type: ignore + +from ..input import AudioInput, StreamedAudioInput +from ..model import StreamedTranscriptionSession, STTModel, STTModelSettings + + +class DeepgramSTTModel(STTModel): + """Speech-to-text model using Deepgram Nova 3.""" + + def __init__( + self, model: str, api_key: str, *, client: httpx.AsyncClient | None = None + ) -> None: + self.model = model + self.api_key = api_key + self._client = client or httpx.AsyncClient() + + @property + def model_name(self) -> str: + return self.model + + async def transcribe( + self, + input: AudioInput, + settings: STTModelSettings, + trace_include_sensitive_data: bool, + trace_include_sensitive_audio_data: bool, + ) -> str: + url = f"https://api.deepgram.com/v1/listen?model={self.model}" + headers = {"Authorization": f"Token {self.api_key}"} + filename, data, content_type = input.to_audio_file() + response = await self._client.post(url, headers=headers, content=data.getvalue()) + payload: dict[str, Any] = response.json() + return ( + payload.get("results", {}) + .get("channels", [{}])[0] + .get("alternatives", [{}])[0] + .get("transcript", "") + ) + + async def create_session( + self, + input: StreamedAudioInput, + settings: STTModelSettings, + trace_include_sensitive_data: bool, + trace_include_sensitive_audio_data: bool, + ) -> StreamedTranscriptionSession: + raise NotImplementedError("Streaming transcription is not implemented.") \ No newline at end of file diff --git a/src/agents/voice/models/deepgram_tts.py b/src/agents/voice/models/deepgram_tts.py new file mode 100644 index 000000000..56dc340ea --- /dev/null +++ b/src/agents/voice/models/deepgram_tts.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +from collections.abc import AsyncIterator + +import httpx # type: ignore + +from ..model import TTSModel, TTSModelSettings + + +class DeepgramTTSModel(TTSModel): + """Text-to-speech model using Deepgram Aura 2.""" + + def __init__( + self, model: str, api_key: str, *, client: httpx.AsyncClient | None = None + ) -> None: + self.model = model + self.api_key = api_key + self._client = client or httpx.AsyncClient() + + @property + def model_name(self) -> str: + return self.model + + async def run(self, text: str, settings: TTSModelSettings) -> AsyncIterator[bytes]: + url = "https://api.deepgram.com/v1/speak" + headers = {"Authorization": f"Token {self.api_key}", "Content-Type": "application/json"} + payload = {"text": text, "model": self.model, "voice": settings.voice or "aura-2"} + response = await self._client.post(url, headers=headers, json=payload) + yield response.content \ No newline at end of file diff --git a/tests/executive_assistant/__init__.py b/tests/executive_assistant/__init__.py new file mode 100644 index 000000000..e4396e9cb --- /dev/null +++ b/tests/executive_assistant/__init__.py @@ -0,0 +1 @@ +# Executive Assistant tests \ No newline at end of file diff --git a/tests/executive_assistant/test_agent.py b/tests/executive_assistant/test_agent.py new file mode 100644 index 000000000..ff63bd2f0 --- /dev/null +++ b/tests/executive_assistant/test_agent.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +import pytest + +from agents import Agent, Runner +from agents.agent import ToolsToFinalOutputResult +from agents.executive_assistant import executive_assistant_agent +from tests.fake_model import FakeModel + + +@pytest.mark.asyncio +async def test_agent_runs_with_fake_model() -> None: + model = FakeModel() + agent = Agent( + name=executive_assistant_agent.name, + instructions=executive_assistant_agent.instructions, + tools=executive_assistant_agent.tools, + model=model, + ) + model.set_next_output( + [ + {"role": "assistant", "content": "Hello"}, + ] + ) + + result: ToolsToFinalOutputResult = await Runner.run(agent, "hi") + assert result.final_output == "Hello" \ No newline at end of file diff --git a/tests/voice/test_deepgram_models.py b/tests/voice/test_deepgram_models.py new file mode 100644 index 000000000..1dd0e584e --- /dev/null +++ b/tests/voice/test_deepgram_models.py @@ -0,0 +1,14 @@ +from __future__ import annotations + +import pytest + +from agents.voice.models.deepgram_model_provider import DeepgramVoiceModelProvider + + +@pytest.mark.asyncio +async def test_provider_returns_models() -> None: + provider = DeepgramVoiceModelProvider(api_key="key") + stt = provider.get_stt_model(None) + tts = provider.get_tts_model(None) + assert stt.model_name == "nova-3" + assert tts.model_name == "aura-2" \ No newline at end of file From f6828e5912537a2e23c629e28c7e44ef0b5359e3 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 1 Jun 2025 11:45:07 +0000 Subject: [PATCH 12/14] Fix circular import in executive assistant module Co-authored-by: TGreen87 <170493160+TGreen87@users.noreply.github.com> --- src/agents/executive_assistant/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/agents/executive_assistant/__init__.py b/src/agents/executive_assistant/__init__.py index 1d91ac91d..8b73dc5c7 100644 --- a/src/agents/executive_assistant/__init__.py +++ b/src/agents/executive_assistant/__init__.py @@ -8,7 +8,7 @@ Deepgram STT and TTS models. It maintains short-term and long-term memory and can retrieve information via a simple RAG component. """ -from agents import Agent # noqa: E402 +from ..agent import Agent # noqa: E402 from .memory import LongTermMemory, ShortTermMemory # noqa: E402 from .rag import Retriever # noqa: E402 From a8b731b72d0aeaa0a4b8ad0ec7fa614f911153f9 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 1 Jun 2025 11:46:29 +0000 Subject: [PATCH 13/14] Fix code style and formatting issues with ruff Co-authored-by: TGreen87 <170493160+TGreen87@users.noreply.github.com> --- src/agents/executive_assistant/__init__.py | 3 +-- src/agents/executive_assistant/memory.py | 2 +- src/agents/executive_assistant/rag.py | 2 +- src/agents/executive_assistant/tools.py | 2 +- src/agents/voice/models/deepgram_model_provider.py | 2 +- src/agents/voice/models/deepgram_stt.py | 2 +- src/agents/voice/models/deepgram_tts.py | 2 +- tests/executive_assistant/__init__.py | 2 +- tests/executive_assistant/test_agent.py | 2 +- tests/voice/test_deepgram_models.py | 2 +- 10 files changed, 10 insertions(+), 11 deletions(-) diff --git a/src/agents/executive_assistant/__init__.py b/src/agents/executive_assistant/__init__.py index 8b73dc5c7..a60dd9295 100644 --- a/src/agents/executive_assistant/__init__.py +++ b/src/agents/executive_assistant/__init__.py @@ -9,7 +9,6 @@ can retrieve information via a simple RAG component. """ from ..agent import Agent # noqa: E402 - from .memory import LongTermMemory, ShortTermMemory # noqa: E402 from .rag import Retriever # noqa: E402 from .tools import get_calendar_events, send_email # noqa: E402 @@ -34,4 +33,4 @@ def __init__(self, memory_path: str = "memory.json") -> None: tools=[get_calendar_events, send_email], ) -__all__ = ["ExecutiveAssistantState", "executive_assistant_agent"] \ No newline at end of file +__all__ = ["ExecutiveAssistantState", "executive_assistant_agent"] diff --git a/src/agents/executive_assistant/memory.py b/src/agents/executive_assistant/memory.py index 4a73573d6..5b2867181 100644 --- a/src/agents/executive_assistant/memory.py +++ b/src/agents/executive_assistant/memory.py @@ -37,4 +37,4 @@ def add(self, item: Any) -> None: def all(self) -> list[Any]: """Return all persisted items.""" - return list(self._data) \ No newline at end of file + return list(self._data) diff --git a/src/agents/executive_assistant/rag.py b/src/agents/executive_assistant/rag.py index c1395163d..c009fffae 100644 --- a/src/agents/executive_assistant/rag.py +++ b/src/agents/executive_assistant/rag.py @@ -15,4 +15,4 @@ def add(self, document: str) -> None: def search(self, query: str) -> list[str]: """Return documents containing the query string.""" - return [doc for doc in self._corpus if query.lower() in doc.lower()] \ No newline at end of file + return [doc for doc in self._corpus if query.lower() in doc.lower()] diff --git a/src/agents/executive_assistant/tools.py b/src/agents/executive_assistant/tools.py index 24f38a0d8..5fa2862a9 100644 --- a/src/agents/executive_assistant/tools.py +++ b/src/agents/executive_assistant/tools.py @@ -14,4 +14,4 @@ def get_calendar_events(date: str) -> str: def send_email(recipient: str, subject: str, body: str) -> str: """Send a simple email.""" # TODO: Integrate with email service. - return "Email sent." \ No newline at end of file + return "Email sent." diff --git a/src/agents/voice/models/deepgram_model_provider.py b/src/agents/voice/models/deepgram_model_provider.py index de6c2100a..cd9e9e198 100644 --- a/src/agents/voice/models/deepgram_model_provider.py +++ b/src/agents/voice/models/deepgram_model_provider.py @@ -30,4 +30,4 @@ def get_stt_model(self, model_name: str | None) -> STTModel: def get_tts_model(self, model_name: str | None) -> TTSModel: return DeepgramTTSModel( model_name or DEFAULT_TTS_MODEL, self._api_key, client=self._get_client() - ) \ No newline at end of file + ) diff --git a/src/agents/voice/models/deepgram_stt.py b/src/agents/voice/models/deepgram_stt.py index aee004182..e2f633d83 100644 --- a/src/agents/voice/models/deepgram_stt.py +++ b/src/agents/voice/models/deepgram_stt.py @@ -48,4 +48,4 @@ async def create_session( trace_include_sensitive_data: bool, trace_include_sensitive_audio_data: bool, ) -> StreamedTranscriptionSession: - raise NotImplementedError("Streaming transcription is not implemented.") \ No newline at end of file + raise NotImplementedError("Streaming transcription is not implemented.") diff --git a/src/agents/voice/models/deepgram_tts.py b/src/agents/voice/models/deepgram_tts.py index 56dc340ea..b035f35d4 100644 --- a/src/agents/voice/models/deepgram_tts.py +++ b/src/agents/voice/models/deepgram_tts.py @@ -26,4 +26,4 @@ async def run(self, text: str, settings: TTSModelSettings) -> AsyncIterator[byte headers = {"Authorization": f"Token {self.api_key}", "Content-Type": "application/json"} payload = {"text": text, "model": self.model, "voice": settings.voice or "aura-2"} response = await self._client.post(url, headers=headers, json=payload) - yield response.content \ No newline at end of file + yield response.content diff --git a/tests/executive_assistant/__init__.py b/tests/executive_assistant/__init__.py index e4396e9cb..fa3eb6165 100644 --- a/tests/executive_assistant/__init__.py +++ b/tests/executive_assistant/__init__.py @@ -1 +1 @@ -# Executive Assistant tests \ No newline at end of file +# Executive Assistant tests diff --git a/tests/executive_assistant/test_agent.py b/tests/executive_assistant/test_agent.py index ff63bd2f0..b8e76e95d 100644 --- a/tests/executive_assistant/test_agent.py +++ b/tests/executive_assistant/test_agent.py @@ -24,4 +24,4 @@ async def test_agent_runs_with_fake_model() -> None: ) result: ToolsToFinalOutputResult = await Runner.run(agent, "hi") - assert result.final_output == "Hello" \ No newline at end of file + assert result.final_output == "Hello" diff --git a/tests/voice/test_deepgram_models.py b/tests/voice/test_deepgram_models.py index 1dd0e584e..15d7cea3a 100644 --- a/tests/voice/test_deepgram_models.py +++ b/tests/voice/test_deepgram_models.py @@ -11,4 +11,4 @@ async def test_provider_returns_models() -> None: stt = provider.get_stt_model(None) tts = provider.get_tts_model(None) assert stt.model_name == "nova-3" - assert tts.model_name == "aura-2" \ No newline at end of file + assert tts.model_name == "aura-2" From dac1e396e04dd706dc358fbe825f3e6066c63320 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 1 Jun 2025 11:47:30 +0000 Subject: [PATCH 14/14] Fix another circular import in executive assistant tools Co-authored-by: TGreen87 <170493160+TGreen87@users.noreply.github.com> --- src/agents/executive_assistant/tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/agents/executive_assistant/tools.py b/src/agents/executive_assistant/tools.py index 5fa2862a9..5f9102288 100644 --- a/src/agents/executive_assistant/tools.py +++ b/src/agents/executive_assistant/tools.py @@ -1,6 +1,6 @@ from __future__ import annotations -from agents.tool import function_tool +from ..tool import function_tool @function_tool