3
3
import inspect
4
4
from collections .abc import Awaitable
5
5
from dataclasses import dataclass
6
- from typing import TYPE_CHECKING , Any , Callable , Generic , Union , overload
6
+ from typing import TYPE_CHECKING , Any , Callable , Generic , Union , cast , overload
7
7
8
- from typing_extensions import TypeVar
8
+ from typing_extensions import TypeAlias , TypeVar
9
9
10
10
from .exceptions import UserError
11
11
from .items import TResponseInputItem
@@ -68,6 +68,31 @@ class OutputGuardrailResult:
68
68
"""The output of the guardrail function."""
69
69
70
70
71
+ InputGuardrailFunctionLegacy : TypeAlias = Callable [
72
+ [RunContextWrapper [TContext ], "Agent[Any]" , Union [str , list [TResponseInputItem ]]],
73
+ MaybeAwaitable [GuardrailFunctionOutput ],
74
+ ]
75
+ """The legacy guardrail function signature, retained for backwards compatibility. Of the form:
76
+ def my_guardrail(ctx, agent, input)
77
+ """
78
+
79
+
80
+ @dataclass
81
+ class InputGuardailInputs :
82
+ agent : Agent [Any ]
83
+ input : str | list [TResponseInputItem ]
84
+ previous_response_id : str | None
85
+
86
+
87
+ InputGuardrailFunction : TypeAlias = Callable [
88
+ [RunContextWrapper [TContext ], InputGuardailInputs ],
89
+ MaybeAwaitable [GuardrailFunctionOutput ],
90
+ ]
91
+ """The new guardrail function signature, of the form:
92
+ def my_guardrail(ctx, inputs)
93
+ """
94
+
95
+
71
96
@dataclass
72
97
class InputGuardrail (Generic [TContext ]):
73
98
"""Input guardrails are checks that run in parallel to the agent's execution.
@@ -82,10 +107,7 @@ class InputGuardrail(Generic[TContext]):
82
107
execution will immediately stop and a `InputGuardrailTripwireTriggered` exception will be raised
83
108
"""
84
109
85
- guardrail_function : Callable [
86
- [RunContextWrapper [TContext ], Agent [Any ], str | list [TResponseInputItem ]],
87
- MaybeAwaitable [GuardrailFunctionOutput ],
88
- ]
110
+ guardrail_function : InputGuardrailFunction [TContext ] | InputGuardrailFunctionLegacy [TContext ]
89
111
"""A function that receives the agent input and the context, and returns a
90
112
`GuardrailResult`. The result marks whether the tripwire was triggered, and can optionally
91
113
include information about the guardrail's output.
@@ -107,11 +129,21 @@ async def run(
107
129
agent : Agent [Any ],
108
130
input : str | list [TResponseInputItem ],
109
131
context : RunContextWrapper [TContext ],
132
+ previous_response_id : str | None ,
110
133
) -> InputGuardrailResult :
111
134
if not callable (self .guardrail_function ):
112
135
raise UserError (f"Guardrail function must be callable, got { self .guardrail_function } " )
113
136
114
- output = self .guardrail_function (context , agent , input )
137
+ sig = inspect .signature (self .guardrail_function )
138
+ if len (sig .parameters ) == 3 :
139
+ # Legacy guardrail function
140
+ legacy_function = cast (InputGuardrailFunctionLegacy [TContext ], self .guardrail_function )
141
+ output = legacy_function (context , agent , input )
142
+ else :
143
+ # New guardrail function
144
+ new_function = cast (InputGuardrailFunction [TContext ], self .guardrail_function )
145
+ output = new_function (context , InputGuardailInputs (agent , input , previous_response_id ))
146
+
115
147
if inspect .isawaitable (output ):
116
148
return InputGuardrailResult (
117
149
guardrail = self ,
@@ -182,13 +214,11 @@ async def run(
182
214
TContext_co = TypeVar ("TContext_co" , bound = Any , covariant = True )
183
215
184
216
# For InputGuardrail
185
- _InputGuardrailFuncSync = Callable [
186
- [RunContextWrapper [TContext_co ], "Agent[Any]" , Union [str , list [TResponseInputItem ]]],
187
- GuardrailFunctionOutput ,
217
+ _InputGuardrailFuncSync = Union [
218
+ InputGuardrailFunctionLegacy [TContext_co ], InputGuardrailFunction [TContext_co ]
188
219
]
189
- _InputGuardrailFuncAsync = Callable [
190
- [RunContextWrapper [TContext_co ], "Agent[Any]" , Union [str , list [TResponseInputItem ]]],
191
- Awaitable [GuardrailFunctionOutput ],
220
+ _InputGuardrailFuncAsync = Union [
221
+ InputGuardrailFunctionLegacy [TContext_co ], InputGuardrailFunction [TContext_co ]
192
222
]
193
223
194
224
0 commit comments