Skip to content

Commit e9667b3

Browse files
committed
Allow previous_response_id to be passed to input guardrails
1 parent af80e3a commit e9667b3

File tree

4 files changed

+61
-16
lines changed

4 files changed

+61
-16
lines changed

src/agents/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@
1818
)
1919
from .guardrail import (
2020
GuardrailFunctionOutput,
21+
InputGuardailInputs,
2122
InputGuardrail,
23+
InputGuardrailFunction,
24+
InputGuardrailFunctionLegacy,
2225
InputGuardrailResult,
2326
OutputGuardrail,
2427
OutputGuardrailResult,
@@ -174,6 +177,9 @@ def enable_verbose_stdout_logging():
174177
"OutputGuardrail",
175178
"OutputGuardrailResult",
176179
"GuardrailFunctionOutput",
180+
"InputGuardailInputs",
181+
"InputGuardrailFunction",
182+
"InputGuardrailFunctionLegacy",
177183
"input_guardrail",
178184
"output_guardrail",
179185
"handoff",

src/agents/_run_impl.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -688,10 +688,11 @@ async def run_single_input_guardrail(
688688
agent: Agent[Any],
689689
guardrail: InputGuardrail[TContext],
690690
input: str | list[TResponseInputItem],
691+
previous_response_id: str | None,
691692
context: RunContextWrapper[TContext],
692693
) -> InputGuardrailResult:
693694
with guardrail_span(guardrail.get_name()) as span_guardrail:
694-
result = await guardrail.run(agent, input, context)
695+
result = await guardrail.run(agent, input, context, previous_response_id)
695696
span_guardrail.span_data.triggered = result.output.tripwire_triggered
696697
return result
697698

src/agents/guardrail.py

Lines changed: 43 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
import inspect
44
from collections.abc import Awaitable
55
from dataclasses import dataclass
6-
from typing import TYPE_CHECKING, Any, Callable, Generic, Union, overload
6+
from typing import TYPE_CHECKING, Any, Callable, Generic, Union, cast, overload
77

8-
from typing_extensions import TypeVar
8+
from typing_extensions import TypeAlias, TypeVar
99

1010
from .exceptions import UserError
1111
from .items import TResponseInputItem
@@ -68,6 +68,31 @@ class OutputGuardrailResult:
6868
"""The output of the guardrail function."""
6969

7070

71+
InputGuardrailFunctionLegacy: TypeAlias = Callable[
72+
[RunContextWrapper[TContext], "Agent[Any]", Union[str, list[TResponseInputItem]]],
73+
MaybeAwaitable[GuardrailFunctionOutput],
74+
]
75+
"""The legacy guardrail function signature, retained for backwards compatibility. Of the form:
76+
def my_guardrail(ctx, agent, input)
77+
"""
78+
79+
80+
@dataclass
81+
class InputGuardailInputs:
82+
agent: Agent[Any]
83+
input: str | list[TResponseInputItem]
84+
previous_response_id: str | None
85+
86+
87+
InputGuardrailFunction: TypeAlias = Callable[
88+
[RunContextWrapper[TContext], InputGuardailInputs],
89+
MaybeAwaitable[GuardrailFunctionOutput],
90+
]
91+
"""The new guardrail function signature, of the form:
92+
def my_guardrail(ctx, inputs)
93+
"""
94+
95+
7196
@dataclass
7297
class InputGuardrail(Generic[TContext]):
7398
"""Input guardrails are checks that run in parallel to the agent's execution.
@@ -82,10 +107,7 @@ class InputGuardrail(Generic[TContext]):
82107
execution will immediately stop and a `InputGuardrailTripwireTriggered` exception will be raised
83108
"""
84109

85-
guardrail_function: Callable[
86-
[RunContextWrapper[TContext], Agent[Any], str | list[TResponseInputItem]],
87-
MaybeAwaitable[GuardrailFunctionOutput],
88-
]
110+
guardrail_function: InputGuardrailFunction[TContext] | InputGuardrailFunctionLegacy[TContext]
89111
"""A function that receives the agent input and the context, and returns a
90112
`GuardrailResult`. The result marks whether the tripwire was triggered, and can optionally
91113
include information about the guardrail's output.
@@ -107,11 +129,21 @@ async def run(
107129
agent: Agent[Any],
108130
input: str | list[TResponseInputItem],
109131
context: RunContextWrapper[TContext],
132+
previous_response_id: str | None,
110133
) -> InputGuardrailResult:
111134
if not callable(self.guardrail_function):
112135
raise UserError(f"Guardrail function must be callable, got {self.guardrail_function}")
113136

114-
output = self.guardrail_function(context, agent, input)
137+
sig = inspect.signature(self.guardrail_function)
138+
if len(sig.parameters) == 3:
139+
# Legacy guardrail function
140+
legacy_function = cast(InputGuardrailFunctionLegacy[TContext], self.guardrail_function)
141+
output = legacy_function(context, agent, input)
142+
else:
143+
# New guardrail function
144+
new_function = cast(InputGuardrailFunction[TContext], self.guardrail_function)
145+
output = new_function(context, InputGuardailInputs(agent, input, previous_response_id))
146+
115147
if inspect.isawaitable(output):
116148
return InputGuardrailResult(
117149
guardrail=self,
@@ -182,13 +214,11 @@ async def run(
182214
TContext_co = TypeVar("TContext_co", bound=Any, covariant=True)
183215

184216
# For InputGuardrail
185-
_InputGuardrailFuncSync = Callable[
186-
[RunContextWrapper[TContext_co], "Agent[Any]", Union[str, list[TResponseInputItem]]],
187-
GuardrailFunctionOutput,
217+
_InputGuardrailFuncSync = Union[
218+
InputGuardrailFunctionLegacy[TContext_co], InputGuardrailFunction[TContext_co]
188219
]
189-
_InputGuardrailFuncAsync = Callable[
190-
[RunContextWrapper[TContext_co], "Agent[Any]", Union[str, list[TResponseInputItem]]],
191-
Awaitable[GuardrailFunctionOutput],
220+
_InputGuardrailFuncAsync = Union[
221+
InputGuardrailFunctionLegacy[TContext_co], InputGuardrailFunction[TContext_co]
192222
]
193223

194224

src/agents/run.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ async def run(
221221
starting_agent.input_guardrails
222222
+ (run_config.input_guardrails or []),
223223
copy.deepcopy(input),
224+
previous_response_id,
224225
context_wrapper,
225226
),
226227
cls._run_single_turn(
@@ -446,6 +447,7 @@ async def _run_input_guardrails_with_queue(
446447
agent: Agent[Any],
447448
guardrails: list[InputGuardrail[TContext]],
448449
input: str | list[TResponseInputItem],
450+
previous_response_id: str | None,
449451
context: RunContextWrapper[TContext],
450452
streamed_result: RunResultStreaming,
451453
parent_span: Span[Any],
@@ -455,7 +457,9 @@ async def _run_input_guardrails_with_queue(
455457
# We'll run the guardrails and push them onto the queue as they complete
456458
guardrail_tasks = [
457459
asyncio.create_task(
458-
RunImpl.run_single_input_guardrail(agent, guardrail, input, context)
460+
RunImpl.run_single_input_guardrail(
461+
agent, guardrail, input, previous_response_id, context
462+
)
459463
)
460464
for guardrail in guardrails
461465
]
@@ -551,6 +555,7 @@ async def _run_streamed_impl(
551555
starting_agent,
552556
starting_agent.input_guardrails + (run_config.input_guardrails or []),
553557
copy.deepcopy(ItemHelpers.input_to_new_input_list(starting_input)),
558+
previous_response_id,
554559
context_wrapper,
555560
streamed_result,
556561
current_span,
@@ -825,14 +830,17 @@ async def _run_input_guardrails(
825830
agent: Agent[Any],
826831
guardrails: list[InputGuardrail[TContext]],
827832
input: str | list[TResponseInputItem],
833+
previous_response_id: str | None,
828834
context: RunContextWrapper[TContext],
829835
) -> list[InputGuardrailResult]:
830836
if not guardrails:
831837
return []
832838

833839
guardrail_tasks = [
834840
asyncio.create_task(
835-
RunImpl.run_single_input_guardrail(agent, guardrail, input, context)
841+
RunImpl.run_single_input_guardrail(
842+
agent, guardrail, input, previous_response_id, context
843+
)
836844
)
837845
for guardrail in guardrails
838846
]

0 commit comments

Comments
 (0)