Skip to content

Commit d6688b6

Browse files
committed
feat: add run_in_thread to function tools
1 parent 846be16 commit d6688b6

File tree

1 file changed

+14
-3
lines changed

1 file changed

+14
-3
lines changed

src/agents/tool.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
from .util._types import MaybeAwaitable
3232

3333
if TYPE_CHECKING:
34-
3534
from .agent import Agent
3635

3736
ToolParams = ParamSpec("ToolParams")
@@ -303,6 +302,7 @@ def function_tool(
303302
failure_error_function: ToolErrorFunction | None = None,
304303
strict_mode: bool = True,
305304
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
305+
run_in_thread: bool = False,
306306
) -> FunctionTool:
307307
"""Overload for usage as @function_tool (no parentheses)."""
308308
...
@@ -318,6 +318,7 @@ def function_tool(
318318
failure_error_function: ToolErrorFunction | None = None,
319319
strict_mode: bool = True,
320320
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
321+
run_in_thread: bool = False,
321322
) -> Callable[[ToolFunction[...]], FunctionTool]:
322323
"""Overload for usage as @function_tool(...)."""
323324
...
@@ -333,6 +334,7 @@ def function_tool(
333334
failure_error_function: ToolErrorFunction | None = default_tool_error_function,
334335
strict_mode: bool = True,
335336
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
337+
run_in_thread: bool = False,
336338
) -> FunctionTool | Callable[[ToolFunction[...]], FunctionTool]:
337339
"""
338340
Decorator to create a FunctionTool from a function. By default, we will:
@@ -364,6 +366,9 @@ def function_tool(
364366
is_enabled: Whether the tool is enabled. Can be a bool or a callable that takes the run
365367
context and agent and returns whether the tool is enabled. Disabled tools are hidden
366368
from the LLM at runtime.
369+
run_in_thread: Whether to run the tool in a thread. This only applies to non-async functions.
370+
If True, the tool will be run in a thread, which can be useful to avoid blocking the
371+
main thread.
367372
"""
368373

369374
def _create_function_tool(the_func: ToolFunction[...]) -> FunctionTool:
@@ -414,9 +419,15 @@ async def _on_invoke_tool_impl(ctx: ToolContext[Any], input: str) -> Any:
414419
result = await the_func(*args, **kwargs_dict)
415420
else:
416421
if schema.takes_context:
417-
result = await asyncio.to_thread(the_func, ctx, *args, **kwargs_dict)
422+
if run_in_thread:
423+
result = await asyncio.to_thread(the_func, ctx, *args, **kwargs_dict)
424+
else:
425+
result = the_func(ctx, *args, **kwargs_dict)
418426
else:
419-
result = await asyncio.to_thread(the_func, *args, **kwargs_dict)
427+
if run_in_thread:
428+
result = await asyncio.to_thread(the_func, *args, **kwargs_dict)
429+
else:
430+
result = the_func(*args, **kwargs_dict)
420431

421432
if _debug.DONT_LOG_TOOL_DATA:
422433
logger.debug(f"Tool {schema.name} completed.")

0 commit comments

Comments
 (0)