31
31
from .util ._types import MaybeAwaitable
32
32
33
33
if TYPE_CHECKING :
34
-
35
34
from .agent import Agent
36
35
37
36
ToolParams = ParamSpec ("ToolParams" )
@@ -303,6 +302,7 @@ def function_tool(
303
302
failure_error_function : ToolErrorFunction | None = None ,
304
303
strict_mode : bool = True ,
305
304
is_enabled : bool | Callable [[RunContextWrapper [Any ], Agent [Any ]], MaybeAwaitable [bool ]] = True ,
305
+ run_in_thread : bool = False ,
306
306
) -> FunctionTool :
307
307
"""Overload for usage as @function_tool (no parentheses)."""
308
308
...
@@ -318,6 +318,7 @@ def function_tool(
318
318
failure_error_function : ToolErrorFunction | None = None ,
319
319
strict_mode : bool = True ,
320
320
is_enabled : bool | Callable [[RunContextWrapper [Any ], Agent [Any ]], MaybeAwaitable [bool ]] = True ,
321
+ run_in_thread : bool = False ,
321
322
) -> Callable [[ToolFunction [...]], FunctionTool ]:
322
323
"""Overload for usage as @function_tool(...)."""
323
324
...
@@ -333,6 +334,7 @@ def function_tool(
333
334
failure_error_function : ToolErrorFunction | None = default_tool_error_function ,
334
335
strict_mode : bool = True ,
335
336
is_enabled : bool | Callable [[RunContextWrapper [Any ], Agent [Any ]], MaybeAwaitable [bool ]] = True ,
337
+ run_in_thread : bool = False ,
336
338
) -> FunctionTool | Callable [[ToolFunction [...]], FunctionTool ]:
337
339
"""
338
340
Decorator to create a FunctionTool from a function. By default, we will:
@@ -364,6 +366,9 @@ def function_tool(
364
366
is_enabled: Whether the tool is enabled. Can be a bool or a callable that takes the run
365
367
context and agent and returns whether the tool is enabled. Disabled tools are hidden
366
368
from the LLM at runtime.
369
+ run_in_thread: Whether to run the tool in a thread. This only applies to non-async functions.
370
+ If True, the tool will be run in a thread, which can be useful to avoid blocking the
371
+ main thread.
367
372
"""
368
373
369
374
def _create_function_tool (the_func : ToolFunction [...]) -> FunctionTool :
@@ -414,9 +419,15 @@ async def _on_invoke_tool_impl(ctx: ToolContext[Any], input: str) -> Any:
414
419
result = await the_func (* args , ** kwargs_dict )
415
420
else :
416
421
if schema .takes_context :
417
- result = await asyncio .to_thread (the_func , ctx , * args , ** kwargs_dict )
422
+ if run_in_thread :
423
+ result = await asyncio .to_thread (the_func , ctx , * args , ** kwargs_dict )
424
+ else :
425
+ result = the_func (ctx , * args , ** kwargs_dict )
418
426
else :
419
- result = await asyncio .to_thread (the_func , * args , ** kwargs_dict )
427
+ if run_in_thread :
428
+ result = await asyncio .to_thread (the_func , * args , ** kwargs_dict )
429
+ else :
430
+ result = the_func (* args , ** kwargs_dict )
420
431
421
432
if _debug .DONT_LOG_TOOL_DATA :
422
433
logger .debug (f"Tool { schema .name } completed." )
0 commit comments