From 6fb8e829477c1a503e3fb85b64be1c3dd09ca368 Mon Sep 17 00:00:00 2001 From: Rohan Mehta Date: Mon, 14 Apr 2025 12:00:38 -0400 Subject: [PATCH] Example for streaming guardrails --- .../agent_patterns/streaming_guardrails.py | 93 +++++++++++++++++++ src/agents/models/openai_chatcompletions.py | 8 +- src/agents/models/openai_responses.py | 2 +- 3 files changed, 99 insertions(+), 4 deletions(-) create mode 100644 examples/agent_patterns/streaming_guardrails.py diff --git a/examples/agent_patterns/streaming_guardrails.py b/examples/agent_patterns/streaming_guardrails.py new file mode 100644 index 00000000..f4db2869 --- /dev/null +++ b/examples/agent_patterns/streaming_guardrails.py @@ -0,0 +1,93 @@ +from __future__ import annotations + +import asyncio + +from openai.types.responses import ResponseTextDeltaEvent +from pydantic import BaseModel, Field + +from agents import Agent, Runner + +""" +This example shows how to use guardrails as the model is streaming. Output guardrails run after the +final output has been generated; this example runs guardails every N tokens, allowing for early +termination if bad output is detected. + +The expected output is that you'll see a bunch of tokens stream in, then the guardrail will trigger +and stop the streaming. +""" + + +agent = Agent( + name="Assistant", + instructions=( + "You are a helpful assistant. You ALWAYS write long responses, making sure to be verbose " + "and detailed." + ), +) + + +class GuardrailOutput(BaseModel): + reasoning: str = Field( + description="Reasoning about whether the response could be understood by a ten year old." + ) + is_readable_by_ten_year_old: bool = Field( + description="Whether the response is understandable by a ten year old." + ) + + +guardrail_agent = Agent( + name="Checker", + instructions=( + "You will be given a question and a response. Your goal is to judge whether the response " + "is simple enough to be understood by a ten year old." + ), + output_type=GuardrailOutput, + model="gpt-4o-mini", +) + + +async def check_guardrail(text: str) -> GuardrailOutput: + result = await Runner.run(guardrail_agent, text) + return result.final_output_as(GuardrailOutput) + + +async def main(): + question = "What is a black hole, and how does it behave?" + result = Runner.run_streamed(agent, question) + current_text = "" + + # We will check the guardrail every N characters + next_guardrail_check_len = 300 + guardrail_task = None + + async for event in result.stream_events(): + if event.type == "raw_response_event" and isinstance(event.data, ResponseTextDeltaEvent): + print(event.data.delta, end="", flush=True) + current_text += event.data.delta + + # Check if it's time to run the guardrail check + # Note that we don't run the guardrail check if there's already a task running. An + # alternate implementation is to have N guardrails running, or cancel the previous + # one. + if len(current_text) >= next_guardrail_check_len and not guardrail_task: + print("Running guardrail check") + guardrail_task = asyncio.create_task(check_guardrail(current_text)) + next_guardrail_check_len += 300 + + # Every iteration of the loop, check if the guardrail has been triggered + if guardrail_task and guardrail_task.done(): + guardrail_result = guardrail_task.result() + if not guardrail_result.is_readable_by_ten_year_old: + print("\n\n================\n\n") + print(f"Guardrail triggered. Reasoning:\n{guardrail_result.reasoning}") + break + + # Do one final check on the final output + guardrail_result = await check_guardrail(current_text) + if not guardrail_result.is_readable_by_ten_year_old: + print("\n\n================\n\n") + print(f"Guardrail triggered. Reasoning:\n{guardrail_result.reasoning}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/agents/models/openai_chatcompletions.py b/src/agents/models/openai_chatcompletions.py index 267efcaf..6978ee30 100644 --- a/src/agents/models/openai_chatcompletions.py +++ b/src/agents/models/openai_chatcompletions.py @@ -572,7 +572,6 @@ def _get_client(self) -> AsyncOpenAI: class _Converter: - @classmethod def is_openai(cls, client: AsyncOpenAI): return str(client.base_url).startswith("https://api.openai.com") @@ -585,11 +584,14 @@ def get_store_param(cls, client: AsyncOpenAI, model_settings: ModelSettings) -> @classmethod def get_stream_options_param( - cls, client: AsyncOpenAI, model_settings: ModelSettings + cls, client: AsyncOpenAI, model_settings: ModelSettings ) -> dict[str, bool] | None: default_include_usage = True if cls.is_openai(client) else None - include_usage = model_settings.include_usage if model_settings.include_usage is not None \ + include_usage = ( + model_settings.include_usage + if model_settings.include_usage is not None else default_include_usage + ) stream_options = {"include_usage": include_usage} if include_usage is not None else None return stream_options diff --git a/src/agents/models/openai_responses.py b/src/agents/models/openai_responses.py index cb5a603f..055ab79b 100644 --- a/src/agents/models/openai_responses.py +++ b/src/agents/models/openai_responses.py @@ -250,7 +250,7 @@ async def _fetch_response( text=response_format, store=self._non_null_or_not_given(model_settings.store), reasoning=self._non_null_or_not_given(model_settings.reasoning), - metadata=self._non_null_or_not_given(model_settings.metadata) + metadata=self._non_null_or_not_given(model_settings.metadata), ) def _get_client(self) -> AsyncOpenAI: