From 7db156dda67f99f7cda3eaece16b1395b9350834 Mon Sep 17 00:00:00 2001 From: Rohan Mehta Date: Tue, 3 Jun 2025 10:57:56 -0400 Subject: [PATCH] Add is_enabled to FunctionTool ### Summary: Allows a user to do `function_tool(is_enabled=)`; the callable is called when the agent runs. This allows you to dynamically enable/disable a tool based on the context/env. ### Test Plan: Unit tests --- src/agents/agent.py | 22 +++++++++++++--- src/agents/run.py | 10 ++++--- src/agents/tool.py | 17 +++++++++++- tests/test_function_tool.py | 43 ++++++++++++++++++++++++++++++- tests/test_run_step_execution.py | 2 +- tests/test_run_step_processing.py | 32 +++++++++++++---------- 6 files changed, 102 insertions(+), 24 deletions(-) diff --git a/src/agents/agent.py b/src/agents/agent.py index e22f579fa..6adccedd5 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import dataclasses import inspect from collections.abc import Awaitable @@ -17,7 +18,7 @@ from .model_settings import ModelSettings from .models.interface import Model from .run_context import RunContextWrapper, TContext -from .tool import FunctionToolResult, Tool, function_tool +from .tool import FunctionTool, FunctionToolResult, Tool, function_tool from .util import _transforms from .util._types import MaybeAwaitable @@ -246,7 +247,22 @@ async def get_mcp_tools(self) -> list[Tool]: convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False) return await MCPUtil.get_all_function_tools(self.mcp_servers, convert_schemas_to_strict) - async def get_all_tools(self) -> list[Tool]: + async def get_all_tools(self, run_context: RunContextWrapper[Any]) -> list[Tool]: """All agent tools, including MCP tools and function tools.""" mcp_tools = await self.get_mcp_tools() - return mcp_tools + self.tools + + async def _check_tool_enabled(tool: Tool) -> bool: + if not isinstance(tool, FunctionTool): + return True + + attr = tool.is_enabled + if isinstance(attr, bool): + return attr + res = attr(run_context, self) + if inspect.isawaitable(res): + return bool(await res) + return bool(res) + + results = await asyncio.gather(*(_check_tool_enabled(t) for t in self.tools)) + enabled: list[Tool] = [t for t, ok in zip(self.tools, results) if ok] + return [*mcp_tools, *enabled] diff --git a/src/agents/run.py b/src/agents/run.py index bfbdacd45..f375a0b89 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -181,7 +181,7 @@ async def run( try: while True: - all_tools = await cls._get_all_tools(current_agent) + all_tools = await cls._get_all_tools(current_agent, context_wrapper) # Start an agent span if we don't have one. This span is ended if the current # agent changes, or if the agent loop ends. @@ -525,7 +525,7 @@ async def _run_streamed_impl( if streamed_result.is_complete: break - all_tools = await cls._get_all_tools(current_agent) + all_tools = await cls._get_all_tools(current_agent, context_wrapper) # Start an agent span if we don't have one. This span is ended if the current # agent changes, or if the agent loop ends. @@ -980,8 +980,10 @@ def _get_handoffs(cls, agent: Agent[Any]) -> list[Handoff]: return handoffs @classmethod - async def _get_all_tools(cls, agent: Agent[Any]) -> list[Tool]: - return await agent.get_all_tools() + async def _get_all_tools( + cls, agent: Agent[Any], context_wrapper: RunContextWrapper[Any] + ) -> list[Tool]: + return await agent.get_all_tools(context_wrapper) @classmethod def _get_model(cls, agent: Agent[Any], run_config: RunConfig) -> Model: diff --git a/src/agents/tool.py b/src/agents/tool.py index fd5a21c89..57272f9f4 100644 --- a/src/agents/tool.py +++ b/src/agents/tool.py @@ -4,7 +4,7 @@ import json from collections.abc import Awaitable from dataclasses import dataclass -from typing import Any, Callable, Literal, Union, overload +from typing import TYPE_CHECKING, Any, Callable, Literal, Union, overload from openai.types.responses.file_search_tool_param import Filters, RankingOptions from openai.types.responses.response_output_item import LocalShellCall, McpApprovalRequest @@ -24,6 +24,9 @@ from .util import _error_tracing from .util._types import MaybeAwaitable +if TYPE_CHECKING: + from .agent import Agent + ToolParams = ParamSpec("ToolParams") ToolFunctionWithoutContext = Callable[ToolParams, Any] @@ -74,6 +77,11 @@ class FunctionTool: """Whether the JSON schema is in strict mode. We **strongly** recommend setting this to True, as it increases the likelihood of correct JSON input.""" + is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True + """Whether the tool is enabled. Either a bool or a Callable that takes the run context and agent + and returns whether the tool is enabled. You can use this to dynamically enable/disable a tool + based on your context/state.""" + @dataclass class FileSearchTool: @@ -262,6 +270,7 @@ def function_tool( use_docstring_info: bool = True, failure_error_function: ToolErrorFunction | None = None, strict_mode: bool = True, + is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True, ) -> FunctionTool: """Overload for usage as @function_tool (no parentheses).""" ... @@ -276,6 +285,7 @@ def function_tool( use_docstring_info: bool = True, failure_error_function: ToolErrorFunction | None = None, strict_mode: bool = True, + is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True, ) -> Callable[[ToolFunction[...]], FunctionTool]: """Overload for usage as @function_tool(...).""" ... @@ -290,6 +300,7 @@ def function_tool( use_docstring_info: bool = True, failure_error_function: ToolErrorFunction | None = default_tool_error_function, strict_mode: bool = True, + is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True, ) -> FunctionTool | Callable[[ToolFunction[...]], FunctionTool]: """ Decorator to create a FunctionTool from a function. By default, we will: @@ -318,6 +329,9 @@ def function_tool( If False, it allows non-strict JSON schemas. For example, if a parameter has a default value, it will be optional, additional properties are allowed, etc. See here for more: https://platform.openai.com/docs/guides/structured-outputs?api-mode=responses#supported-schemas + is_enabled: Whether the tool is enabled. Can be a bool or a callable that takes the run + context and agent and returns whether the tool is enabled. Disabled tools are hidden + from the LLM at runtime. """ def _create_function_tool(the_func: ToolFunction[...]) -> FunctionTool: @@ -407,6 +421,7 @@ async def _on_invoke_tool(ctx: RunContextWrapper[Any], input: str) -> Any: params_json_schema=schema.params_json_schema, on_invoke_tool=_on_invoke_tool, strict_json_schema=strict_mode, + is_enabled=is_enabled, ) # If func is actually a callable, we were used as @function_tool with no parentheses diff --git a/tests/test_function_tool.py b/tests/test_function_tool.py index 0a57aea87..c5d7da649 100644 --- a/tests/test_function_tool.py +++ b/tests/test_function_tool.py @@ -5,7 +5,7 @@ from pydantic import BaseModel from typing_extensions import TypedDict -from agents import FunctionTool, ModelBehaviorError, RunContextWrapper, function_tool +from agents import Agent, FunctionTool, ModelBehaviorError, RunContextWrapper, function_tool from agents.tool import default_tool_error_function @@ -255,3 +255,44 @@ def custom_sync_error_function(ctx: RunContextWrapper[Any], error: Exception) -> result = await tool.on_invoke_tool(ctx, '{"a": 1, "b": 2}') assert result == "error_ValueError" + + +class BoolCtx(BaseModel): + enable_tools: bool + + +@pytest.mark.asyncio +async def test_is_enabled_bool_and_callable(): + @function_tool(is_enabled=False) + def disabled_tool(): + return "nope" + + async def cond_enabled(ctx: RunContextWrapper[BoolCtx], agent: Agent[Any]) -> bool: + return ctx.context.enable_tools + + @function_tool(is_enabled=cond_enabled) + def another_tool(): + return "hi" + + async def third_tool_on_invoke_tool(ctx: RunContextWrapper[Any], args: str) -> str: + return "third" + + third_tool = FunctionTool( + name="third_tool", + description="third tool", + on_invoke_tool=third_tool_on_invoke_tool, + is_enabled=lambda ctx, agent: ctx.context.enable_tools, + params_json_schema={}, + ) + + agent = Agent(name="t", tools=[disabled_tool, another_tool, third_tool]) + context_1 = RunContextWrapper(BoolCtx(enable_tools=False)) + context_2 = RunContextWrapper(BoolCtx(enable_tools=True)) + + tools_with_ctx = await agent.get_all_tools(context_1) + assert tools_with_ctx == [] + + tools_with_ctx = await agent.get_all_tools(context_2) + assert len(tools_with_ctx) == 2 + assert tools_with_ctx[0].name == "another_tool" + assert tools_with_ctx[1].name == "third_tool" diff --git a/tests/test_run_step_execution.py b/tests/test_run_step_execution.py index 6ae25fbd5..b4c83d015 100644 --- a/tests/test_run_step_execution.py +++ b/tests/test_run_step_execution.py @@ -290,7 +290,7 @@ async def get_execute_result( processed_response = RunImpl.process_model_response( agent=agent, - all_tools=await agent.get_all_tools(), + all_tools=await agent.get_all_tools(context_wrapper or RunContextWrapper(None)), response=response, output_schema=output_schema, handoffs=handoffs, diff --git a/tests/test_run_step_processing.py b/tests/test_run_step_processing.py index 2ea98f06a..3cc1231c1 100644 --- a/tests/test_run_step_processing.py +++ b/tests/test_run_step_processing.py @@ -34,6 +34,10 @@ ) +def _dummy_ctx() -> RunContextWrapper[None]: + return RunContextWrapper(context=None) + + def test_empty_response(): agent = Agent(name="test") response = ModelResponse( @@ -83,7 +87,7 @@ async def test_single_tool_call(): response=response, output_schema=None, handoffs=[], - all_tools=await agent.get_all_tools(), + all_tools=await agent.get_all_tools(_dummy_ctx()), ) assert not result.handoffs assert result.functions and len(result.functions) == 1 @@ -111,7 +115,7 @@ async def test_missing_tool_call_raises_error(): response=response, output_schema=None, handoffs=[], - all_tools=await agent.get_all_tools(), + all_tools=await agent.get_all_tools(_dummy_ctx()), ) @@ -140,7 +144,7 @@ async def test_multiple_tool_calls(): response=response, output_schema=None, handoffs=[], - all_tools=await agent.get_all_tools(), + all_tools=await agent.get_all_tools(_dummy_ctx()), ) assert not result.handoffs assert result.functions and len(result.functions) == 2 @@ -169,7 +173,7 @@ async def test_handoffs_parsed_correctly(): response=response, output_schema=None, handoffs=[], - all_tools=await agent_3.get_all_tools(), + all_tools=await agent_3.get_all_tools(_dummy_ctx()), ) assert not result.handoffs, "Shouldn't have a handoff here" @@ -183,7 +187,7 @@ async def test_handoffs_parsed_correctly(): response=response, output_schema=None, handoffs=Runner._get_handoffs(agent_3), - all_tools=await agent_3.get_all_tools(), + all_tools=await agent_3.get_all_tools(_dummy_ctx()), ) assert len(result.handoffs) == 1, "Should have a handoff here" handoff = result.handoffs[0] @@ -213,7 +217,7 @@ async def test_missing_handoff_fails(): response=response, output_schema=None, handoffs=Runner._get_handoffs(agent_3), - all_tools=await agent_3.get_all_tools(), + all_tools=await agent_3.get_all_tools(_dummy_ctx()), ) @@ -236,7 +240,7 @@ async def test_multiple_handoffs_doesnt_error(): response=response, output_schema=None, handoffs=Runner._get_handoffs(agent_3), - all_tools=await agent_3.get_all_tools(), + all_tools=await agent_3.get_all_tools(_dummy_ctx()), ) assert len(result.handoffs) == 2, "Should have multiple handoffs here" @@ -262,7 +266,7 @@ async def test_final_output_parsed_correctly(): response=response, output_schema=Runner._get_output_schema(agent), handoffs=[], - all_tools=await agent.get_all_tools(), + all_tools=await agent.get_all_tools(_dummy_ctx()), ) @@ -288,7 +292,7 @@ async def test_file_search_tool_call_parsed_correctly(): response=response, output_schema=None, handoffs=[], - all_tools=await agent.get_all_tools(), + all_tools=await agent.get_all_tools(_dummy_ctx()), ) # The final item should be a ToolCallItem for the file search call assert any( @@ -313,7 +317,7 @@ async def test_function_web_search_tool_call_parsed_correctly(): response=response, output_schema=None, handoffs=[], - all_tools=await agent.get_all_tools(), + all_tools=await agent.get_all_tools(_dummy_ctx()), ) assert any( isinstance(item, ToolCallItem) and item.raw_item is web_search_call @@ -340,7 +344,7 @@ async def test_reasoning_item_parsed_correctly(): response=response, output_schema=None, handoffs=[], - all_tools=await Agent(name="test").get_all_tools(), + all_tools=await Agent(name="test").get_all_tools(_dummy_ctx()), ) assert any( isinstance(item, ReasoningItem) and item.raw_item is reasoning for item in result.new_items @@ -409,7 +413,7 @@ async def test_computer_tool_call_without_computer_tool_raises_error(): response=response, output_schema=None, handoffs=[], - all_tools=await Agent(name="test").get_all_tools(), + all_tools=await Agent(name="test").get_all_tools(_dummy_ctx()), ) @@ -437,7 +441,7 @@ async def test_computer_tool_call_with_computer_tool_parsed_correctly(): response=response, output_schema=None, handoffs=[], - all_tools=await agent.get_all_tools(), + all_tools=await agent.get_all_tools(_dummy_ctx()), ) assert any( isinstance(item, ToolCallItem) and item.raw_item is computer_call @@ -468,7 +472,7 @@ async def test_tool_and_handoff_parsed_correctly(): response=response, output_schema=None, handoffs=Runner._get_handoffs(agent_3), - all_tools=await agent_3.get_all_tools(), + all_tools=await agent_3.get_all_tools(_dummy_ctx()), ) assert result.functions and len(result.functions) == 1 assert len(result.handoffs) == 1, "Should have a handoff here"