From 41975827fb1cbf124cc4153b29b72185af1acfca Mon Sep 17 00:00:00 2001 From: "will.yang" Date: Thu, 10 Apr 2025 09:01:21 +0800 Subject: [PATCH] add overwrite mechanism for stream_options --- src/agents/model_settings.py | 4 ++++ src/agents/models/openai_chatcompletions.py | 21 +++++++++++++++++++-- tests/test_openai_chatcompletions.py | 2 +- 3 files changed, 24 insertions(+), 3 deletions(-) diff --git a/src/agents/model_settings.py b/src/agents/model_settings.py index bac71f58..f29cfa4a 100644 --- a/src/agents/model_settings.py +++ b/src/agents/model_settings.py @@ -54,6 +54,10 @@ class ModelSettings: """Whether to store the generated model response for later retrieval. Defaults to True if not provided.""" + include_usage: bool | None = None + """Whether to include usage chunk. + Defaults to True if not provided.""" + def resolve(self, override: ModelSettings | None) -> ModelSettings: """Produce a new ModelSettings by overlaying any non-None values from the override on top of this instance.""" diff --git a/src/agents/models/openai_chatcompletions.py b/src/agents/models/openai_chatcompletions.py index e0aafad0..807c6512 100644 --- a/src/agents/models/openai_chatcompletions.py +++ b/src/agents/models/openai_chatcompletions.py @@ -521,6 +521,8 @@ async def _fetch_response( reasoning_effort = model_settings.reasoning.effort if model_settings.reasoning else None store = _Converter.get_store_param(self._get_client(), model_settings) + stream_options = _Converter.get_stream_options_param(self._get_client(), model_settings) + ret = await self._get_client().chat.completions.create( model=self.model, messages=converted_messages, @@ -534,7 +536,7 @@ async def _fetch_response( response_format=response_format, parallel_tool_calls=parallel_tool_calls, stream=stream, - stream_options={"include_usage": True} if stream else NOT_GIVEN, + stream_options=self._non_null_or_not_given(stream_options), store=self._non_null_or_not_given(store), reasoning_effort=self._non_null_or_not_given(reasoning_effort), extra_headers=_HEADERS, @@ -568,12 +570,27 @@ def _get_client(self) -> AsyncOpenAI: class _Converter: + + @classmethod + def is_openai(cls, client: AsyncOpenAI): + return str(client.base_url).startswith("https://api.openai.com") + @classmethod def get_store_param(cls, client: AsyncOpenAI, model_settings: ModelSettings) -> bool | None: # Match the behavior of Responses where store is True when not given - default_store = True if str(client.base_url).startswith("https://api.openai.com") else None + default_store = True if cls.is_openai(client) else None return model_settings.store if model_settings.store is not None else default_store + @classmethod + def get_stream_options_param( + cls, client: AsyncOpenAI, model_settings: ModelSettings + ) -> dict[str, bool] | None: + default_include_usage = True if cls.is_openai(client) else None + include_usage = model_settings.include_usage if model_settings.include_usage is not None \ + else default_include_usage + stream_options = {"include_usage": include_usage} if include_usage is not None else None + return stream_options + @classmethod def convert_tool_choice( cls, tool_choice: Literal["auto", "required", "none"] | str | None diff --git a/tests/test_openai_chatcompletions.py b/tests/test_openai_chatcompletions.py index a3198d33..281d7b41 100644 --- a/tests/test_openai_chatcompletions.py +++ b/tests/test_openai_chatcompletions.py @@ -282,7 +282,7 @@ def __init__(self, completions: DummyCompletions) -> None: # Check OpenAI client was called for streaming assert completions.kwargs["stream"] is True assert completions.kwargs["store"] is NOT_GIVEN - assert completions.kwargs["stream_options"] == {"include_usage": True} + assert completions.kwargs["stream_options"] is NOT_GIVEN # Response is a proper openai Response assert isinstance(response, Response) assert response.id == FAKE_RESPONSES_ID