diff --git a/src/agents/tool.py b/src/agents/tool.py index 3aab47752..9b85cf6fb 100644 --- a/src/agents/tool.py +++ b/src/agents/tool.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import inspect import json from collections.abc import Awaitable @@ -30,7 +31,6 @@ from .util._types import MaybeAwaitable if TYPE_CHECKING: - from .agent import Agent ToolParams = ParamSpec("ToolParams") @@ -302,6 +302,7 @@ def function_tool( failure_error_function: ToolErrorFunction | None = None, strict_mode: bool = True, is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True, + run_in_thread: bool = False, ) -> FunctionTool: """Overload for usage as @function_tool (no parentheses).""" ... @@ -317,6 +318,7 @@ def function_tool( failure_error_function: ToolErrorFunction | None = None, strict_mode: bool = True, is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True, + run_in_thread: bool = False, ) -> Callable[[ToolFunction[...]], FunctionTool]: """Overload for usage as @function_tool(...).""" ... @@ -332,6 +334,7 @@ def function_tool( failure_error_function: ToolErrorFunction | None = default_tool_error_function, strict_mode: bool = True, is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True, + run_in_thread: bool = False, ) -> FunctionTool | Callable[[ToolFunction[...]], FunctionTool]: """ Decorator to create a FunctionTool from a function. By default, we will: @@ -363,6 +366,9 @@ def function_tool( is_enabled: Whether the tool is enabled. Can be a bool or a callable that takes the run context and agent and returns whether the tool is enabled. Disabled tools are hidden from the LLM at runtime. + run_in_thread: Whether to run the tool in a thread. This only applies to non-async functions. + If True, the tool will be run in a thread, which can be useful to avoid blocking the + main thread. """ def _create_function_tool(the_func: ToolFunction[...]) -> FunctionTool: @@ -413,9 +419,15 @@ async def _on_invoke_tool_impl(ctx: ToolContext[Any], input: str) -> Any: result = await the_func(*args, **kwargs_dict) else: if schema.takes_context: - result = the_func(ctx, *args, **kwargs_dict) + if run_in_thread: + result = await asyncio.to_thread(the_func, ctx, *args, **kwargs_dict) + else: + result = the_func(ctx, *args, **kwargs_dict) else: - result = the_func(*args, **kwargs_dict) + if run_in_thread: + result = await asyncio.to_thread(the_func, *args, **kwargs_dict) + else: + result = the_func(*args, **kwargs_dict) if _debug.DONT_LOG_TOOL_DATA: logger.debug(f"Tool {schema.name} completed.")