diff --git a/src/agents/guardrail.py b/src/agents/guardrail.py index a96f0f7d7..d2ce79400 100644 --- a/src/agents/guardrail.py +++ b/src/agents/guardrail.py @@ -109,7 +109,9 @@ async def run( context: RunContextWrapper[TContext], ) -> InputGuardrailResult: if not callable(self.guardrail_function): - raise UserError(f"Guardrail function must be callable, got {self.guardrail_function}") + raise UserError( + f"Guardrail function must be callable, got {self.guardrail_function}" + ) output = self.guardrail_function(context, agent, input) if inspect.isawaitable(output): @@ -160,7 +162,9 @@ async def run( self, context: RunContextWrapper[TContext], agent: Agent[Any], agent_output: Any ) -> OutputGuardrailResult: if not callable(self.guardrail_function): - raise UserError(f"Guardrail function must be callable, got {self.guardrail_function}") + raise UserError( + f"Guardrail function must be callable, got {self.guardrail_function}" + ) output = self.guardrail_function(context, agent, agent_output) if inspect.isawaitable(output): @@ -183,11 +187,19 @@ async def run( # For InputGuardrail _InputGuardrailFuncSync = Callable[ - [RunContextWrapper[TContext_co], "Agent[Any]", Union[str, list[TResponseInputItem]]], + [ + RunContextWrapper[TContext_co], + "Agent[Any]", + Union[str, list[TResponseInputItem]], + ], GuardrailFunctionOutput, ] _InputGuardrailFuncAsync = Callable[ - [RunContextWrapper[TContext_co], "Agent[Any]", Union[str, list[TResponseInputItem]]], + [ + RunContextWrapper[TContext_co], + "Agent[Any]", + Union[str, list[TResponseInputItem]], + ], Awaitable[GuardrailFunctionOutput], ] @@ -215,9 +227,11 @@ def input_guardrail( def input_guardrail( - func: _InputGuardrailFuncSync[TContext_co] - | _InputGuardrailFuncAsync[TContext_co] - | None = None, + func: ( + _InputGuardrailFuncSync[TContext_co] + | _InputGuardrailFuncAsync[TContext_co] + | None + ) = None, *, name: str | None = None, ) -> ( @@ -284,15 +298,20 @@ def output_guardrail( def output_guardrail( - func: _OutputGuardrailFuncSync[TContext_co] - | _OutputGuardrailFuncAsync[TContext_co] - | None = None, + func: ( + _OutputGuardrailFuncSync[TContext_co] + | _OutputGuardrailFuncAsync[TContext_co] + | None + ) = None, *, name: str | None = None, ) -> ( OutputGuardrail[TContext_co] | Callable[ - [_OutputGuardrailFuncSync[TContext_co] | _OutputGuardrailFuncAsync[TContext_co]], + [ + _OutputGuardrailFuncSync[TContext_co] + | _OutputGuardrailFuncAsync[TContext_co] + ], OutputGuardrail[TContext_co], ] ): @@ -308,9 +327,17 @@ async def my_async_guardrail(...): ... """ def decorator( - f: _OutputGuardrailFuncSync[TContext_co] | _OutputGuardrailFuncAsync[TContext_co], + f: ( + _OutputGuardrailFuncSync[TContext_co] + | _OutputGuardrailFuncAsync[TContext_co] + ), ) -> OutputGuardrail[TContext_co]: - return OutputGuardrail(guardrail_function=f, name=name) + + return OutputGuardrail( + guardrail_function=f, + # Guardrail name defaults to function name when not specified (None). + name=name if name else f.__name__ + ) if func is not None: # Decorator was used without parentheses