From 95b3a0f6eaa508111374d4ee33a4e8aa498a71df Mon Sep 17 00:00:00 2001 From: devtalker Date: Fri, 13 Jun 2025 16:26:55 +0800 Subject: [PATCH 1/7] feat: add MCP tool filtering support --- docs/mcp.md | 51 +++++++ src/agents/agent.py | 13 +- src/agents/mcp/server.py | 71 ++++++++-- src/agents/mcp/util.py | 34 ++++- tests/mcp/test_tool_filtering.py | 223 +++++++++++++++++++++++++++++++ 5 files changed, 374 insertions(+), 18 deletions(-) create mode 100644 tests/mcp/test_tool_filtering.py diff --git a/docs/mcp.md b/docs/mcp.md index 76d142029..fb498a6b4 100644 --- a/docs/mcp.md +++ b/docs/mcp.md @@ -41,6 +41,57 @@ agent=Agent( ) ``` +## Tool filtering + +You can filter which tools are available to your Agent in two ways: + +### Server-level filtering + +Each MCP server instance can be configured with `allowed_tools` and `excluded_tools` parameters to control which tools it exposes: + +```python +# Only expose specific tools from this server +server = MCPServerStdio( + params={ + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-filesystem", samples_dir], + }, + allowed_tools=["read_file", "write_file"], # Only these tools will be available +) + +# Exclude specific tools from this server +server = MCPServerStdio( + params={ + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-filesystem", samples_dir], + }, + excluded_tools=["delete_file"], # This tool will be filtered out +) +``` + +### Agent-level filtering + +You can also filter tools at the Agent level using the `mcp_config` parameter. This allows you to control which tools are available across all MCP servers: + +```python +agent = Agent( + name="Assistant", + instructions="Use the tools to achieve the task", + mcp_servers=[server1, server2, server3], + mcp_config={ + "allowed_tools": { + "server1": ["read_file", "write_file"], # Only these tools from server1 + "server2": ["search"], # Only search tool from server2 + }, + "excluded_tools": { + "server3": ["dangerous_tool"], # Exclude this tool from server3 + } + } +) +``` + +**Filtering priority**: Server-level filtering is applied first, then Agent-level filtering. This allows for fine-grained control where servers can limit their exposed tools, and Agents can further restrict which tools they use. + ## Caching Every time an Agent runs, it calls `list_tools()` on the MCP server. This can be a latency hit, especially if the server is a remote server. To automatically cache the list of tools, you can pass `cache_tools_list=True` to [`MCPServerStdio`][agents.mcp.server.MCPServerStdio], [`MCPServerSse`][agents.mcp.server.MCPServerSse], and [`MCPServerStreamableHttp`][agents.mcp.server.MCPServerStreamableHttp]. You should only do this if you're certain the tool list will not change. diff --git a/src/agents/agent.py b/src/agents/agent.py index 61a9abe0c..dd790615d 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -64,6 +64,10 @@ class MCPConfig(TypedDict): """If True, we will attempt to convert the MCP schemas to strict-mode schemas. This is a best-effort conversion, so some schemas may not be convertible. Defaults to False. """ + allowed_tools: NotRequired[dict[str, list[str]]] + """Optional: server_name -> allowed tool names (whitelist)""" + excluded_tools: NotRequired[dict[str, list[str]]] + """Optional: server_name -> excluded tool names (blacklist)""" @dataclass @@ -259,7 +263,14 @@ async def get_prompt( async def get_mcp_tools(self) -> list[Tool]: """Fetches the available tools from the MCP servers.""" convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False) - return await MCPUtil.get_all_function_tools(self.mcp_servers, convert_schemas_to_strict) + allowed_tools_map = self.mcp_config.get("allowed_tools", {}) + excluded_tools_map = self.mcp_config.get("excluded_tools", {}) + return await MCPUtil.get_all_function_tools( + self.mcp_servers, + convert_schemas_to_strict, + allowed_tools_map, + excluded_tools_map, + ) async def get_all_tools(self, run_context: RunContextWrapper[Any]) -> list[Tool]: """All agent tools, including MCP tools and function tools.""" diff --git a/src/agents/mcp/server.py b/src/agents/mcp/server.py index 3d1e17790..ceff36320 100644 --- a/src/agents/mcp/server.py +++ b/src/agents/mcp/server.py @@ -57,7 +57,13 @@ async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> C class _MCPServerWithClientSession(MCPServer, abc.ABC): """Base class for MCP servers that use a `ClientSession` to communicate with the server.""" - def __init__(self, cache_tools_list: bool, client_session_timeout_seconds: float | None): + def __init__( + self, + cache_tools_list: bool, + client_session_timeout_seconds: float | None, + allowed_tools: list[str] | None = None, + excluded_tools: list[str] | None = None, + ): """ Args: cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be @@ -68,6 +74,10 @@ def __init__(self, cache_tools_list: bool, client_session_timeout_seconds: float (by avoiding a round-trip to the server every time). client_session_timeout_seconds: the read timeout passed to the MCP ClientSession. + allowed_tools: Optional list of tool names to allow (whitelist). + If set, only these tools will be available. + excluded_tools: Optional list of tool names to exclude (blacklist). + If set, these tools will be filtered out. """ self.session: ClientSession | None = None self.exit_stack: AsyncExitStack = AsyncExitStack() @@ -81,6 +91,9 @@ def __init__(self, cache_tools_list: bool, client_session_timeout_seconds: float self._cache_dirty = True self._tools_list: list[MCPTool] | None = None + self.allowed_tools = allowed_tools + self.excluded_tools = excluded_tools + @abc.abstractmethod def create_streams( self, @@ -138,14 +151,21 @@ async def list_tools(self) -> list[MCPTool]: # Return from cache if caching is enabled, we have tools, and the cache is not dirty if self.cache_tools_list and not self._cache_dirty and self._tools_list: - return self._tools_list - - # Reset the cache dirty to False - self._cache_dirty = False - - # Fetch the tools from the server - self._tools_list = (await self.session.list_tools()).tools - return self._tools_list + tools = self._tools_list + else: + # Reset the cache dirty to False + self._cache_dirty = False + # Fetch the tools from the server + self._tools_list = (await self.session.list_tools()).tools + tools = self._tools_list + + # Filter tools based on allowed and excluded tools + filtered_tools = tools + if self.allowed_tools is not None: + filtered_tools = [t for t in filtered_tools if t.name in self.allowed_tools] + if self.excluded_tools is not None: + filtered_tools = [t for t in filtered_tools if t.name not in self.excluded_tools] + return filtered_tools async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> CallToolResult: """Invoke a tool on the server.""" @@ -206,6 +226,8 @@ def __init__( cache_tools_list: bool = False, name: str | None = None, client_session_timeout_seconds: float | None = 5, + allowed_tools: list[str] | None = None, + excluded_tools: list[str] | None = None, ): """Create a new MCP server based on the stdio transport. @@ -223,8 +245,15 @@ def __init__( name: A readable name for the server. If not provided, we'll create one from the command. client_session_timeout_seconds: the read timeout passed to the MCP ClientSession. + allowed_tools: Optional list of tool names to allow (whitelist). + excluded_tools: Optional list of tool names to exclude (blacklist). """ - super().__init__(cache_tools_list, client_session_timeout_seconds) + super().__init__( + cache_tools_list, + client_session_timeout_seconds, + allowed_tools, + excluded_tools, + ) self.params = StdioServerParameters( command=params["command"], @@ -283,6 +312,8 @@ def __init__( cache_tools_list: bool = False, name: str | None = None, client_session_timeout_seconds: float | None = 5, + allowed_tools: list[str] | None = None, + excluded_tools: list[str] | None = None, ): """Create a new MCP server based on the HTTP with SSE transport. @@ -302,8 +333,15 @@ def __init__( URL. client_session_timeout_seconds: the read timeout passed to the MCP ClientSession. + allowed_tools: Optional list of tool names to allow (whitelist). + excluded_tools: Optional list of tool names to exclude (blacklist). """ - super().__init__(cache_tools_list, client_session_timeout_seconds) + super().__init__( + cache_tools_list, + client_session_timeout_seconds, + allowed_tools, + excluded_tools, + ) self.params = params self._name = name or f"sse: {self.params['url']}" @@ -362,6 +400,8 @@ def __init__( cache_tools_list: bool = False, name: str | None = None, client_session_timeout_seconds: float | None = 5, + allowed_tools: list[str] | None = None, + excluded_tools: list[str] | None = None, ): """Create a new MCP server based on the Streamable HTTP transport. @@ -382,8 +422,15 @@ def __init__( URL. client_session_timeout_seconds: the read timeout passed to the MCP ClientSession. + allowed_tools: Optional list of tool names to allow (whitelist). + excluded_tools: Optional list of tool names to exclude (blacklist). """ - super().__init__(cache_tools_list, client_session_timeout_seconds) + super().__init__( + cache_tools_list, + client_session_timeout_seconds, + allowed_tools, + excluded_tools, + ) self.params = params self._name = name or f"streamable_http: {self.params['url']}" diff --git a/src/agents/mcp/util.py b/src/agents/mcp/util.py index 5a963bc01..1b7d9d588 100644 --- a/src/agents/mcp/util.py +++ b/src/agents/mcp/util.py @@ -1,6 +1,6 @@ import functools import json -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Optional from agents.strict_schema import ensure_strict_json_schema @@ -22,13 +22,23 @@ class MCPUtil: @classmethod async def get_all_function_tools( - cls, servers: list["MCPServer"], convert_schemas_to_strict: bool + cls, + servers: list["MCPServer"], + convert_schemas_to_strict: bool, + allowed_tools_map: Optional[dict[str, list[str]]] = None, + excluded_tools_map: Optional[dict[str, list[str]]] = None, ) -> list[Tool]: """Get all function tools from a list of MCP servers.""" tools = [] tool_names: set[str] = set() + allowed_tools_map = allowed_tools_map or {} + excluded_tools_map = excluded_tools_map or {} for server in servers: - server_tools = await cls.get_function_tools(server, convert_schemas_to_strict) + allowed = allowed_tools_map.get(server.name) + excluded = excluded_tools_map.get(server.name) + server_tools = await cls.get_function_tools( + server, convert_schemas_to_strict, allowed, excluded + ) server_tool_names = {tool.name for tool in server_tools} if len(server_tool_names & tool_names) > 0: raise UserError( @@ -42,7 +52,11 @@ async def get_all_function_tools( @classmethod async def get_function_tools( - cls, server: "MCPServer", convert_schemas_to_strict: bool + cls, + server: "MCPServer", + convert_schemas_to_strict: bool, + allowed_tools: Optional[list[str]] = None, + excluded_tools: Optional[list[str]] = None, ) -> list[Tool]: """Get all function tools from a single MCP server.""" @@ -50,7 +64,17 @@ async def get_function_tools( tools = await server.list_tools() span.span_data.result = [tool.name for tool in tools] - return [cls.to_function_tool(tool, server, convert_schemas_to_strict) for tool in tools] + # Apply Agent-level filtering (additional filtering on top of server-level filtering) + filtered_tools = tools + if allowed_tools is not None: + filtered_tools = [t for t in filtered_tools if t.name in allowed_tools] + if excluded_tools is not None: + filtered_tools = [t for t in filtered_tools if t.name not in excluded_tools] + + return [ + cls.to_function_tool(tool, server, convert_schemas_to_strict) + for tool in filtered_tools + ] @classmethod def to_function_tool( diff --git a/tests/mcp/test_tool_filtering.py b/tests/mcp/test_tool_filtering.py new file mode 100644 index 000000000..99168a551 --- /dev/null +++ b/tests/mcp/test_tool_filtering.py @@ -0,0 +1,223 @@ +import pytest + +from agents import Agent +from agents.mcp import MCPUtil + +from .helpers import FakeMCPServer + + +class FilterableFakeMCPServer(FakeMCPServer): + """Extended FakeMCPServer that supports tool filtering""" + + def __init__(self, tools=None, allowed_tools=None, excluded_tools=None, server_name=None): + super().__init__(tools) + self.allowed_tools = allowed_tools + self.excluded_tools = excluded_tools + self._server_name = server_name + + async def list_tools(self): + tools = await super().list_tools() + + # Apply filtering logic similar to _MCPServerWithClientSession + filtered_tools = tools + if self.allowed_tools is not None: + filtered_tools = [t for t in filtered_tools if t.name in self.allowed_tools] + if self.excluded_tools is not None: + filtered_tools = [t for t in filtered_tools if t.name not in self.excluded_tools] + return filtered_tools + + @property + def name(self) -> str: + return self._server_name or "filterable_fake_server" + + +@pytest.mark.asyncio +async def test_server_allowed_tools(): + """Test that server-level allowed_tools filters tools correctly""" + server = FilterableFakeMCPServer(server_name="test_server") + server.add_tool("tool1", {}) + server.add_tool("tool2", {}) + server.add_tool("tool3", {}) + + # Set allowed_tools to only include tool1 and tool2 + server.allowed_tools = ["tool1", "tool2"] + + # Get tools and verify filtering + tools = await server.list_tools() + assert len(tools) == 2 + assert {t.name for t in tools} == {"tool1", "tool2"} + + +@pytest.mark.asyncio +async def test_server_excluded_tools(): + """Test that server-level excluded_tools filters tools correctly""" + server = FilterableFakeMCPServer(server_name="test_server") + server.add_tool("tool1", {}) + server.add_tool("tool2", {}) + server.add_tool("tool3", {}) + + # Set excluded_tools to exclude tool3 + server.excluded_tools = ["tool3"] + + # Get tools and verify filtering + tools = await server.list_tools() + assert len(tools) == 2 + assert {t.name for t in tools} == {"tool1", "tool2"} + + +@pytest.mark.asyncio +async def test_server_both_filters(): + """Test that server-level allowed_tools and excluded_tools work together correctly""" + server = FilterableFakeMCPServer(server_name="test_server") + server.add_tool("tool1", {}) + server.add_tool("tool2", {}) + server.add_tool("tool3", {}) + server.add_tool("tool4", {}) + + # Set both filters + server.allowed_tools = ["tool1", "tool2", "tool3"] + server.excluded_tools = ["tool3"] + + # Get tools and verify filtering (allowed_tools applied first, then excluded_tools) + tools = await server.list_tools() + assert len(tools) == 2 + assert {t.name for t in tools} == {"tool1", "tool2"} + + +@pytest.mark.asyncio +async def test_agent_allowed_tools(): + """Test that agent-level allowed_tools filters tools correctly""" + server1 = FilterableFakeMCPServer(server_name="server1") + server1.add_tool("tool1", {}) + server1.add_tool("tool2", {}) + + server2 = FilterableFakeMCPServer(server_name="server2") + server2.add_tool("tool3", {}) + server2.add_tool("tool4", {}) + + # Create agent with allowed_tools in mcp_config + agent = Agent( + name="test_agent", + mcp_servers=[server1, server2], + mcp_config={ + "allowed_tools": { + "server1": ["tool1"], + "server2": ["tool3"], + } + } + ) + + # Get tools and verify filtering + tools = await agent.get_mcp_tools() + assert len(tools) == 2 + assert {t.name for t in tools} == {"tool1", "tool3"} + + +@pytest.mark.asyncio +async def test_agent_excluded_tools(): + """Test that agent-level excluded_tools filters tools correctly""" + server1 = FilterableFakeMCPServer(server_name="server1") + server1.add_tool("tool1", {}) + server1.add_tool("tool2", {}) + + server2 = FilterableFakeMCPServer(server_name="server2") + server2.add_tool("tool3", {}) + server2.add_tool("tool4", {}) + + # Create agent with excluded_tools in mcp_config + agent = Agent( + name="test_agent", + mcp_servers=[server1, server2], + mcp_config={ + "excluded_tools": { + "server1": ["tool2"], + "server2": ["tool4"], + } + } + ) + + # Get tools and verify filtering + tools = await agent.get_mcp_tools() + assert len(tools) == 2 + assert {t.name for t in tools} == {"tool1", "tool3"} + + +@pytest.mark.asyncio +async def test_combined_filtering(): + """Test that server-level and agent-level filtering work together correctly""" + # Server with its own filtering + server = FilterableFakeMCPServer(server_name="test_server") + server.add_tool("tool1", {}) + server.add_tool("tool2", {}) + server.add_tool("tool3", {}) + server.add_tool("tool4", {}) + server.allowed_tools = ["tool1", "tool2", "tool3"] # Server only exposes these + + # Agent with additional filtering + agent = Agent( + name="test_agent", + mcp_servers=[server], + mcp_config={ + "excluded_tools": { + "test_server": ["tool3"], # Agent excludes this one + } + } + ) + + # Get tools and verify filtering + tools = await agent.get_mcp_tools() + assert len(tools) == 2 + assert {t.name for t in tools} == {"tool1", "tool2"} + + +@pytest.mark.asyncio +async def test_util_direct_filtering(): + """Test MCPUtil.get_all_function_tools with filtering parameters""" + server1 = FilterableFakeMCPServer(server_name="server1") + server1.add_tool("tool1", {}) + server1.add_tool("tool2", {}) + + server2 = FilterableFakeMCPServer(server_name="server2") + server2.add_tool("tool3", {}) + server2.add_tool("tool4", {}) + + # Test direct filtering through MCPUtil + allowed_tools_map = {"server1": ["tool1"]} + excluded_tools_map = {"server2": ["tool4"]} + + tools = await MCPUtil.get_all_function_tools( + [server1, server2], + convert_schemas_to_strict=False, + allowed_tools_map=allowed_tools_map, + excluded_tools_map=excluded_tools_map + ) + + assert len(tools) == 2 + assert {t.name for t in tools} == {"tool1", "tool3"} + + +@pytest.mark.asyncio +async def test_filtering_priority(): + """Test that server-level filtering takes priority over agent-level filtering""" + # Server only exposes tool1 and tool2 + server = FilterableFakeMCPServer(server_name="test_server") + server.add_tool("tool1", {}) + server.add_tool("tool2", {}) + server.add_tool("tool3", {}) + server.allowed_tools = ["tool1", "tool2"] + + # Agent tries to allow tool3 (which server doesn't expose) + agent = Agent( + name="test_agent", + mcp_servers=[server], + mcp_config={ + "allowed_tools": { + "test_server": ["tool2", "tool3"], # tool3 isn't available from server + } + } + ) + + # Get tools and verify filtering + tools = await agent.get_mcp_tools() + assert len(tools) == 1 + assert tools[0].name == "tool2" # Only tool2 passes both filters From ffb5b2a3c167bcfbecc66bc1751cedbb4ce86078 Mon Sep 17 00:00:00 2001 From: devtalker Date: Mon, 16 Jun 2025 21:03:59 +0800 Subject: [PATCH 2/7] refactor: remove Agent-level tool filtering logic --- docs/mcp.md | 29 +------ src/agents/agent.py | 13 +-- src/agents/mcp/util.py | 34 ++------ tests/mcp/test_tool_filtering.py | 142 ------------------------------- 4 files changed, 7 insertions(+), 211 deletions(-) diff --git a/docs/mcp.md b/docs/mcp.md index fb498a6b4..3524db605 100644 --- a/docs/mcp.md +++ b/docs/mcp.md @@ -43,11 +43,7 @@ agent=Agent( ## Tool filtering -You can filter which tools are available to your Agent in two ways: - -### Server-level filtering - -Each MCP server instance can be configured with `allowed_tools` and `excluded_tools` parameters to control which tools it exposes: +You can filter which tools are available to your Agent using server-level filtering: ```python # Only expose specific tools from this server @@ -69,29 +65,6 @@ server = MCPServerStdio( ) ``` -### Agent-level filtering - -You can also filter tools at the Agent level using the `mcp_config` parameter. This allows you to control which tools are available across all MCP servers: - -```python -agent = Agent( - name="Assistant", - instructions="Use the tools to achieve the task", - mcp_servers=[server1, server2, server3], - mcp_config={ - "allowed_tools": { - "server1": ["read_file", "write_file"], # Only these tools from server1 - "server2": ["search"], # Only search tool from server2 - }, - "excluded_tools": { - "server3": ["dangerous_tool"], # Exclude this tool from server3 - } - } -) -``` - -**Filtering priority**: Server-level filtering is applied first, then Agent-level filtering. This allows for fine-grained control where servers can limit their exposed tools, and Agents can further restrict which tools they use. - ## Caching Every time an Agent runs, it calls `list_tools()` on the MCP server. This can be a latency hit, especially if the server is a remote server. To automatically cache the list of tools, you can pass `cache_tools_list=True` to [`MCPServerStdio`][agents.mcp.server.MCPServerStdio], [`MCPServerSse`][agents.mcp.server.MCPServerSse], and [`MCPServerStreamableHttp`][agents.mcp.server.MCPServerStreamableHttp]. You should only do this if you're certain the tool list will not change. diff --git a/src/agents/agent.py b/src/agents/agent.py index dd790615d..61a9abe0c 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -64,10 +64,6 @@ class MCPConfig(TypedDict): """If True, we will attempt to convert the MCP schemas to strict-mode schemas. This is a best-effort conversion, so some schemas may not be convertible. Defaults to False. """ - allowed_tools: NotRequired[dict[str, list[str]]] - """Optional: server_name -> allowed tool names (whitelist)""" - excluded_tools: NotRequired[dict[str, list[str]]] - """Optional: server_name -> excluded tool names (blacklist)""" @dataclass @@ -263,14 +259,7 @@ async def get_prompt( async def get_mcp_tools(self) -> list[Tool]: """Fetches the available tools from the MCP servers.""" convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False) - allowed_tools_map = self.mcp_config.get("allowed_tools", {}) - excluded_tools_map = self.mcp_config.get("excluded_tools", {}) - return await MCPUtil.get_all_function_tools( - self.mcp_servers, - convert_schemas_to_strict, - allowed_tools_map, - excluded_tools_map, - ) + return await MCPUtil.get_all_function_tools(self.mcp_servers, convert_schemas_to_strict) async def get_all_tools(self, run_context: RunContextWrapper[Any]) -> list[Tool]: """All agent tools, including MCP tools and function tools.""" diff --git a/src/agents/mcp/util.py b/src/agents/mcp/util.py index 1b7d9d588..5a963bc01 100644 --- a/src/agents/mcp/util.py +++ b/src/agents/mcp/util.py @@ -1,6 +1,6 @@ import functools import json -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from agents.strict_schema import ensure_strict_json_schema @@ -22,23 +22,13 @@ class MCPUtil: @classmethod async def get_all_function_tools( - cls, - servers: list["MCPServer"], - convert_schemas_to_strict: bool, - allowed_tools_map: Optional[dict[str, list[str]]] = None, - excluded_tools_map: Optional[dict[str, list[str]]] = None, + cls, servers: list["MCPServer"], convert_schemas_to_strict: bool ) -> list[Tool]: """Get all function tools from a list of MCP servers.""" tools = [] tool_names: set[str] = set() - allowed_tools_map = allowed_tools_map or {} - excluded_tools_map = excluded_tools_map or {} for server in servers: - allowed = allowed_tools_map.get(server.name) - excluded = excluded_tools_map.get(server.name) - server_tools = await cls.get_function_tools( - server, convert_schemas_to_strict, allowed, excluded - ) + server_tools = await cls.get_function_tools(server, convert_schemas_to_strict) server_tool_names = {tool.name for tool in server_tools} if len(server_tool_names & tool_names) > 0: raise UserError( @@ -52,11 +42,7 @@ async def get_all_function_tools( @classmethod async def get_function_tools( - cls, - server: "MCPServer", - convert_schemas_to_strict: bool, - allowed_tools: Optional[list[str]] = None, - excluded_tools: Optional[list[str]] = None, + cls, server: "MCPServer", convert_schemas_to_strict: bool ) -> list[Tool]: """Get all function tools from a single MCP server.""" @@ -64,17 +50,7 @@ async def get_function_tools( tools = await server.list_tools() span.span_data.result = [tool.name for tool in tools] - # Apply Agent-level filtering (additional filtering on top of server-level filtering) - filtered_tools = tools - if allowed_tools is not None: - filtered_tools = [t for t in filtered_tools if t.name in allowed_tools] - if excluded_tools is not None: - filtered_tools = [t for t in filtered_tools if t.name not in excluded_tools] - - return [ - cls.to_function_tool(tool, server, convert_schemas_to_strict) - for tool in filtered_tools - ] + return [cls.to_function_tool(tool, server, convert_schemas_to_strict) for tool in tools] @classmethod def to_function_tool( diff --git a/tests/mcp/test_tool_filtering.py b/tests/mcp/test_tool_filtering.py index 99168a551..845eb5aef 100644 --- a/tests/mcp/test_tool_filtering.py +++ b/tests/mcp/test_tool_filtering.py @@ -1,8 +1,5 @@ import pytest -from agents import Agent -from agents.mcp import MCPUtil - from .helpers import FakeMCPServer @@ -82,142 +79,3 @@ async def test_server_both_filters(): tools = await server.list_tools() assert len(tools) == 2 assert {t.name for t in tools} == {"tool1", "tool2"} - - -@pytest.mark.asyncio -async def test_agent_allowed_tools(): - """Test that agent-level allowed_tools filters tools correctly""" - server1 = FilterableFakeMCPServer(server_name="server1") - server1.add_tool("tool1", {}) - server1.add_tool("tool2", {}) - - server2 = FilterableFakeMCPServer(server_name="server2") - server2.add_tool("tool3", {}) - server2.add_tool("tool4", {}) - - # Create agent with allowed_tools in mcp_config - agent = Agent( - name="test_agent", - mcp_servers=[server1, server2], - mcp_config={ - "allowed_tools": { - "server1": ["tool1"], - "server2": ["tool3"], - } - } - ) - - # Get tools and verify filtering - tools = await agent.get_mcp_tools() - assert len(tools) == 2 - assert {t.name for t in tools} == {"tool1", "tool3"} - - -@pytest.mark.asyncio -async def test_agent_excluded_tools(): - """Test that agent-level excluded_tools filters tools correctly""" - server1 = FilterableFakeMCPServer(server_name="server1") - server1.add_tool("tool1", {}) - server1.add_tool("tool2", {}) - - server2 = FilterableFakeMCPServer(server_name="server2") - server2.add_tool("tool3", {}) - server2.add_tool("tool4", {}) - - # Create agent with excluded_tools in mcp_config - agent = Agent( - name="test_agent", - mcp_servers=[server1, server2], - mcp_config={ - "excluded_tools": { - "server1": ["tool2"], - "server2": ["tool4"], - } - } - ) - - # Get tools and verify filtering - tools = await agent.get_mcp_tools() - assert len(tools) == 2 - assert {t.name for t in tools} == {"tool1", "tool3"} - - -@pytest.mark.asyncio -async def test_combined_filtering(): - """Test that server-level and agent-level filtering work together correctly""" - # Server with its own filtering - server = FilterableFakeMCPServer(server_name="test_server") - server.add_tool("tool1", {}) - server.add_tool("tool2", {}) - server.add_tool("tool3", {}) - server.add_tool("tool4", {}) - server.allowed_tools = ["tool1", "tool2", "tool3"] # Server only exposes these - - # Agent with additional filtering - agent = Agent( - name="test_agent", - mcp_servers=[server], - mcp_config={ - "excluded_tools": { - "test_server": ["tool3"], # Agent excludes this one - } - } - ) - - # Get tools and verify filtering - tools = await agent.get_mcp_tools() - assert len(tools) == 2 - assert {t.name for t in tools} == {"tool1", "tool2"} - - -@pytest.mark.asyncio -async def test_util_direct_filtering(): - """Test MCPUtil.get_all_function_tools with filtering parameters""" - server1 = FilterableFakeMCPServer(server_name="server1") - server1.add_tool("tool1", {}) - server1.add_tool("tool2", {}) - - server2 = FilterableFakeMCPServer(server_name="server2") - server2.add_tool("tool3", {}) - server2.add_tool("tool4", {}) - - # Test direct filtering through MCPUtil - allowed_tools_map = {"server1": ["tool1"]} - excluded_tools_map = {"server2": ["tool4"]} - - tools = await MCPUtil.get_all_function_tools( - [server1, server2], - convert_schemas_to_strict=False, - allowed_tools_map=allowed_tools_map, - excluded_tools_map=excluded_tools_map - ) - - assert len(tools) == 2 - assert {t.name for t in tools} == {"tool1", "tool3"} - - -@pytest.mark.asyncio -async def test_filtering_priority(): - """Test that server-level filtering takes priority over agent-level filtering""" - # Server only exposes tool1 and tool2 - server = FilterableFakeMCPServer(server_name="test_server") - server.add_tool("tool1", {}) - server.add_tool("tool2", {}) - server.add_tool("tool3", {}) - server.allowed_tools = ["tool1", "tool2"] - - # Agent tries to allow tool3 (which server doesn't expose) - agent = Agent( - name="test_agent", - mcp_servers=[server], - mcp_config={ - "allowed_tools": { - "test_server": ["tool2", "tool3"], # tool3 isn't available from server - } - } - ) - - # Get tools and verify filtering - tools = await agent.get_mcp_tools() - assert len(tools) == 1 - assert tools[0].name == "tool2" # Only tool2 passes both filters From 304b1060801937befeb8e551037a8f5069a0adc1 Mon Sep 17 00:00:00 2001 From: devtalker Date: Mon, 16 Jun 2025 22:12:07 +0800 Subject: [PATCH 3/7] feat: implement tool filter interface for MCP servers --- src/agents/mcp/__init__.py | 14 +++++- src/agents/mcp/server.py | 79 +++++++++++++++++++------------- src/agents/mcp/util.py | 73 ++++++++++++++++++++++++++++- src/agents/tool.py | 1 + tests/mcp/test_tool_filtering.py | 79 ++++++++++++++++++++++++-------- 5 files changed, 194 insertions(+), 52 deletions(-) diff --git a/src/agents/mcp/__init__.py b/src/agents/mcp/__init__.py index d4eb8fa68..da5a68b16 100644 --- a/src/agents/mcp/__init__.py +++ b/src/agents/mcp/__init__.py @@ -11,7 +11,14 @@ except ImportError: pass -from .util import MCPUtil +from .util import ( + MCPUtil, + ToolFilter, + ToolFilterCallable, + ToolFilterContext, + ToolFilterStatic, + create_static_tool_filter, +) __all__ = [ "MCPServer", @@ -22,4 +29,9 @@ "MCPServerStreamableHttp", "MCPServerStreamableHttpParams", "MCPUtil", + "ToolFilter", + "ToolFilterCallable", + "ToolFilterContext", + "ToolFilterStatic", + "create_static_tool_filter", ] diff --git a/src/agents/mcp/server.py b/src/agents/mcp/server.py index ceff36320..561956510 100644 --- a/src/agents/mcp/server.py +++ b/src/agents/mcp/server.py @@ -17,6 +17,7 @@ from ..exceptions import UserError from ..logger import logger +from .util import ToolFilter, ToolFilterStatic class MCPServer(abc.ABC): @@ -61,8 +62,7 @@ def __init__( self, cache_tools_list: bool, client_session_timeout_seconds: float | None, - allowed_tools: list[str] | None = None, - excluded_tools: list[str] | None = None, + tool_filter: ToolFilter = None, ): """ Args: @@ -74,10 +74,7 @@ def __init__( (by avoiding a round-trip to the server every time). client_session_timeout_seconds: the read timeout passed to the MCP ClientSession. - allowed_tools: Optional list of tool names to allow (whitelist). - If set, only these tools will be available. - excluded_tools: Optional list of tool names to exclude (blacklist). - If set, these tools will be filtered out. + tool_filter: The tool filter to use for filtering tools. """ self.session: ClientSession | None = None self.exit_stack: AsyncExitStack = AsyncExitStack() @@ -91,8 +88,39 @@ def __init__( self._cache_dirty = True self._tools_list: list[MCPTool] | None = None - self.allowed_tools = allowed_tools - self.excluded_tools = excluded_tools + self.tool_filter = tool_filter + + def _apply_tool_filter(self, tools: list[MCPTool]) -> list[MCPTool]: + """Apply the tool filter to the list of tools.""" + if self.tool_filter is None: + return tools + + # Handle static tool filter + if isinstance(self.tool_filter, dict): + static_filter: ToolFilterStatic = self.tool_filter + filtered_tools = tools + + # Apply allowed_tool_names filter (whitelist) + if "allowed_tool_names" in static_filter: + allowed_names = static_filter["allowed_tool_names"] + filtered_tools = [t for t in filtered_tools if t.name in allowed_names] + + # Apply blocked_tool_names filter (blacklist) + if "blocked_tool_names" in static_filter: + blocked_names = static_filter["blocked_tool_names"] + filtered_tools = [t for t in filtered_tools if t.name not in blocked_names] + + return filtered_tools + + # Handle callable tool filter + # For now, we can't support callable filters because we don't have access to + # run context and agent in the current list_tools signature. + # This could be enhanced in the future by modifying the call chain. + else: + raise NotImplementedError( + "Callable tool filters are not yet supported. Please use ToolFilterStatic " + "with 'allowed_tool_names' and/or 'blocked_tool_names' for now." + ) @abc.abstractmethod def create_streams( @@ -159,12 +187,10 @@ async def list_tools(self) -> list[MCPTool]: self._tools_list = (await self.session.list_tools()).tools tools = self._tools_list - # Filter tools based on allowed and excluded tools + # Filter tools based on tool_filter filtered_tools = tools - if self.allowed_tools is not None: - filtered_tools = [t for t in filtered_tools if t.name in self.allowed_tools] - if self.excluded_tools is not None: - filtered_tools = [t for t in filtered_tools if t.name not in self.excluded_tools] + if self.tool_filter is not None: + filtered_tools = self._apply_tool_filter(filtered_tools) return filtered_tools async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> CallToolResult: @@ -226,8 +252,7 @@ def __init__( cache_tools_list: bool = False, name: str | None = None, client_session_timeout_seconds: float | None = 5, - allowed_tools: list[str] | None = None, - excluded_tools: list[str] | None = None, + tool_filter: ToolFilter = None, ): """Create a new MCP server based on the stdio transport. @@ -245,14 +270,12 @@ def __init__( name: A readable name for the server. If not provided, we'll create one from the command. client_session_timeout_seconds: the read timeout passed to the MCP ClientSession. - allowed_tools: Optional list of tool names to allow (whitelist). - excluded_tools: Optional list of tool names to exclude (blacklist). + tool_filter: The tool filter to use for filtering tools. """ super().__init__( cache_tools_list, client_session_timeout_seconds, - allowed_tools, - excluded_tools, + tool_filter, ) self.params = StdioServerParameters( @@ -312,8 +335,7 @@ def __init__( cache_tools_list: bool = False, name: str | None = None, client_session_timeout_seconds: float | None = 5, - allowed_tools: list[str] | None = None, - excluded_tools: list[str] | None = None, + tool_filter: ToolFilter = None, ): """Create a new MCP server based on the HTTP with SSE transport. @@ -333,14 +355,12 @@ def __init__( URL. client_session_timeout_seconds: the read timeout passed to the MCP ClientSession. - allowed_tools: Optional list of tool names to allow (whitelist). - excluded_tools: Optional list of tool names to exclude (blacklist). + tool_filter: The tool filter to use for filtering tools. """ super().__init__( cache_tools_list, client_session_timeout_seconds, - allowed_tools, - excluded_tools, + tool_filter, ) self.params = params @@ -400,8 +420,7 @@ def __init__( cache_tools_list: bool = False, name: str | None = None, client_session_timeout_seconds: float | None = 5, - allowed_tools: list[str] | None = None, - excluded_tools: list[str] | None = None, + tool_filter: ToolFilter = None, ): """Create a new MCP server based on the Streamable HTTP transport. @@ -422,14 +441,12 @@ def __init__( URL. client_session_timeout_seconds: the read timeout passed to the MCP ClientSession. - allowed_tools: Optional list of tool names to allow (whitelist). - excluded_tools: Optional list of tool names to exclude (blacklist). + tool_filter: The tool filter to use for filtering tools. """ super().__init__( cache_tools_list, client_session_timeout_seconds, - allowed_tools, - excluded_tools, + tool_filter, ) self.params = params diff --git a/src/agents/mcp/util.py b/src/agents/mcp/util.py index 5a963bc01..7ddce2ab3 100644 --- a/src/agents/mcp/util.py +++ b/src/agents/mcp/util.py @@ -1,6 +1,8 @@ import functools import json -from typing import TYPE_CHECKING, Any +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Callable, Union +from typing_extensions import NotRequired, TypedDict from agents.strict_schema import ensure_strict_json_schema @@ -10,13 +12,82 @@ from ..run_context import RunContextWrapper from ..tool import FunctionTool, Tool from ..tracing import FunctionSpanData, get_current_span, mcp_tools_span +from ..util._types import MaybeAwaitable if TYPE_CHECKING: from mcp.types import Tool as MCPTool + from ..agent import Agent from .server import MCPServer +@dataclass +class ToolFilterContext: + """Context information available to tool filter functions.""" + + run_context: RunContextWrapper[Any] + """The current run context.""" + + agent: "Agent[Any]" + """The agent that is requesting the tool list.""" + + server_name: str + """The name of the MCP server.""" + + +ToolFilterCallable = Callable[["ToolFilterContext", "MCPTool"], MaybeAwaitable[bool]] +"""A function that determines whether a tool should be available. + +Args: + context: The context information including run context, agent, and server name. + tool: The MCP tool to filter. + +Returns: + Whether the tool should be available (True) or filtered out (False). +""" + + +class ToolFilterStatic(TypedDict): + """Static tool filter configuration using allowlists and blocklists.""" + + allowed_tool_names: NotRequired[list[str]] + """Optional list of tool names to allow (whitelist). If set, only these tools will be available.""" + + blocked_tool_names: NotRequired[list[str]] + """Optional list of tool names to exclude (blacklist). If set, these tools will be filtered out.""" + + +ToolFilter = Union[ToolFilterCallable, ToolFilterStatic, None] +"""A tool filter that can be either a function, static configuration, or None (no filtering).""" + + +def create_static_tool_filter( + allowed_tool_names: list[str] | None = None, + blocked_tool_names: list[str] | None = None, +) -> ToolFilterStatic | None: + """Create a static tool filter from allowlist and blocklist parameters. + + This is a convenience function for creating a ToolFilterStatic. + + Args: + allowed_tool_names: Optional list of tool names to allow (whitelist). + blocked_tool_names: Optional list of tool names to exclude (blacklist). + + Returns: + A ToolFilterStatic if any filtering is specified, None otherwise. + """ + if allowed_tool_names is None and blocked_tool_names is None: + return None + + filter_dict: ToolFilterStatic = {} + if allowed_tool_names is not None: + filter_dict["allowed_tool_names"] = allowed_tool_names + if blocked_tool_names is not None: + filter_dict["blocked_tool_names"] = blocked_tool_names + + return filter_dict + + class MCPUtil: """Set of utilities for interop between MCP and Agents SDK tools.""" diff --git a/src/agents/tool.py b/src/agents/tool.py index ce66a53ba..c441dd768 100644 --- a/src/agents/tool.py +++ b/src/agents/tool.py @@ -27,6 +27,7 @@ if TYPE_CHECKING: from .agent import Agent + from mcp.types import Tool as MCPTool ToolParams = ParamSpec("ToolParams") diff --git a/tests/mcp/test_tool_filtering.py b/tests/mcp/test_tool_filtering.py index 845eb5aef..84cf729f0 100644 --- a/tests/mcp/test_tool_filtering.py +++ b/tests/mcp/test_tool_filtering.py @@ -1,15 +1,15 @@ import pytest +from agents.mcp import ToolFilterStatic from .helpers import FakeMCPServer class FilterableFakeMCPServer(FakeMCPServer): """Extended FakeMCPServer that supports tool filtering""" - def __init__(self, tools=None, allowed_tools=None, excluded_tools=None, server_name=None): + def __init__(self, tools=None, tool_filter=None, server_name=None): super().__init__(tools) - self.allowed_tools = allowed_tools - self.excluded_tools = excluded_tools + self.tool_filter = tool_filter self._server_name = server_name async def list_tools(self): @@ -17,27 +17,49 @@ async def list_tools(self): # Apply filtering logic similar to _MCPServerWithClientSession filtered_tools = tools - if self.allowed_tools is not None: - filtered_tools = [t for t in filtered_tools if t.name in self.allowed_tools] - if self.excluded_tools is not None: - filtered_tools = [t for t in filtered_tools if t.name not in self.excluded_tools] + if self.tool_filter is not None: + filtered_tools = self._apply_tool_filter(filtered_tools) return filtered_tools + def _apply_tool_filter(self, tools): + """Apply the tool filter to the list of tools.""" + if self.tool_filter is None: + return tools + + # Handle static tool filter + if isinstance(self.tool_filter, dict): + static_filter: ToolFilterStatic = self.tool_filter + filtered_tools = tools + + # Apply allowed_tool_names filter (whitelist) + if "allowed_tool_names" in static_filter: + allowed_names = static_filter["allowed_tool_names"] + filtered_tools = [t for t in filtered_tools if t.name in allowed_names] + + # Apply blocked_tool_names filter (blacklist) + if "blocked_tool_names" in static_filter: + blocked_names = static_filter["blocked_tool_names"] + filtered_tools = [t for t in filtered_tools if t.name not in blocked_names] + + return filtered_tools + + return tools + @property def name(self) -> str: return self._server_name or "filterable_fake_server" @pytest.mark.asyncio -async def test_server_allowed_tools(): - """Test that server-level allowed_tools filters tools correctly""" +async def test_server_allowed_tool_names(): + """Test that server-level allowed_tool_names filters tools correctly""" server = FilterableFakeMCPServer(server_name="test_server") server.add_tool("tool1", {}) server.add_tool("tool2", {}) server.add_tool("tool3", {}) - # Set allowed_tools to only include tool1 and tool2 - server.allowed_tools = ["tool1", "tool2"] + # Set tool_filter to only include tool1 and tool2 + server.tool_filter = {"allowed_tool_names": ["tool1", "tool2"]} # Get tools and verify filtering tools = await server.list_tools() @@ -46,15 +68,15 @@ async def test_server_allowed_tools(): @pytest.mark.asyncio -async def test_server_excluded_tools(): - """Test that server-level excluded_tools filters tools correctly""" +async def test_server_blocked_tool_names(): + """Test that server-level blocked_tool_names filters tools correctly""" server = FilterableFakeMCPServer(server_name="test_server") server.add_tool("tool1", {}) server.add_tool("tool2", {}) server.add_tool("tool3", {}) - # Set excluded_tools to exclude tool3 - server.excluded_tools = ["tool3"] + # Set tool_filter to exclude tool3 + server.tool_filter = {"blocked_tool_names": ["tool3"]} # Get tools and verify filtering tools = await server.list_tools() @@ -64,7 +86,7 @@ async def test_server_excluded_tools(): @pytest.mark.asyncio async def test_server_both_filters(): - """Test that server-level allowed_tools and excluded_tools work together correctly""" + """Test that server-level allowed_tool_names and blocked_tool_names work together correctly""" server = FilterableFakeMCPServer(server_name="test_server") server.add_tool("tool1", {}) server.add_tool("tool2", {}) @@ -72,10 +94,29 @@ async def test_server_both_filters(): server.add_tool("tool4", {}) # Set both filters - server.allowed_tools = ["tool1", "tool2", "tool3"] - server.excluded_tools = ["tool3"] + server.tool_filter = { + "allowed_tool_names": ["tool1", "tool2", "tool3"], + "blocked_tool_names": ["tool3"] + } - # Get tools and verify filtering (allowed_tools applied first, then excluded_tools) + # Get tools and verify filtering (allowed_tool_names applied first, then blocked_tool_names) tools = await server.list_tools() assert len(tools) == 2 assert {t.name for t in tools} == {"tool1", "tool2"} + + +@pytest.mark.asyncio +async def test_server_no_filter(): + """Test that when no filter is set, all tools are returned""" + server = FilterableFakeMCPServer(server_name="test_server") + server.add_tool("tool1", {}) + server.add_tool("tool2", {}) + server.add_tool("tool3", {}) + + # No filter set (None) + server.tool_filter = None + + # Get tools and verify no filtering + tools = await server.list_tools() + assert len(tools) == 3 + assert {t.name for t in tools} == {"tool1", "tool2", "tool3"} From e94c09e52a7e2840b5994c579f75c1429ccb5ec5 Mon Sep 17 00:00:00 2001 From: devtalker Date: Tue, 17 Jun 2025 15:00:13 +0800 Subject: [PATCH 4/7] feat: implement comprehensive MCP tool filtering with static and dynamic support --- docs/mcp.md | 69 ++++++- src/agents/agent.py | 10 +- src/agents/mcp/server.py | 131 +++++++++++--- src/agents/mcp/util.py | 28 ++- src/agents/tool.py | 2 +- tests/mcp/helpers.py | 49 ++++- tests/mcp/test_tool_filtering.py | 300 ++++++++++++++++++++++--------- 7 files changed, 459 insertions(+), 130 deletions(-) diff --git a/docs/mcp.md b/docs/mcp.md index 3524db605..dbaa5e97a 100644 --- a/docs/mcp.md +++ b/docs/mcp.md @@ -43,16 +43,24 @@ agent=Agent( ## Tool filtering -You can filter which tools are available to your Agent using server-level filtering: +You can filter which tools are available to your Agent by configuring tool filters on MCP servers. The SDK supports both static and dynamic tool filtering. + +### Static tool filtering + +For simple allow/block lists, you can use static filtering: ```python +from agents.mcp import create_static_tool_filter + # Only expose specific tools from this server server = MCPServerStdio( params={ "command": "npx", "args": ["-y", "@modelcontextprotocol/server-filesystem", samples_dir], }, - allowed_tools=["read_file", "write_file"], # Only these tools will be available + tool_filter=create_static_tool_filter( + allowed_tool_names=["read_file", "write_file"] + ) ) # Exclude specific tools from this server @@ -61,10 +69,65 @@ server = MCPServerStdio( "command": "npx", "args": ["-y", "@modelcontextprotocol/server-filesystem", samples_dir], }, - excluded_tools=["delete_file"], # This tool will be filtered out + tool_filter=create_static_tool_filter( + blocked_tool_names=["delete_file"] + ) +) + +``` + +**When both `allowed_tool_names` and `blocked_tool_names` are configured, the processing order is:** +1. First apply `allowed_tool_names` (allowlist) - only keep the specified tools +2. Then apply `blocked_tool_names` (blocklist) - exclude specified tools from the remaining tools + +For example, if you configure `allowed_tool_names=["read_file", "write_file", "delete_file"]` and `blocked_tool_names=["delete_file"]`, only `read_file` and `write_file` tools will be available. + +### Dynamic tool filtering + +For more complex filtering logic, you can use dynamic filters with functions: + +```python +from agents.mcp import ToolFilterContext + +# Simple synchronous filter +def custom_filter(context: ToolFilterContext, tool) -> bool: + """Example of a custom tool filter.""" + # Filter logic based on tool name patterns + return tool.name.startswith("allowed_prefix") + +# Context-aware filter +def context_aware_filter(context: ToolFilterContext, tool) -> bool: + """Filter tools based on context information.""" + # Access agent information + agent_name = context.agent.name + + # Access server information + server_name = context.server_name + + # Implement your custom filtering logic here + return some_filtering_logic(agent_name, server_name, tool) + +# Asynchronous filter +async def async_filter(context: ToolFilterContext, tool) -> bool: + """Example of an asynchronous filter.""" + # Perform async operations if needed + result = await some_async_check(context, tool) + return result + +server = MCPServerStdio( + params={ + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-filesystem", samples_dir], + }, + tool_filter=custom_filter # or context_aware_filter or async_filter ) ``` +The `ToolFilterContext` provides access to: +- `run_context`: The current run context +- `agent`: The agent requesting the tools +- `server_name`: The name of the MCP server + ## Caching Every time an Agent runs, it calls `list_tools()` on the MCP server. This can be a latency hit, especially if the server is a remote server. To automatically cache the list of tools, you can pass `cache_tools_list=True` to [`MCPServerStdio`][agents.mcp.server.MCPServerStdio], [`MCPServerSse`][agents.mcp.server.MCPServerSse], and [`MCPServerStreamableHttp`][agents.mcp.server.MCPServerStreamableHttp]. You should only do this if you're certain the tool list will not change. diff --git a/src/agents/agent.py b/src/agents/agent.py index 61a9abe0c..0884fe3d2 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -256,14 +256,18 @@ async def get_prompt( """Get the prompt for the agent.""" return await PromptUtil.to_model_input(self.prompt, run_context, self) - async def get_mcp_tools(self) -> list[Tool]: + async def get_mcp_tools( + self, run_context: RunContextWrapper[TContext] | None = None + ) -> list[Tool]: """Fetches the available tools from the MCP servers.""" convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False) - return await MCPUtil.get_all_function_tools(self.mcp_servers, convert_schemas_to_strict) + return await MCPUtil.get_all_function_tools( + self.mcp_servers, convert_schemas_to_strict, run_context, self + ) async def get_all_tools(self, run_context: RunContextWrapper[Any]) -> list[Tool]: """All agent tools, including MCP tools and function tools.""" - mcp_tools = await self.get_mcp_tools() + mcp_tools = await self.get_mcp_tools(run_context) async def _check_tool_enabled(tool: Tool) -> bool: if not isinstance(tool, FunctionTool): diff --git a/src/agents/mcp/server.py b/src/agents/mcp/server.py index 561956510..552c14cea 100644 --- a/src/agents/mcp/server.py +++ b/src/agents/mcp/server.py @@ -2,10 +2,11 @@ import abc import asyncio +import inspect from contextlib import AbstractAsyncContextManager, AsyncExitStack from datetime import timedelta from pathlib import Path -from typing import Any, Literal +from typing import TYPE_CHECKING, Any, Literal, cast from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from mcp import ClientSession, StdioServerParameters, Tool as MCPTool, stdio_client @@ -17,7 +18,11 @@ from ..exceptions import UserError from ..logger import logger -from .util import ToolFilter, ToolFilterStatic +from ..run_context import RunContextWrapper +from .util import ToolFilter, ToolFilterCallable, ToolFilterContext, ToolFilterStatic + +if TYPE_CHECKING: + from ..agent import Agent class MCPServer(abc.ABC): @@ -45,7 +50,11 @@ async def cleanup(self): pass @abc.abstractmethod - async def list_tools(self) -> list[MCPTool]: + async def list_tools( + self, + run_context: RunContextWrapper[Any] | None = None, + agent: Agent[Any] | None = None, + ) -> list[MCPTool]: """List the tools available on the server.""" pass @@ -90,38 +99,106 @@ def __init__( self.tool_filter = tool_filter - def _apply_tool_filter(self, tools: list[MCPTool]) -> list[MCPTool]: + async def _apply_tool_filter( + self, + tools: list[MCPTool], + run_context: RunContextWrapper[Any] | None, + agent: Agent[Any] | None, + ) -> list[MCPTool]: """Apply the tool filter to the list of tools.""" if self.tool_filter is None: return tools # Handle static tool filter if isinstance(self.tool_filter, dict): - static_filter: ToolFilterStatic = self.tool_filter - filtered_tools = tools + return self._apply_static_tool_filter(tools, self.tool_filter) - # Apply allowed_tool_names filter (whitelist) - if "allowed_tool_names" in static_filter: - allowed_names = static_filter["allowed_tool_names"] - filtered_tools = [t for t in filtered_tools if t.name in allowed_names] + # Handle callable tool filter (dynamic filter) + else: + return await self._apply_dynamic_tool_filter(tools, run_context, agent) - # Apply blocked_tool_names filter (blacklist) - if "blocked_tool_names" in static_filter: - blocked_names = static_filter["blocked_tool_names"] - filtered_tools = [t for t in filtered_tools if t.name not in blocked_names] + def _apply_static_tool_filter( + self, + tools: list[MCPTool], + static_filter: ToolFilterStatic + ) -> list[MCPTool]: + """Apply static tool filtering based on allowlist and blocklist.""" + filtered_tools = tools - return filtered_tools + # Apply allowed_tool_names filter (whitelist) + if "allowed_tool_names" in static_filter: + allowed_names = static_filter["allowed_tool_names"] + filtered_tools = [t for t in filtered_tools if t.name in allowed_names] - # Handle callable tool filter - # For now, we can't support callable filters because we don't have access to - # run context and agent in the current list_tools signature. - # This could be enhanced in the future by modifying the call chain. - else: - raise NotImplementedError( - "Callable tool filters are not yet supported. Please use ToolFilterStatic " - "with 'allowed_tool_names' and/or 'blocked_tool_names' for now." + # Apply blocked_tool_names filter (blacklist) + if "blocked_tool_names" in static_filter: + blocked_names = static_filter["blocked_tool_names"] + filtered_tools = [t for t in filtered_tools if t.name not in blocked_names] + + return filtered_tools + + async def _apply_dynamic_tool_filter( + self, + tools: list[MCPTool], + run_context: RunContextWrapper[Any] | None, + agent: Agent[Any] | None, + ) -> list[MCPTool]: + """Apply dynamic tool filtering using a callable filter function.""" + + # Ensure we have a callable filter and cast to help mypy + if not callable(self.tool_filter): + raise ValueError("Tool filter must be callable for dynamic filtering") + tool_filter_func = cast(ToolFilterCallable, self.tool_filter) + + # Create filter context - it may be None if run_context or agent is None + filter_context = None + if run_context is not None and agent is not None: + filter_context = ToolFilterContext( + run_context=run_context, + agent=agent, + server_name=self.name, ) + filtered_tools = [] + for tool in tools: + try: + # Try to call the filter function + if filter_context is not None: + # We have full context, call with context + result = tool_filter_func(filter_context, tool) + else: + # Try to call without context first to see if it works + try: + # Some filters might not need context parameters at all + result = tool_filter_func(None, tool) + except (TypeError, AttributeError) as e: + # If the filter tries to access context attributes, raise a helpful error + raise UserError( + "Dynamic tool filters require both run_context and agent when the " + "filter function accesses context information. This typically happens " + "when calling list_tools() directly without these parameters. Either " + "provide both parameters or use a static tool filter instead." + ) from e + + if inspect.isawaitable(result): + should_include = await result + else: + should_include = result + + if should_include: + filtered_tools.append(tool) + except UserError: + # Re-raise UserError as-is (this includes our context requirement error) + raise + except Exception as e: + logger.error( + f"Error applying tool filter to tool '{tool.name}' on server '{self.name}': {e}" + ) + # On error, exclude the tool for safety + continue + + return filtered_tools + @abc.abstractmethod def create_streams( self, @@ -172,7 +249,11 @@ async def connect(self): await self.cleanup() raise - async def list_tools(self) -> list[MCPTool]: + async def list_tools( + self, + run_context: RunContextWrapper[Any] | None = None, + agent: Agent[Any] | None = None, + ) -> list[MCPTool]: """List the tools available on the server.""" if not self.session: raise UserError("Server not initialized. Make sure you call `connect()` first.") @@ -190,7 +271,7 @@ async def list_tools(self) -> list[MCPTool]: # Filter tools based on tool_filter filtered_tools = tools if self.tool_filter is not None: - filtered_tools = self._apply_tool_filter(filtered_tools) + filtered_tools = await self._apply_tool_filter(filtered_tools, run_context, agent) return filtered_tools async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> CallToolResult: diff --git a/src/agents/mcp/util.py b/src/agents/mcp/util.py index 7ddce2ab3..af0f24220 100644 --- a/src/agents/mcp/util.py +++ b/src/agents/mcp/util.py @@ -2,6 +2,7 @@ import json from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Callable, Union + from typing_extensions import NotRequired, TypedDict from agents.strict_schema import ensure_strict_json_schema @@ -35,11 +36,12 @@ class ToolFilterContext: """The name of the MCP server.""" -ToolFilterCallable = Callable[["ToolFilterContext", "MCPTool"], MaybeAwaitable[bool]] +ToolFilterCallable = Callable[["ToolFilterContext | None", "MCPTool"], MaybeAwaitable[bool]] """A function that determines whether a tool should be available. Args: context: The context information including run context, agent, and server name. + Can be None if run_context or agent is not available. tool: The MCP tool to filter. Returns: @@ -51,10 +53,12 @@ class ToolFilterStatic(TypedDict): """Static tool filter configuration using allowlists and blocklists.""" allowed_tool_names: NotRequired[list[str]] - """Optional list of tool names to allow (whitelist). If set, only these tools will be available.""" + """Optional list of tool names to allow (whitelist). + If set, only these tools will be available.""" blocked_tool_names: NotRequired[list[str]] - """Optional list of tool names to exclude (blacklist). If set, these tools will be filtered out.""" + """Optional list of tool names to exclude (blacklist). + If set, these tools will be filtered out.""" ToolFilter = Union[ToolFilterCallable, ToolFilterStatic, None] @@ -93,13 +97,19 @@ class MCPUtil: @classmethod async def get_all_function_tools( - cls, servers: list["MCPServer"], convert_schemas_to_strict: bool + cls, + servers: list["MCPServer"], + convert_schemas_to_strict: bool, + run_context: RunContextWrapper[Any] | None = None, + agent: "Agent[Any] | None" = None, ) -> list[Tool]: """Get all function tools from a list of MCP servers.""" tools = [] tool_names: set[str] = set() for server in servers: - server_tools = await cls.get_function_tools(server, convert_schemas_to_strict) + server_tools = await cls.get_function_tools( + server, convert_schemas_to_strict, run_context, agent + ) server_tool_names = {tool.name for tool in server_tools} if len(server_tool_names & tool_names) > 0: raise UserError( @@ -113,12 +123,16 @@ async def get_all_function_tools( @classmethod async def get_function_tools( - cls, server: "MCPServer", convert_schemas_to_strict: bool + cls, + server: "MCPServer", + convert_schemas_to_strict: bool, + run_context: RunContextWrapper[Any] | None = None, + agent: "Agent[Any] | None" = None, ) -> list[Tool]: """Get all function tools from a single MCP server.""" with mcp_tools_span(server=server.name) as span: - tools = await server.list_tools() + tools = await server.list_tools(run_context, agent) span.span_data.result = [tool.name for tool in tools] return [cls.to_function_tool(tool, server, convert_schemas_to_strict) for tool in tools] diff --git a/src/agents/tool.py b/src/agents/tool.py index c441dd768..283a9e24f 100644 --- a/src/agents/tool.py +++ b/src/agents/tool.py @@ -26,8 +26,8 @@ from .util._types import MaybeAwaitable if TYPE_CHECKING: + from .agent import Agent - from mcp.types import Tool as MCPTool ToolParams = ParamSpec("ToolParams") diff --git a/tests/mcp/helpers.py b/tests/mcp/helpers.py index 8ff153c18..e0d8a813d 100644 --- a/tests/mcp/helpers.py +++ b/tests/mcp/helpers.py @@ -1,3 +1,4 @@ +import asyncio import json import shutil from typing import Any @@ -6,6 +7,8 @@ from mcp.types import CallToolResult, TextContent from agents.mcp import MCPServer +from agents.mcp.server import _MCPServerWithClientSession +from agents.mcp.util import ToolFilter tee = shutil.which("tee") or "" assert tee, "tee not found" @@ -28,11 +31,41 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): pass +class _TestFilterServer(_MCPServerWithClientSession): + """Minimal implementation of _MCPServerWithClientSession for testing tool filtering""" + + def __init__(self, tool_filter: ToolFilter, server_name: str): + # Initialize parent class properly to avoid type errors + super().__init__( + cache_tools_list=False, + client_session_timeout_seconds=None, + tool_filter=tool_filter, + ) + self._server_name: str = server_name + # Override some attributes for test isolation + self.session = None + self._cleanup_lock = asyncio.Lock() + + def create_streams(self): + raise NotImplementedError("Not needed for filtering tests") + + @property + def name(self) -> str: + return self._server_name + + class FakeMCPServer(MCPServer): - def __init__(self, tools: list[MCPTool] | None = None): + def __init__( + self, + tools: list[MCPTool] | None = None, + tool_filter: ToolFilter = None, + server_name: str = "fake_mcp_server", + ): self.tools: list[MCPTool] = tools or [] self.tool_calls: list[str] = [] self.tool_results: list[str] = [] + self.tool_filter = tool_filter + self._server_name = server_name def add_tool(self, name: str, input_schema: dict[str, Any]): self.tools.append(MCPTool(name=name, inputSchema=input_schema)) @@ -43,8 +76,16 @@ async def connect(self): async def cleanup(self): pass - async def list_tools(self): - return self.tools + async def list_tools(self, run_context=None, agent=None): + tools = self.tools + + # Apply tool filtering using the REAL implementation + if self.tool_filter is not None: + # Use the real _MCPServerWithClientSession filtering logic + filter_server = _TestFilterServer(self.tool_filter, self.name) + tools = await filter_server._apply_tool_filter(tools, run_context, agent) + + return tools async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> CallToolResult: self.tool_calls.append(tool_name) @@ -55,4 +96,4 @@ async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> C @property def name(self) -> str: - return "fake_mcp_server" + return self._server_name diff --git a/tests/mcp/test_tool_filtering.py b/tests/mcp/test_tool_filtering.py index 84cf729f0..33ba9531e 100644 --- a/tests/mcp/test_tool_filtering.py +++ b/tests/mcp/test_tool_filtering.py @@ -1,122 +1,248 @@ +""" +Tool filtering tests use FakeMCPServer instead of real MCPServer implementations to avoid +external dependencies (processes, network connections) and ensure fast, reliable unit tests. +FakeMCPServer delegates filtering logic to the real _MCPServerWithClientSession implementation. +""" +import asyncio + import pytest +from mcp import Tool as MCPTool + +from agents import Agent +from agents.exceptions import UserError +from agents.mcp import ToolFilterContext, create_static_tool_filter +from agents.run_context import RunContextWrapper -from agents.mcp import ToolFilterStatic from .helpers import FakeMCPServer -class FilterableFakeMCPServer(FakeMCPServer): - """Extended FakeMCPServer that supports tool filtering""" - - def __init__(self, tools=None, tool_filter=None, server_name=None): - super().__init__(tools) - self.tool_filter = tool_filter - self._server_name = server_name - - async def list_tools(self): - tools = await super().list_tools() - - # Apply filtering logic similar to _MCPServerWithClientSession - filtered_tools = tools - if self.tool_filter is not None: - filtered_tools = self._apply_tool_filter(filtered_tools) - return filtered_tools - - def _apply_tool_filter(self, tools): - """Apply the tool filter to the list of tools.""" - if self.tool_filter is None: - return tools - - # Handle static tool filter - if isinstance(self.tool_filter, dict): - static_filter: ToolFilterStatic = self.tool_filter - filtered_tools = tools - - # Apply allowed_tool_names filter (whitelist) - if "allowed_tool_names" in static_filter: - allowed_names = static_filter["allowed_tool_names"] - filtered_tools = [t for t in filtered_tools if t.name in allowed_names] - - # Apply blocked_tool_names filter (blacklist) - if "blocked_tool_names" in static_filter: - blocked_names = static_filter["blocked_tool_names"] - filtered_tools = [t for t in filtered_tools if t.name not in blocked_names] - - return filtered_tools - - return tools - - @property - def name(self) -> str: - return self._server_name or "filterable_fake_server" +def create_test_agent(name: str = "test_agent") -> Agent: + """Create a test agent for filtering tests.""" + return Agent(name=name, instructions="Test agent") +# === Static Tool Filtering Tests === + @pytest.mark.asyncio -async def test_server_allowed_tool_names(): - """Test that server-level allowed_tool_names filters tools correctly""" - server = FilterableFakeMCPServer(server_name="test_server") +async def test_static_tool_filtering(): + """Test all static tool filtering scenarios: allowed, blocked, both, none, etc.""" + server = FakeMCPServer(server_name="test_server") server.add_tool("tool1", {}) server.add_tool("tool2", {}) server.add_tool("tool3", {}) + server.add_tool("tool4", {}) - # Set tool_filter to only include tool1 and tool2 + # Test allowed_tool_names only server.tool_filter = {"allowed_tool_names": ["tool1", "tool2"]} - - # Get tools and verify filtering tools = await server.list_tools() assert len(tools) == 2 assert {t.name for t in tools} == {"tool1", "tool2"} - -@pytest.mark.asyncio -async def test_server_blocked_tool_names(): - """Test that server-level blocked_tool_names filters tools correctly""" - server = FilterableFakeMCPServer(server_name="test_server") - server.add_tool("tool1", {}) - server.add_tool("tool2", {}) - server.add_tool("tool3", {}) - - # Set tool_filter to exclude tool3 - server.tool_filter = {"blocked_tool_names": ["tool3"]} - - # Get tools and verify filtering + # Test blocked_tool_names only + server.tool_filter = {"blocked_tool_names": ["tool3", "tool4"]} tools = await server.list_tools() assert len(tools) == 2 assert {t.name for t in tools} == {"tool1", "tool2"} - -@pytest.mark.asyncio -async def test_server_both_filters(): - """Test that server-level allowed_tool_names and blocked_tool_names work together correctly""" - server = FilterableFakeMCPServer(server_name="test_server") - server.add_tool("tool1", {}) - server.add_tool("tool2", {}) - server.add_tool("tool3", {}) - server.add_tool("tool4", {}) - - # Set both filters + # Test both filters together (allowed first, then blocked) server.tool_filter = { "allowed_tool_names": ["tool1", "tool2", "tool3"], "blocked_tool_names": ["tool3"] } - - # Get tools and verify filtering (allowed_tool_names applied first, then blocked_tool_names) tools = await server.list_tools() assert len(tools) == 2 assert {t.name for t in tools} == {"tool1", "tool2"} + # Test no filter + server.tool_filter = None + tools = await server.list_tools() + assert len(tools) == 4 + + # Test helper function + server.tool_filter = create_static_tool_filter( + allowed_tool_names=["tool1", "tool2"], + blocked_tool_names=["tool2"] + ) + tools = await server.list_tools() + assert len(tools) == 1 + assert tools[0].name == "tool1" + + +# === Dynamic Tool Filtering Core Tests === @pytest.mark.asyncio -async def test_server_no_filter(): - """Test that when no filter is set, all tools are returned""" - server = FilterableFakeMCPServer(server_name="test_server") - server.add_tool("tool1", {}) - server.add_tool("tool2", {}) - server.add_tool("tool3", {}) +async def test_dynamic_filter_sync_and_async(): + """Test both synchronous and asynchronous dynamic filters""" + server = FakeMCPServer(server_name="test_server") + server.add_tool("allowed_tool", {}) + server.add_tool("blocked_tool", {}) + server.add_tool("restricted_tool", {}) + + # Test sync filter + def sync_filter(context: ToolFilterContext | None, tool: MCPTool) -> bool: + return tool.name.startswith("allowed") + + server.tool_filter = sync_filter + tools = await server.list_tools() + assert len(tools) == 1 + assert tools[0].name == "allowed_tool" - # No filter set (None) - server.tool_filter = None + # Test async filter + async def async_filter(context: ToolFilterContext | None, tool: MCPTool) -> bool: + await asyncio.sleep(0.001) # Simulate async operation + return "restricted" not in tool.name - # Get tools and verify no filtering + server.tool_filter = async_filter tools = await server.list_tools() + assert len(tools) == 2 + assert {t.name for t in tools} == {"allowed_tool", "blocked_tool"} + + +@pytest.mark.asyncio +async def test_dynamic_filter_context_handling(): + """Test dynamic filters with and without context access""" + server = FakeMCPServer(server_name="test_server") + server.add_tool("admin_tool", {}) + server.add_tool("user_tool", {}) + server.add_tool("guest_tool", {}) + + # Test context-independent filter (should work without context) + def context_independent_filter(context: ToolFilterContext | None, tool: MCPTool) -> bool: + return not tool.name.startswith("admin") + + server.tool_filter = context_independent_filter + tools = await server.list_tools(None, None) + assert len(tools) == 2 + assert {t.name for t in tools} == {"user_tool", "guest_tool"} + + # Test context-dependent filter (needs context) + def context_dependent_filter(context: ToolFilterContext | None, tool: MCPTool) -> bool: + assert context is not None + assert context.run_context is not None + assert context.agent is not None + assert context.server_name == "test_server" + + # Only admin tools for agents with "admin" in name + if "admin" in context.agent.name.lower(): + return True + else: + return not tool.name.startswith("admin") + + server.tool_filter = context_dependent_filter + + # Should work with context + run_context = RunContextWrapper(context=None) + regular_agent = create_test_agent("regular_user") + tools = await server.list_tools(run_context, regular_agent) + assert len(tools) == 2 + assert {t.name for t in tools} == {"user_tool", "guest_tool"} + + admin_agent = create_test_agent("admin_user") + tools = await server.list_tools(run_context, admin_agent) assert len(tools) == 3 - assert {t.name for t in tools} == {"tool1", "tool2", "tool3"} + + # Should fail without context when trying to access context + def context_accessing_filter(context: ToolFilterContext | None, tool: MCPTool) -> bool: + # This will raise AttributeError when context is None + return "admin" in context.agent.name.lower() # type: ignore[union-attr] + + server.tool_filter = context_accessing_filter + with pytest.raises( + UserError, + match="Dynamic tool filters require both run_context and agent when the filter " + "function accesses context information", + ): + await server.list_tools(None, None) + + +@pytest.mark.asyncio +async def test_dynamic_filter_error_handling(): + """Test error handling in dynamic filters""" + server = FakeMCPServer(server_name="test_server") + server.add_tool("good_tool", {}) + server.add_tool("error_tool", {}) + server.add_tool("another_good_tool", {}) + + def error_prone_filter(context: ToolFilterContext | None, tool: MCPTool) -> bool: + if tool.name == "error_tool": + raise ValueError("Simulated filter error") + return True + + server.tool_filter = error_prone_filter + + # Test with direct server call + tools = await server.list_tools() + assert len(tools) == 2 + assert {t.name for t in tools} == {"good_tool", "another_good_tool"} + + # Test with agent context + run_context = RunContextWrapper(context=None) + agent = create_test_agent() + tools = await server.list_tools(run_context, agent) + assert len(tools) == 2 + assert {t.name for t in tools} == {"good_tool", "another_good_tool"} + + +# === Integration Tests === + +@pytest.mark.asyncio +async def test_agent_dynamic_filtering_integration(): + """Test dynamic filtering integration with Agent methods""" + server = FakeMCPServer() + server.add_tool("file_read", {"type": "object", "properties": {"path": {"type": "string"}}}) + server.add_tool( + "file_write", + { + "type": "object", + "properties": {"path": {"type": "string"}, "content": {"type": "string"}}, + }, + ) + server.add_tool( + "database_query", {"type": "object", "properties": {"query": {"type": "string"}}} + ) + server.add_tool( + "network_request", {"type": "object", "properties": {"url": {"type": "string"}}} + ) + + # Role-based filter for comprehensive testing + async def role_based_filter(context: ToolFilterContext | None, tool: MCPTool) -> bool: + # Simulate async permission check + await asyncio.sleep(0.001) + + assert context is not None + agent_name = context.agent.name.lower() + if "admin" in agent_name: + return True + elif "readonly" in agent_name: + return "read" in tool.name or "query" in tool.name + else: + return tool.name.startswith("file_") + + server.tool_filter = role_based_filter + + # Test admin agent + admin_agent = Agent(name="admin_user", instructions="Admin", mcp_servers=[server]) + run_context = RunContextWrapper(context=None) + admin_tools = await admin_agent.get_mcp_tools(run_context) + assert len(admin_tools) == 4 + + # Test readonly agent + readonly_agent = Agent(name="readonly_viewer", instructions="Read-only", mcp_servers=[server]) + readonly_tools = await readonly_agent.get_mcp_tools(run_context) + assert len(readonly_tools) == 2 + assert {t.name for t in readonly_tools} == {"file_read", "database_query"} + + # Test regular agent + regular_agent = Agent(name="regular_user", instructions="Regular", mcp_servers=[server]) + regular_tools = await regular_agent.get_mcp_tools(run_context) + assert len(regular_tools) == 2 + assert {t.name for t in regular_tools} == {"file_read", "file_write"} + + # Test get_all_tools method + all_tools = await regular_agent.get_all_tools(run_context) + mcp_tool_names = { + t.name + for t in all_tools + if t.name in {"file_read", "file_write", "database_query", "network_request"} + } + assert mcp_tool_names == {"file_read", "file_write"} From 02f0f276969201166f291e4e6355ba44b2ea9a67 Mon Sep 17 00:00:00 2001 From: devtalker Date: Tue, 24 Jun 2025 11:16:22 +0800 Subject: [PATCH 5/7] feat!: make run_context required in Agent.get_mcp_tools() --- src/agents/agent.py | 2 +- tests/mcp/test_mcp_util.py | 9 ++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/agents/agent.py b/src/agents/agent.py index 0884fe3d2..6c87297f1 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -257,7 +257,7 @@ async def get_prompt( return await PromptUtil.to_model_input(self.prompt, run_context, self) async def get_mcp_tools( - self, run_context: RunContextWrapper[TContext] | None = None + self, run_context: RunContextWrapper[TContext] ) -> list[Tool]: """Fetches the available tools from the MCP servers.""" convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False) diff --git a/tests/mcp/test_mcp_util.py b/tests/mcp/test_mcp_util.py index 74356a16d..8dd888967 100644 --- a/tests/mcp/test_mcp_util.py +++ b/tests/mcp/test_mcp_util.py @@ -144,7 +144,8 @@ async def test_agent_convert_schemas_true(): agent = Agent( name="test_agent", mcp_servers=[server], mcp_config={"convert_schemas_to_strict": True} ) - tools = await agent.get_mcp_tools() + run_context = RunContextWrapper(context=None) + tools = await agent.get_mcp_tools(run_context) foo_tool = next(tool for tool in tools if tool.name == "foo") assert isinstance(foo_tool, FunctionTool) @@ -208,7 +209,8 @@ async def test_agent_convert_schemas_false(): agent = Agent( name="test_agent", mcp_servers=[server], mcp_config={"convert_schemas_to_strict": False} ) - tools = await agent.get_mcp_tools() + run_context = RunContextWrapper(context=None) + tools = await agent.get_mcp_tools(run_context) foo_tool = next(tool for tool in tools if tool.name == "foo") assert isinstance(foo_tool, FunctionTool) @@ -245,7 +247,8 @@ async def test_agent_convert_schemas_unset(): server.add_tool("bar", non_strict_schema) server.add_tool("baz", possible_to_convert_schema) agent = Agent(name="test_agent", mcp_servers=[server]) - tools = await agent.get_mcp_tools() + run_context = RunContextWrapper(context=None) + tools = await agent.get_mcp_tools(run_context) foo_tool = next(tool for tool in tools if tool.name == "foo") assert isinstance(foo_tool, FunctionTool) From e5eb35b9f6d296c905d1fcff3a5e2d3d7e205f6e Mon Sep 17 00:00:00 2001 From: devtalker Date: Tue, 24 Jun 2025 14:48:37 +0800 Subject: [PATCH 6/7] feat!: make run_context and agent required parameters for MCP tools listing BREAKING CHANGE: Agent.get_mcp_tools() and MCPServer.list_tools() now require run_context and agent parameters for API consistency and type safety --- docs/ja/mcp.md | 9 +++- docs/mcp.md | 9 +++- src/agents/mcp/server.py | 52 +++++++---------------- src/agents/mcp/util.py | 11 +++-- tests/mcp/test_caching.py | 22 ++++++---- tests/mcp/test_mcp_util.py | 11 +++-- tests/mcp/test_server_errors.py | 7 +++- tests/mcp/test_tool_filtering.py | 71 +++++++++++++++----------------- 8 files changed, 98 insertions(+), 94 deletions(-) diff --git a/docs/ja/mcp.md b/docs/ja/mcp.md index 09804beb2..1e394a5e6 100644 --- a/docs/ja/mcp.md +++ b/docs/ja/mcp.md @@ -23,13 +23,20 @@ Agents SDK は MCP をサポートしており、これにより幅広い MCP たとえば、[公式 MCP filesystem サーバー](https://www.npmjs.com/package/@modelcontextprotocol/server-filesystem)を利用する場合は次のようになります。 ```python +from agents.run_context import RunContextWrapper + async with MCPServerStdio( params={ "command": "npx", "args": ["-y", "@modelcontextprotocol/server-filesystem", samples_dir], } ) as server: - tools = await server.list_tools() + # 注意:実際には通常は MCP サーバーをエージェントに追加し、 + # フレームワークがツール一覧の取得を自動的に処理するようにします。 + # list_tools() への直接呼び出しには run_context と agent パラメータが必要です。 + run_context = RunContextWrapper(context=None) + agent = Agent(name="test", instructions="test") + tools = await server.list_tools(run_context, agent) ``` ## MCP サーバーの利用 diff --git a/docs/mcp.md b/docs/mcp.md index dbaa5e97a..d30a916ac 100644 --- a/docs/mcp.md +++ b/docs/mcp.md @@ -19,13 +19,20 @@ You can use the [`MCPServerStdio`][agents.mcp.server.MCPServerStdio], [`MCPServe For example, this is how you'd use the [official MCP filesystem server](https://www.npmjs.com/package/@modelcontextprotocol/server-filesystem). ```python +from agents.run_context import RunContextWrapper + async with MCPServerStdio( params={ "command": "npx", "args": ["-y", "@modelcontextprotocol/server-filesystem", samples_dir], } ) as server: - tools = await server.list_tools() + # Note: In practice, you typically add the server to an Agent + # and let the framework handle tool listing automatically. + # Direct calls to list_tools() require run_context and agent parameters. + run_context = RunContextWrapper(context=None) + agent = Agent(name="test", instructions="test") + tools = await server.list_tools(run_context, agent) ``` ## Using MCP servers diff --git a/src/agents/mcp/server.py b/src/agents/mcp/server.py index 552c14cea..f6c2b58ef 100644 --- a/src/agents/mcp/server.py +++ b/src/agents/mcp/server.py @@ -52,8 +52,8 @@ async def cleanup(self): @abc.abstractmethod async def list_tools( self, - run_context: RunContextWrapper[Any] | None = None, - agent: Agent[Any] | None = None, + run_context: RunContextWrapper[Any], + agent: Agent[Any], ) -> list[MCPTool]: """List the tools available on the server.""" pass @@ -102,8 +102,8 @@ def __init__( async def _apply_tool_filter( self, tools: list[MCPTool], - run_context: RunContextWrapper[Any] | None, - agent: Agent[Any] | None, + run_context: RunContextWrapper[Any], + agent: Agent[Any], ) -> list[MCPTool]: """Apply the tool filter to the list of tools.""" if self.tool_filter is None: @@ -140,8 +140,8 @@ def _apply_static_tool_filter( async def _apply_dynamic_tool_filter( self, tools: list[MCPTool], - run_context: RunContextWrapper[Any] | None, - agent: Agent[Any] | None, + run_context: RunContextWrapper[Any], + agent: Agent[Any], ) -> list[MCPTool]: """Apply dynamic tool filtering using a callable filter function.""" @@ -150,35 +150,18 @@ async def _apply_dynamic_tool_filter( raise ValueError("Tool filter must be callable for dynamic filtering") tool_filter_func = cast(ToolFilterCallable, self.tool_filter) - # Create filter context - it may be None if run_context or agent is None - filter_context = None - if run_context is not None and agent is not None: - filter_context = ToolFilterContext( - run_context=run_context, - agent=agent, - server_name=self.name, - ) + # Create filter context + filter_context = ToolFilterContext( + run_context=run_context, + agent=agent, + server_name=self.name, + ) filtered_tools = [] for tool in tools: try: - # Try to call the filter function - if filter_context is not None: - # We have full context, call with context - result = tool_filter_func(filter_context, tool) - else: - # Try to call without context first to see if it works - try: - # Some filters might not need context parameters at all - result = tool_filter_func(None, tool) - except (TypeError, AttributeError) as e: - # If the filter tries to access context attributes, raise a helpful error - raise UserError( - "Dynamic tool filters require both run_context and agent when the " - "filter function accesses context information. This typically happens " - "when calling list_tools() directly without these parameters. Either " - "provide both parameters or use a static tool filter instead." - ) from e + # Call the filter function with context + result = tool_filter_func(filter_context, tool) if inspect.isawaitable(result): should_include = await result @@ -187,9 +170,6 @@ async def _apply_dynamic_tool_filter( if should_include: filtered_tools.append(tool) - except UserError: - # Re-raise UserError as-is (this includes our context requirement error) - raise except Exception as e: logger.error( f"Error applying tool filter to tool '{tool.name}' on server '{self.name}': {e}" @@ -251,8 +231,8 @@ async def connect(self): async def list_tools( self, - run_context: RunContextWrapper[Any] | None = None, - agent: Agent[Any] | None = None, + run_context: RunContextWrapper[Any], + agent: Agent[Any], ) -> list[MCPTool]: """List the tools available on the server.""" if not self.session: diff --git a/src/agents/mcp/util.py b/src/agents/mcp/util.py index af0f24220..91c51ac6d 100644 --- a/src/agents/mcp/util.py +++ b/src/agents/mcp/util.py @@ -36,12 +36,11 @@ class ToolFilterContext: """The name of the MCP server.""" -ToolFilterCallable = Callable[["ToolFilterContext | None", "MCPTool"], MaybeAwaitable[bool]] +ToolFilterCallable = Callable[["ToolFilterContext", "MCPTool"], MaybeAwaitable[bool]] """A function that determines whether a tool should be available. Args: context: The context information including run context, agent, and server name. - Can be None if run_context or agent is not available. tool: The MCP tool to filter. Returns: @@ -100,8 +99,8 @@ async def get_all_function_tools( cls, servers: list["MCPServer"], convert_schemas_to_strict: bool, - run_context: RunContextWrapper[Any] | None = None, - agent: "Agent[Any] | None" = None, + run_context: RunContextWrapper[Any], + agent: "Agent[Any]", ) -> list[Tool]: """Get all function tools from a list of MCP servers.""" tools = [] @@ -126,8 +125,8 @@ async def get_function_tools( cls, server: "MCPServer", convert_schemas_to_strict: bool, - run_context: RunContextWrapper[Any] | None = None, - agent: "Agent[Any] | None" = None, + run_context: RunContextWrapper[Any], + agent: "Agent[Any]", ) -> list[Tool]: """Get all function tools from a single MCP server.""" diff --git a/tests/mcp/test_caching.py b/tests/mcp/test_caching.py index cac409e6e..f31cdf951 100644 --- a/tests/mcp/test_caching.py +++ b/tests/mcp/test_caching.py @@ -3,7 +3,9 @@ import pytest from mcp.types import ListToolsResult, Tool as MCPTool +from agents import Agent from agents.mcp import MCPServerStdio +from agents.run_context import RunContextWrapper from .helpers import DummyStreamsContextManager, tee @@ -33,25 +35,29 @@ async def test_server_caching_works( mock_list_tools.return_value = ListToolsResult(tools=tools) async with server: + # Create test context and agent + run_context = RunContextWrapper(context=None) + agent = Agent(name="test_agent", instructions="Test agent") + # Call list_tools() multiple times - tools = await server.list_tools() - assert tools == tools + result_tools = await server.list_tools(run_context, agent) + assert result_tools == tools assert mock_list_tools.call_count == 1, "list_tools() should have been called once" # Call list_tools() again, should return the cached value - tools = await server.list_tools() - assert tools == tools + result_tools = await server.list_tools(run_context, agent) + assert result_tools == tools assert mock_list_tools.call_count == 1, "list_tools() should not have been called again" # Invalidate the cache and call list_tools() again server.invalidate_tools_cache() - tools = await server.list_tools() - assert tools == tools + result_tools = await server.list_tools(run_context, agent) + assert result_tools == tools assert mock_list_tools.call_count == 2, "list_tools() should be called again" # Without invalidating the cache, calling list_tools() again should return the cached value - tools = await server.list_tools() - assert tools == tools + result_tools = await server.list_tools(run_context, agent) + assert result_tools == tools diff --git a/tests/mcp/test_mcp_util.py b/tests/mcp/test_mcp_util.py index 8dd888967..3230e63dd 100644 --- a/tests/mcp/test_mcp_util.py +++ b/tests/mcp/test_mcp_util.py @@ -57,7 +57,10 @@ async def test_get_all_function_tools(): server3.add_tool(names[4], schemas[4]) servers: list[MCPServer] = [server1, server2, server3] - tools = await MCPUtil.get_all_function_tools(servers, convert_schemas_to_strict=False) + run_context = RunContextWrapper(context=None) + agent = Agent(name="test_agent", instructions="Test agent") + + tools = await MCPUtil.get_all_function_tools(servers, False, run_context, agent) assert len(tools) == 5 assert all(tool.name in names for tool in tools) @@ -70,7 +73,7 @@ async def test_get_all_function_tools(): assert tool.name == names[idx] # Also make sure it works with strict schemas - tools = await MCPUtil.get_all_function_tools(servers, convert_schemas_to_strict=True) + tools = await MCPUtil.get_all_function_tools(servers, True, run_context, agent) assert len(tools) == 5 assert all(tool.name in names for tool in tools) @@ -282,7 +285,9 @@ async def test_util_adds_properties(): server = FakeMCPServer() server.add_tool("test_tool", schema) - tools = await MCPUtil.get_all_function_tools([server], convert_schemas_to_strict=False) + run_context = RunContextWrapper(context=None) + agent = Agent(name="test_agent", instructions="Test agent") + tools = await MCPUtil.get_all_function_tools([server], False, run_context, agent) tool = next(tool for tool in tools if tool.name == "test_tool") assert isinstance(tool, FunctionTool) diff --git a/tests/mcp/test_server_errors.py b/tests/mcp/test_server_errors.py index fbd8db17d..9e0455115 100644 --- a/tests/mcp/test_server_errors.py +++ b/tests/mcp/test_server_errors.py @@ -1,7 +1,9 @@ import pytest +from agents import Agent from agents.exceptions import UserError from agents.mcp.server import _MCPServerWithClientSession +from agents.run_context import RunContextWrapper class CrashingClientSessionServer(_MCPServerWithClientSession): @@ -35,8 +37,11 @@ async def test_server_errors_cause_error_and_cleanup_called(): async def test_not_calling_connect_causes_error(): server = CrashingClientSessionServer() + run_context = RunContextWrapper(context=None) + agent = Agent(name="test_agent", instructions="Test agent") + with pytest.raises(UserError): - await server.list_tools() + await server.list_tools(run_context, agent) with pytest.raises(UserError): await server.call_tool("foo", {}) diff --git a/tests/mcp/test_tool_filtering.py b/tests/mcp/test_tool_filtering.py index 33ba9531e..c1ffff4b8 100644 --- a/tests/mcp/test_tool_filtering.py +++ b/tests/mcp/test_tool_filtering.py @@ -9,7 +9,6 @@ from mcp import Tool as MCPTool from agents import Agent -from agents.exceptions import UserError from agents.mcp import ToolFilterContext, create_static_tool_filter from agents.run_context import RunContextWrapper @@ -21,6 +20,11 @@ def create_test_agent(name: str = "test_agent") -> Agent: return Agent(name=name, instructions="Test agent") +def create_test_context() -> RunContextWrapper: + """Create a test run context for filtering tests.""" + return RunContextWrapper(context=None) + + # === Static Tool Filtering Tests === @pytest.mark.asyncio @@ -32,15 +36,19 @@ async def test_static_tool_filtering(): server.add_tool("tool3", {}) server.add_tool("tool4", {}) + # Create test context and agent for all calls + run_context = create_test_context() + agent = create_test_agent() + # Test allowed_tool_names only server.tool_filter = {"allowed_tool_names": ["tool1", "tool2"]} - tools = await server.list_tools() + tools = await server.list_tools(run_context, agent) assert len(tools) == 2 assert {t.name for t in tools} == {"tool1", "tool2"} # Test blocked_tool_names only server.tool_filter = {"blocked_tool_names": ["tool3", "tool4"]} - tools = await server.list_tools() + tools = await server.list_tools(run_context, agent) assert len(tools) == 2 assert {t.name for t in tools} == {"tool1", "tool2"} @@ -49,13 +57,13 @@ async def test_static_tool_filtering(): "allowed_tool_names": ["tool1", "tool2", "tool3"], "blocked_tool_names": ["tool3"] } - tools = await server.list_tools() + tools = await server.list_tools(run_context, agent) assert len(tools) == 2 assert {t.name for t in tools} == {"tool1", "tool2"} # Test no filter server.tool_filter = None - tools = await server.list_tools() + tools = await server.list_tools(run_context, agent) assert len(tools) == 4 # Test helper function @@ -63,7 +71,7 @@ async def test_static_tool_filtering(): allowed_tool_names=["tool1", "tool2"], blocked_tool_names=["tool2"] ) - tools = await server.list_tools() + tools = await server.list_tools(run_context, agent) assert len(tools) == 1 assert tools[0].name == "tool1" @@ -78,45 +86,51 @@ async def test_dynamic_filter_sync_and_async(): server.add_tool("blocked_tool", {}) server.add_tool("restricted_tool", {}) + # Create test context and agent + run_context = create_test_context() + agent = create_test_agent() + # Test sync filter - def sync_filter(context: ToolFilterContext | None, tool: MCPTool) -> bool: + def sync_filter(context: ToolFilterContext, tool: MCPTool) -> bool: return tool.name.startswith("allowed") server.tool_filter = sync_filter - tools = await server.list_tools() + tools = await server.list_tools(run_context, agent) assert len(tools) == 1 assert tools[0].name == "allowed_tool" # Test async filter - async def async_filter(context: ToolFilterContext | None, tool: MCPTool) -> bool: + async def async_filter(context: ToolFilterContext, tool: MCPTool) -> bool: await asyncio.sleep(0.001) # Simulate async operation return "restricted" not in tool.name server.tool_filter = async_filter - tools = await server.list_tools() + tools = await server.list_tools(run_context, agent) assert len(tools) == 2 assert {t.name for t in tools} == {"allowed_tool", "blocked_tool"} @pytest.mark.asyncio async def test_dynamic_filter_context_handling(): - """Test dynamic filters with and without context access""" + """Test dynamic filters with context access""" server = FakeMCPServer(server_name="test_server") server.add_tool("admin_tool", {}) server.add_tool("user_tool", {}) server.add_tool("guest_tool", {}) - # Test context-independent filter (should work without context) - def context_independent_filter(context: ToolFilterContext | None, tool: MCPTool) -> bool: + # Test context-independent filter + def context_independent_filter(context: ToolFilterContext, tool: MCPTool) -> bool: return not tool.name.startswith("admin") server.tool_filter = context_independent_filter - tools = await server.list_tools(None, None) + run_context = create_test_context() + agent = create_test_agent() + tools = await server.list_tools(run_context, agent) assert len(tools) == 2 assert {t.name for t in tools} == {"user_tool", "guest_tool"} # Test context-dependent filter (needs context) - def context_dependent_filter(context: ToolFilterContext | None, tool: MCPTool) -> bool: + def context_dependent_filter(context: ToolFilterContext, tool: MCPTool) -> bool: assert context is not None assert context.run_context is not None assert context.agent is not None @@ -141,19 +155,6 @@ def context_dependent_filter(context: ToolFilterContext | None, tool: MCPTool) - tools = await server.list_tools(run_context, admin_agent) assert len(tools) == 3 - # Should fail without context when trying to access context - def context_accessing_filter(context: ToolFilterContext | None, tool: MCPTool) -> bool: - # This will raise AttributeError when context is None - return "admin" in context.agent.name.lower() # type: ignore[union-attr] - - server.tool_filter = context_accessing_filter - with pytest.raises( - UserError, - match="Dynamic tool filters require both run_context and agent when the filter " - "function accesses context information", - ): - await server.list_tools(None, None) - @pytest.mark.asyncio async def test_dynamic_filter_error_handling(): @@ -163,20 +164,15 @@ async def test_dynamic_filter_error_handling(): server.add_tool("error_tool", {}) server.add_tool("another_good_tool", {}) - def error_prone_filter(context: ToolFilterContext | None, tool: MCPTool) -> bool: + def error_prone_filter(context: ToolFilterContext, tool: MCPTool) -> bool: if tool.name == "error_tool": raise ValueError("Simulated filter error") return True server.tool_filter = error_prone_filter - # Test with direct server call - tools = await server.list_tools() - assert len(tools) == 2 - assert {t.name for t in tools} == {"good_tool", "another_good_tool"} - - # Test with agent context - run_context = RunContextWrapper(context=None) + # Test with server call + run_context = create_test_context() agent = create_test_agent() tools = await server.list_tools(run_context, agent) assert len(tools) == 2 @@ -205,11 +201,10 @@ async def test_agent_dynamic_filtering_integration(): ) # Role-based filter for comprehensive testing - async def role_based_filter(context: ToolFilterContext | None, tool: MCPTool) -> bool: + async def role_based_filter(context: ToolFilterContext, tool: MCPTool) -> bool: # Simulate async permission check await asyncio.sleep(0.001) - assert context is not None agent_name = context.agent.name.lower() if "admin" in agent_name: return True From 42acf094e8df9c907b22c7ca602c10fdfb2e4704 Mon Sep 17 00:00:00 2001 From: devtalker Date: Wed, 25 Jun 2025 16:37:09 +0800 Subject: [PATCH 7/7] refactor: replace union type annotations with Optional --- src/agents/mcp/util.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/agents/mcp/util.py b/src/agents/mcp/util.py index 91c51ac6d..48da9f841 100644 --- a/src/agents/mcp/util.py +++ b/src/agents/mcp/util.py @@ -1,7 +1,7 @@ import functools import json from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Callable, Union +from typing import TYPE_CHECKING, Any, Callable, Optional, Union from typing_extensions import NotRequired, TypedDict @@ -65,9 +65,9 @@ class ToolFilterStatic(TypedDict): def create_static_tool_filter( - allowed_tool_names: list[str] | None = None, - blocked_tool_names: list[str] | None = None, -) -> ToolFilterStatic | None: + allowed_tool_names: Optional[list[str]] = None, + blocked_tool_names: Optional[list[str]] = None, +) -> Optional[ToolFilterStatic]: """Create a static tool filter from allowlist and blocklist parameters. This is a convenience function for creating a ToolFilterStatic.