diff --git a/docs/ja/mcp.md b/docs/ja/mcp.md index 09804beb2..1e394a5e6 100644 --- a/docs/ja/mcp.md +++ b/docs/ja/mcp.md @@ -23,13 +23,20 @@ Agents SDK は MCP をサポートしており、これにより幅広い MCP たとえば、[公式 MCP filesystem サーバー](https://www.npmjs.com/package/@modelcontextprotocol/server-filesystem)を利用する場合は次のようになります。 ```python +from agents.run_context import RunContextWrapper + async with MCPServerStdio( params={ "command": "npx", "args": ["-y", "@modelcontextprotocol/server-filesystem", samples_dir], } ) as server: - tools = await server.list_tools() + # 注意:実際には通常は MCP サーバーをエージェントに追加し、 + # フレームワークがツール一覧の取得を自動的に処理するようにします。 + # list_tools() への直接呼び出しには run_context と agent パラメータが必要です。 + run_context = RunContextWrapper(context=None) + agent = Agent(name="test", instructions="test") + tools = await server.list_tools(run_context, agent) ``` ## MCP サーバーの利用 diff --git a/docs/mcp.md b/docs/mcp.md index 76d142029..d30a916ac 100644 --- a/docs/mcp.md +++ b/docs/mcp.md @@ -19,13 +19,20 @@ You can use the [`MCPServerStdio`][agents.mcp.server.MCPServerStdio], [`MCPServe For example, this is how you'd use the [official MCP filesystem server](https://www.npmjs.com/package/@modelcontextprotocol/server-filesystem). ```python +from agents.run_context import RunContextWrapper + async with MCPServerStdio( params={ "command": "npx", "args": ["-y", "@modelcontextprotocol/server-filesystem", samples_dir], } ) as server: - tools = await server.list_tools() + # Note: In practice, you typically add the server to an Agent + # and let the framework handle tool listing automatically. + # Direct calls to list_tools() require run_context and agent parameters. + run_context = RunContextWrapper(context=None) + agent = Agent(name="test", instructions="test") + tools = await server.list_tools(run_context, agent) ``` ## Using MCP servers @@ -41,6 +48,93 @@ agent=Agent( ) ``` +## Tool filtering + +You can filter which tools are available to your Agent by configuring tool filters on MCP servers. The SDK supports both static and dynamic tool filtering. + +### Static tool filtering + +For simple allow/block lists, you can use static filtering: + +```python +from agents.mcp import create_static_tool_filter + +# Only expose specific tools from this server +server = MCPServerStdio( + params={ + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-filesystem", samples_dir], + }, + tool_filter=create_static_tool_filter( + allowed_tool_names=["read_file", "write_file"] + ) +) + +# Exclude specific tools from this server +server = MCPServerStdio( + params={ + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-filesystem", samples_dir], + }, + tool_filter=create_static_tool_filter( + blocked_tool_names=["delete_file"] + ) +) + +``` + +**When both `allowed_tool_names` and `blocked_tool_names` are configured, the processing order is:** +1. First apply `allowed_tool_names` (allowlist) - only keep the specified tools +2. Then apply `blocked_tool_names` (blocklist) - exclude specified tools from the remaining tools + +For example, if you configure `allowed_tool_names=["read_file", "write_file", "delete_file"]` and `blocked_tool_names=["delete_file"]`, only `read_file` and `write_file` tools will be available. + +### Dynamic tool filtering + +For more complex filtering logic, you can use dynamic filters with functions: + +```python +from agents.mcp import ToolFilterContext + +# Simple synchronous filter +def custom_filter(context: ToolFilterContext, tool) -> bool: + """Example of a custom tool filter.""" + # Filter logic based on tool name patterns + return tool.name.startswith("allowed_prefix") + +# Context-aware filter +def context_aware_filter(context: ToolFilterContext, tool) -> bool: + """Filter tools based on context information.""" + # Access agent information + agent_name = context.agent.name + + # Access server information + server_name = context.server_name + + # Implement your custom filtering logic here + return some_filtering_logic(agent_name, server_name, tool) + +# Asynchronous filter +async def async_filter(context: ToolFilterContext, tool) -> bool: + """Example of an asynchronous filter.""" + # Perform async operations if needed + result = await some_async_check(context, tool) + return result + +server = MCPServerStdio( + params={ + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-filesystem", samples_dir], + }, + tool_filter=custom_filter # or context_aware_filter or async_filter +) +``` + +The `ToolFilterContext` provides access to: +- `run_context`: The current run context +- `agent`: The agent requesting the tools +- `server_name`: The name of the MCP server + ## Caching Every time an Agent runs, it calls `list_tools()` on the MCP server. This can be a latency hit, especially if the server is a remote server. To automatically cache the list of tools, you can pass `cache_tools_list=True` to [`MCPServerStdio`][agents.mcp.server.MCPServerStdio], [`MCPServerSse`][agents.mcp.server.MCPServerSse], and [`MCPServerStreamableHttp`][agents.mcp.server.MCPServerStreamableHttp]. You should only do this if you're certain the tool list will not change. diff --git a/src/agents/agent.py b/src/agents/agent.py index 61a9abe0c..6c87297f1 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -256,14 +256,18 @@ async def get_prompt( """Get the prompt for the agent.""" return await PromptUtil.to_model_input(self.prompt, run_context, self) - async def get_mcp_tools(self) -> list[Tool]: + async def get_mcp_tools( + self, run_context: RunContextWrapper[TContext] + ) -> list[Tool]: """Fetches the available tools from the MCP servers.""" convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False) - return await MCPUtil.get_all_function_tools(self.mcp_servers, convert_schemas_to_strict) + return await MCPUtil.get_all_function_tools( + self.mcp_servers, convert_schemas_to_strict, run_context, self + ) async def get_all_tools(self, run_context: RunContextWrapper[Any]) -> list[Tool]: """All agent tools, including MCP tools and function tools.""" - mcp_tools = await self.get_mcp_tools() + mcp_tools = await self.get_mcp_tools(run_context) async def _check_tool_enabled(tool: Tool) -> bool: if not isinstance(tool, FunctionTool): diff --git a/src/agents/mcp/__init__.py b/src/agents/mcp/__init__.py index d4eb8fa68..da5a68b16 100644 --- a/src/agents/mcp/__init__.py +++ b/src/agents/mcp/__init__.py @@ -11,7 +11,14 @@ except ImportError: pass -from .util import MCPUtil +from .util import ( + MCPUtil, + ToolFilter, + ToolFilterCallable, + ToolFilterContext, + ToolFilterStatic, + create_static_tool_filter, +) __all__ = [ "MCPServer", @@ -22,4 +29,9 @@ "MCPServerStreamableHttp", "MCPServerStreamableHttpParams", "MCPUtil", + "ToolFilter", + "ToolFilterCallable", + "ToolFilterContext", + "ToolFilterStatic", + "create_static_tool_filter", ] diff --git a/src/agents/mcp/server.py b/src/agents/mcp/server.py index 3d1e17790..f6c2b58ef 100644 --- a/src/agents/mcp/server.py +++ b/src/agents/mcp/server.py @@ -2,10 +2,11 @@ import abc import asyncio +import inspect from contextlib import AbstractAsyncContextManager, AsyncExitStack from datetime import timedelta from pathlib import Path -from typing import Any, Literal +from typing import TYPE_CHECKING, Any, Literal, cast from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from mcp import ClientSession, StdioServerParameters, Tool as MCPTool, stdio_client @@ -17,6 +18,11 @@ from ..exceptions import UserError from ..logger import logger +from ..run_context import RunContextWrapper +from .util import ToolFilter, ToolFilterCallable, ToolFilterContext, ToolFilterStatic + +if TYPE_CHECKING: + from ..agent import Agent class MCPServer(abc.ABC): @@ -44,7 +50,11 @@ async def cleanup(self): pass @abc.abstractmethod - async def list_tools(self) -> list[MCPTool]: + async def list_tools( + self, + run_context: RunContextWrapper[Any], + agent: Agent[Any], + ) -> list[MCPTool]: """List the tools available on the server.""" pass @@ -57,7 +67,12 @@ async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> C class _MCPServerWithClientSession(MCPServer, abc.ABC): """Base class for MCP servers that use a `ClientSession` to communicate with the server.""" - def __init__(self, cache_tools_list: bool, client_session_timeout_seconds: float | None): + def __init__( + self, + cache_tools_list: bool, + client_session_timeout_seconds: float | None, + tool_filter: ToolFilter = None, + ): """ Args: cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be @@ -68,6 +83,7 @@ def __init__(self, cache_tools_list: bool, client_session_timeout_seconds: float (by avoiding a round-trip to the server every time). client_session_timeout_seconds: the read timeout passed to the MCP ClientSession. + tool_filter: The tool filter to use for filtering tools. """ self.session: ClientSession | None = None self.exit_stack: AsyncExitStack = AsyncExitStack() @@ -81,6 +97,88 @@ def __init__(self, cache_tools_list: bool, client_session_timeout_seconds: float self._cache_dirty = True self._tools_list: list[MCPTool] | None = None + self.tool_filter = tool_filter + + async def _apply_tool_filter( + self, + tools: list[MCPTool], + run_context: RunContextWrapper[Any], + agent: Agent[Any], + ) -> list[MCPTool]: + """Apply the tool filter to the list of tools.""" + if self.tool_filter is None: + return tools + + # Handle static tool filter + if isinstance(self.tool_filter, dict): + return self._apply_static_tool_filter(tools, self.tool_filter) + + # Handle callable tool filter (dynamic filter) + else: + return await self._apply_dynamic_tool_filter(tools, run_context, agent) + + def _apply_static_tool_filter( + self, + tools: list[MCPTool], + static_filter: ToolFilterStatic + ) -> list[MCPTool]: + """Apply static tool filtering based on allowlist and blocklist.""" + filtered_tools = tools + + # Apply allowed_tool_names filter (whitelist) + if "allowed_tool_names" in static_filter: + allowed_names = static_filter["allowed_tool_names"] + filtered_tools = [t for t in filtered_tools if t.name in allowed_names] + + # Apply blocked_tool_names filter (blacklist) + if "blocked_tool_names" in static_filter: + blocked_names = static_filter["blocked_tool_names"] + filtered_tools = [t for t in filtered_tools if t.name not in blocked_names] + + return filtered_tools + + async def _apply_dynamic_tool_filter( + self, + tools: list[MCPTool], + run_context: RunContextWrapper[Any], + agent: Agent[Any], + ) -> list[MCPTool]: + """Apply dynamic tool filtering using a callable filter function.""" + + # Ensure we have a callable filter and cast to help mypy + if not callable(self.tool_filter): + raise ValueError("Tool filter must be callable for dynamic filtering") + tool_filter_func = cast(ToolFilterCallable, self.tool_filter) + + # Create filter context + filter_context = ToolFilterContext( + run_context=run_context, + agent=agent, + server_name=self.name, + ) + + filtered_tools = [] + for tool in tools: + try: + # Call the filter function with context + result = tool_filter_func(filter_context, tool) + + if inspect.isawaitable(result): + should_include = await result + else: + should_include = result + + if should_include: + filtered_tools.append(tool) + except Exception as e: + logger.error( + f"Error applying tool filter to tool '{tool.name}' on server '{self.name}': {e}" + ) + # On error, exclude the tool for safety + continue + + return filtered_tools + @abc.abstractmethod def create_streams( self, @@ -131,21 +229,30 @@ async def connect(self): await self.cleanup() raise - async def list_tools(self) -> list[MCPTool]: + async def list_tools( + self, + run_context: RunContextWrapper[Any], + agent: Agent[Any], + ) -> list[MCPTool]: """List the tools available on the server.""" if not self.session: raise UserError("Server not initialized. Make sure you call `connect()` first.") # Return from cache if caching is enabled, we have tools, and the cache is not dirty if self.cache_tools_list and not self._cache_dirty and self._tools_list: - return self._tools_list - - # Reset the cache dirty to False - self._cache_dirty = False - - # Fetch the tools from the server - self._tools_list = (await self.session.list_tools()).tools - return self._tools_list + tools = self._tools_list + else: + # Reset the cache dirty to False + self._cache_dirty = False + # Fetch the tools from the server + self._tools_list = (await self.session.list_tools()).tools + tools = self._tools_list + + # Filter tools based on tool_filter + filtered_tools = tools + if self.tool_filter is not None: + filtered_tools = await self._apply_tool_filter(filtered_tools, run_context, agent) + return filtered_tools async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> CallToolResult: """Invoke a tool on the server.""" @@ -206,6 +313,7 @@ def __init__( cache_tools_list: bool = False, name: str | None = None, client_session_timeout_seconds: float | None = 5, + tool_filter: ToolFilter = None, ): """Create a new MCP server based on the stdio transport. @@ -223,8 +331,13 @@ def __init__( name: A readable name for the server. If not provided, we'll create one from the command. client_session_timeout_seconds: the read timeout passed to the MCP ClientSession. + tool_filter: The tool filter to use for filtering tools. """ - super().__init__(cache_tools_list, client_session_timeout_seconds) + super().__init__( + cache_tools_list, + client_session_timeout_seconds, + tool_filter, + ) self.params = StdioServerParameters( command=params["command"], @@ -283,6 +396,7 @@ def __init__( cache_tools_list: bool = False, name: str | None = None, client_session_timeout_seconds: float | None = 5, + tool_filter: ToolFilter = None, ): """Create a new MCP server based on the HTTP with SSE transport. @@ -302,8 +416,13 @@ def __init__( URL. client_session_timeout_seconds: the read timeout passed to the MCP ClientSession. + tool_filter: The tool filter to use for filtering tools. """ - super().__init__(cache_tools_list, client_session_timeout_seconds) + super().__init__( + cache_tools_list, + client_session_timeout_seconds, + tool_filter, + ) self.params = params self._name = name or f"sse: {self.params['url']}" @@ -362,6 +481,7 @@ def __init__( cache_tools_list: bool = False, name: str | None = None, client_session_timeout_seconds: float | None = 5, + tool_filter: ToolFilter = None, ): """Create a new MCP server based on the Streamable HTTP transport. @@ -382,8 +502,13 @@ def __init__( URL. client_session_timeout_seconds: the read timeout passed to the MCP ClientSession. + tool_filter: The tool filter to use for filtering tools. """ - super().__init__(cache_tools_list, client_session_timeout_seconds) + super().__init__( + cache_tools_list, + client_session_timeout_seconds, + tool_filter, + ) self.params = params self._name = name or f"streamable_http: {self.params['url']}" diff --git a/src/agents/mcp/util.py b/src/agents/mcp/util.py index 5a963bc01..48da9f841 100644 --- a/src/agents/mcp/util.py +++ b/src/agents/mcp/util.py @@ -1,6 +1,9 @@ import functools import json -from typing import TYPE_CHECKING, Any +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Callable, Optional, Union + +from typing_extensions import NotRequired, TypedDict from agents.strict_schema import ensure_strict_json_schema @@ -10,25 +13,102 @@ from ..run_context import RunContextWrapper from ..tool import FunctionTool, Tool from ..tracing import FunctionSpanData, get_current_span, mcp_tools_span +from ..util._types import MaybeAwaitable if TYPE_CHECKING: from mcp.types import Tool as MCPTool + from ..agent import Agent from .server import MCPServer +@dataclass +class ToolFilterContext: + """Context information available to tool filter functions.""" + + run_context: RunContextWrapper[Any] + """The current run context.""" + + agent: "Agent[Any]" + """The agent that is requesting the tool list.""" + + server_name: str + """The name of the MCP server.""" + + +ToolFilterCallable = Callable[["ToolFilterContext", "MCPTool"], MaybeAwaitable[bool]] +"""A function that determines whether a tool should be available. + +Args: + context: The context information including run context, agent, and server name. + tool: The MCP tool to filter. + +Returns: + Whether the tool should be available (True) or filtered out (False). +""" + + +class ToolFilterStatic(TypedDict): + """Static tool filter configuration using allowlists and blocklists.""" + + allowed_tool_names: NotRequired[list[str]] + """Optional list of tool names to allow (whitelist). + If set, only these tools will be available.""" + + blocked_tool_names: NotRequired[list[str]] + """Optional list of tool names to exclude (blacklist). + If set, these tools will be filtered out.""" + + +ToolFilter = Union[ToolFilterCallable, ToolFilterStatic, None] +"""A tool filter that can be either a function, static configuration, or None (no filtering).""" + + +def create_static_tool_filter( + allowed_tool_names: Optional[list[str]] = None, + blocked_tool_names: Optional[list[str]] = None, +) -> Optional[ToolFilterStatic]: + """Create a static tool filter from allowlist and blocklist parameters. + + This is a convenience function for creating a ToolFilterStatic. + + Args: + allowed_tool_names: Optional list of tool names to allow (whitelist). + blocked_tool_names: Optional list of tool names to exclude (blacklist). + + Returns: + A ToolFilterStatic if any filtering is specified, None otherwise. + """ + if allowed_tool_names is None and blocked_tool_names is None: + return None + + filter_dict: ToolFilterStatic = {} + if allowed_tool_names is not None: + filter_dict["allowed_tool_names"] = allowed_tool_names + if blocked_tool_names is not None: + filter_dict["blocked_tool_names"] = blocked_tool_names + + return filter_dict + + class MCPUtil: """Set of utilities for interop between MCP and Agents SDK tools.""" @classmethod async def get_all_function_tools( - cls, servers: list["MCPServer"], convert_schemas_to_strict: bool + cls, + servers: list["MCPServer"], + convert_schemas_to_strict: bool, + run_context: RunContextWrapper[Any], + agent: "Agent[Any]", ) -> list[Tool]: """Get all function tools from a list of MCP servers.""" tools = [] tool_names: set[str] = set() for server in servers: - server_tools = await cls.get_function_tools(server, convert_schemas_to_strict) + server_tools = await cls.get_function_tools( + server, convert_schemas_to_strict, run_context, agent + ) server_tool_names = {tool.name for tool in server_tools} if len(server_tool_names & tool_names) > 0: raise UserError( @@ -42,12 +122,16 @@ async def get_all_function_tools( @classmethod async def get_function_tools( - cls, server: "MCPServer", convert_schemas_to_strict: bool + cls, + server: "MCPServer", + convert_schemas_to_strict: bool, + run_context: RunContextWrapper[Any], + agent: "Agent[Any]", ) -> list[Tool]: """Get all function tools from a single MCP server.""" with mcp_tools_span(server=server.name) as span: - tools = await server.list_tools() + tools = await server.list_tools(run_context, agent) span.span_data.result = [tool.name for tool in tools] return [cls.to_function_tool(tool, server, convert_schemas_to_strict) for tool in tools] diff --git a/src/agents/tool.py b/src/agents/tool.py index ce66a53ba..283a9e24f 100644 --- a/src/agents/tool.py +++ b/src/agents/tool.py @@ -26,6 +26,7 @@ from .util._types import MaybeAwaitable if TYPE_CHECKING: + from .agent import Agent ToolParams = ParamSpec("ToolParams") diff --git a/tests/mcp/helpers.py b/tests/mcp/helpers.py index 8ff153c18..e0d8a813d 100644 --- a/tests/mcp/helpers.py +++ b/tests/mcp/helpers.py @@ -1,3 +1,4 @@ +import asyncio import json import shutil from typing import Any @@ -6,6 +7,8 @@ from mcp.types import CallToolResult, TextContent from agents.mcp import MCPServer +from agents.mcp.server import _MCPServerWithClientSession +from agents.mcp.util import ToolFilter tee = shutil.which("tee") or "" assert tee, "tee not found" @@ -28,11 +31,41 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): pass +class _TestFilterServer(_MCPServerWithClientSession): + """Minimal implementation of _MCPServerWithClientSession for testing tool filtering""" + + def __init__(self, tool_filter: ToolFilter, server_name: str): + # Initialize parent class properly to avoid type errors + super().__init__( + cache_tools_list=False, + client_session_timeout_seconds=None, + tool_filter=tool_filter, + ) + self._server_name: str = server_name + # Override some attributes for test isolation + self.session = None + self._cleanup_lock = asyncio.Lock() + + def create_streams(self): + raise NotImplementedError("Not needed for filtering tests") + + @property + def name(self) -> str: + return self._server_name + + class FakeMCPServer(MCPServer): - def __init__(self, tools: list[MCPTool] | None = None): + def __init__( + self, + tools: list[MCPTool] | None = None, + tool_filter: ToolFilter = None, + server_name: str = "fake_mcp_server", + ): self.tools: list[MCPTool] = tools or [] self.tool_calls: list[str] = [] self.tool_results: list[str] = [] + self.tool_filter = tool_filter + self._server_name = server_name def add_tool(self, name: str, input_schema: dict[str, Any]): self.tools.append(MCPTool(name=name, inputSchema=input_schema)) @@ -43,8 +76,16 @@ async def connect(self): async def cleanup(self): pass - async def list_tools(self): - return self.tools + async def list_tools(self, run_context=None, agent=None): + tools = self.tools + + # Apply tool filtering using the REAL implementation + if self.tool_filter is not None: + # Use the real _MCPServerWithClientSession filtering logic + filter_server = _TestFilterServer(self.tool_filter, self.name) + tools = await filter_server._apply_tool_filter(tools, run_context, agent) + + return tools async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> CallToolResult: self.tool_calls.append(tool_name) @@ -55,4 +96,4 @@ async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> C @property def name(self) -> str: - return "fake_mcp_server" + return self._server_name diff --git a/tests/mcp/test_caching.py b/tests/mcp/test_caching.py index cac409e6e..f31cdf951 100644 --- a/tests/mcp/test_caching.py +++ b/tests/mcp/test_caching.py @@ -3,7 +3,9 @@ import pytest from mcp.types import ListToolsResult, Tool as MCPTool +from agents import Agent from agents.mcp import MCPServerStdio +from agents.run_context import RunContextWrapper from .helpers import DummyStreamsContextManager, tee @@ -33,25 +35,29 @@ async def test_server_caching_works( mock_list_tools.return_value = ListToolsResult(tools=tools) async with server: + # Create test context and agent + run_context = RunContextWrapper(context=None) + agent = Agent(name="test_agent", instructions="Test agent") + # Call list_tools() multiple times - tools = await server.list_tools() - assert tools == tools + result_tools = await server.list_tools(run_context, agent) + assert result_tools == tools assert mock_list_tools.call_count == 1, "list_tools() should have been called once" # Call list_tools() again, should return the cached value - tools = await server.list_tools() - assert tools == tools + result_tools = await server.list_tools(run_context, agent) + assert result_tools == tools assert mock_list_tools.call_count == 1, "list_tools() should not have been called again" # Invalidate the cache and call list_tools() again server.invalidate_tools_cache() - tools = await server.list_tools() - assert tools == tools + result_tools = await server.list_tools(run_context, agent) + assert result_tools == tools assert mock_list_tools.call_count == 2, "list_tools() should be called again" # Without invalidating the cache, calling list_tools() again should return the cached value - tools = await server.list_tools() - assert tools == tools + result_tools = await server.list_tools(run_context, agent) + assert result_tools == tools diff --git a/tests/mcp/test_mcp_util.py b/tests/mcp/test_mcp_util.py index 74356a16d..3230e63dd 100644 --- a/tests/mcp/test_mcp_util.py +++ b/tests/mcp/test_mcp_util.py @@ -57,7 +57,10 @@ async def test_get_all_function_tools(): server3.add_tool(names[4], schemas[4]) servers: list[MCPServer] = [server1, server2, server3] - tools = await MCPUtil.get_all_function_tools(servers, convert_schemas_to_strict=False) + run_context = RunContextWrapper(context=None) + agent = Agent(name="test_agent", instructions="Test agent") + + tools = await MCPUtil.get_all_function_tools(servers, False, run_context, agent) assert len(tools) == 5 assert all(tool.name in names for tool in tools) @@ -70,7 +73,7 @@ async def test_get_all_function_tools(): assert tool.name == names[idx] # Also make sure it works with strict schemas - tools = await MCPUtil.get_all_function_tools(servers, convert_schemas_to_strict=True) + tools = await MCPUtil.get_all_function_tools(servers, True, run_context, agent) assert len(tools) == 5 assert all(tool.name in names for tool in tools) @@ -144,7 +147,8 @@ async def test_agent_convert_schemas_true(): agent = Agent( name="test_agent", mcp_servers=[server], mcp_config={"convert_schemas_to_strict": True} ) - tools = await agent.get_mcp_tools() + run_context = RunContextWrapper(context=None) + tools = await agent.get_mcp_tools(run_context) foo_tool = next(tool for tool in tools if tool.name == "foo") assert isinstance(foo_tool, FunctionTool) @@ -208,7 +212,8 @@ async def test_agent_convert_schemas_false(): agent = Agent( name="test_agent", mcp_servers=[server], mcp_config={"convert_schemas_to_strict": False} ) - tools = await agent.get_mcp_tools() + run_context = RunContextWrapper(context=None) + tools = await agent.get_mcp_tools(run_context) foo_tool = next(tool for tool in tools if tool.name == "foo") assert isinstance(foo_tool, FunctionTool) @@ -245,7 +250,8 @@ async def test_agent_convert_schemas_unset(): server.add_tool("bar", non_strict_schema) server.add_tool("baz", possible_to_convert_schema) agent = Agent(name="test_agent", mcp_servers=[server]) - tools = await agent.get_mcp_tools() + run_context = RunContextWrapper(context=None) + tools = await agent.get_mcp_tools(run_context) foo_tool = next(tool for tool in tools if tool.name == "foo") assert isinstance(foo_tool, FunctionTool) @@ -279,7 +285,9 @@ async def test_util_adds_properties(): server = FakeMCPServer() server.add_tool("test_tool", schema) - tools = await MCPUtil.get_all_function_tools([server], convert_schemas_to_strict=False) + run_context = RunContextWrapper(context=None) + agent = Agent(name="test_agent", instructions="Test agent") + tools = await MCPUtil.get_all_function_tools([server], False, run_context, agent) tool = next(tool for tool in tools if tool.name == "test_tool") assert isinstance(tool, FunctionTool) diff --git a/tests/mcp/test_server_errors.py b/tests/mcp/test_server_errors.py index fbd8db17d..9e0455115 100644 --- a/tests/mcp/test_server_errors.py +++ b/tests/mcp/test_server_errors.py @@ -1,7 +1,9 @@ import pytest +from agents import Agent from agents.exceptions import UserError from agents.mcp.server import _MCPServerWithClientSession +from agents.run_context import RunContextWrapper class CrashingClientSessionServer(_MCPServerWithClientSession): @@ -35,8 +37,11 @@ async def test_server_errors_cause_error_and_cleanup_called(): async def test_not_calling_connect_causes_error(): server = CrashingClientSessionServer() + run_context = RunContextWrapper(context=None) + agent = Agent(name="test_agent", instructions="Test agent") + with pytest.raises(UserError): - await server.list_tools() + await server.list_tools(run_context, agent) with pytest.raises(UserError): await server.call_tool("foo", {}) diff --git a/tests/mcp/test_tool_filtering.py b/tests/mcp/test_tool_filtering.py new file mode 100644 index 000000000..c1ffff4b8 --- /dev/null +++ b/tests/mcp/test_tool_filtering.py @@ -0,0 +1,243 @@ +""" +Tool filtering tests use FakeMCPServer instead of real MCPServer implementations to avoid +external dependencies (processes, network connections) and ensure fast, reliable unit tests. +FakeMCPServer delegates filtering logic to the real _MCPServerWithClientSession implementation. +""" +import asyncio + +import pytest +from mcp import Tool as MCPTool + +from agents import Agent +from agents.mcp import ToolFilterContext, create_static_tool_filter +from agents.run_context import RunContextWrapper + +from .helpers import FakeMCPServer + + +def create_test_agent(name: str = "test_agent") -> Agent: + """Create a test agent for filtering tests.""" + return Agent(name=name, instructions="Test agent") + + +def create_test_context() -> RunContextWrapper: + """Create a test run context for filtering tests.""" + return RunContextWrapper(context=None) + + +# === Static Tool Filtering Tests === + +@pytest.mark.asyncio +async def test_static_tool_filtering(): + """Test all static tool filtering scenarios: allowed, blocked, both, none, etc.""" + server = FakeMCPServer(server_name="test_server") + server.add_tool("tool1", {}) + server.add_tool("tool2", {}) + server.add_tool("tool3", {}) + server.add_tool("tool4", {}) + + # Create test context and agent for all calls + run_context = create_test_context() + agent = create_test_agent() + + # Test allowed_tool_names only + server.tool_filter = {"allowed_tool_names": ["tool1", "tool2"]} + tools = await server.list_tools(run_context, agent) + assert len(tools) == 2 + assert {t.name for t in tools} == {"tool1", "tool2"} + + # Test blocked_tool_names only + server.tool_filter = {"blocked_tool_names": ["tool3", "tool4"]} + tools = await server.list_tools(run_context, agent) + assert len(tools) == 2 + assert {t.name for t in tools} == {"tool1", "tool2"} + + # Test both filters together (allowed first, then blocked) + server.tool_filter = { + "allowed_tool_names": ["tool1", "tool2", "tool3"], + "blocked_tool_names": ["tool3"] + } + tools = await server.list_tools(run_context, agent) + assert len(tools) == 2 + assert {t.name for t in tools} == {"tool1", "tool2"} + + # Test no filter + server.tool_filter = None + tools = await server.list_tools(run_context, agent) + assert len(tools) == 4 + + # Test helper function + server.tool_filter = create_static_tool_filter( + allowed_tool_names=["tool1", "tool2"], + blocked_tool_names=["tool2"] + ) + tools = await server.list_tools(run_context, agent) + assert len(tools) == 1 + assert tools[0].name == "tool1" + + +# === Dynamic Tool Filtering Core Tests === + +@pytest.mark.asyncio +async def test_dynamic_filter_sync_and_async(): + """Test both synchronous and asynchronous dynamic filters""" + server = FakeMCPServer(server_name="test_server") + server.add_tool("allowed_tool", {}) + server.add_tool("blocked_tool", {}) + server.add_tool("restricted_tool", {}) + + # Create test context and agent + run_context = create_test_context() + agent = create_test_agent() + + # Test sync filter + def sync_filter(context: ToolFilterContext, tool: MCPTool) -> bool: + return tool.name.startswith("allowed") + + server.tool_filter = sync_filter + tools = await server.list_tools(run_context, agent) + assert len(tools) == 1 + assert tools[0].name == "allowed_tool" + + # Test async filter + async def async_filter(context: ToolFilterContext, tool: MCPTool) -> bool: + await asyncio.sleep(0.001) # Simulate async operation + return "restricted" not in tool.name + + server.tool_filter = async_filter + tools = await server.list_tools(run_context, agent) + assert len(tools) == 2 + assert {t.name for t in tools} == {"allowed_tool", "blocked_tool"} + + +@pytest.mark.asyncio +async def test_dynamic_filter_context_handling(): + """Test dynamic filters with context access""" + server = FakeMCPServer(server_name="test_server") + server.add_tool("admin_tool", {}) + server.add_tool("user_tool", {}) + server.add_tool("guest_tool", {}) + + # Test context-independent filter + def context_independent_filter(context: ToolFilterContext, tool: MCPTool) -> bool: + return not tool.name.startswith("admin") + + server.tool_filter = context_independent_filter + run_context = create_test_context() + agent = create_test_agent() + tools = await server.list_tools(run_context, agent) + assert len(tools) == 2 + assert {t.name for t in tools} == {"user_tool", "guest_tool"} + + # Test context-dependent filter (needs context) + def context_dependent_filter(context: ToolFilterContext, tool: MCPTool) -> bool: + assert context is not None + assert context.run_context is not None + assert context.agent is not None + assert context.server_name == "test_server" + + # Only admin tools for agents with "admin" in name + if "admin" in context.agent.name.lower(): + return True + else: + return not tool.name.startswith("admin") + + server.tool_filter = context_dependent_filter + + # Should work with context + run_context = RunContextWrapper(context=None) + regular_agent = create_test_agent("regular_user") + tools = await server.list_tools(run_context, regular_agent) + assert len(tools) == 2 + assert {t.name for t in tools} == {"user_tool", "guest_tool"} + + admin_agent = create_test_agent("admin_user") + tools = await server.list_tools(run_context, admin_agent) + assert len(tools) == 3 + + +@pytest.mark.asyncio +async def test_dynamic_filter_error_handling(): + """Test error handling in dynamic filters""" + server = FakeMCPServer(server_name="test_server") + server.add_tool("good_tool", {}) + server.add_tool("error_tool", {}) + server.add_tool("another_good_tool", {}) + + def error_prone_filter(context: ToolFilterContext, tool: MCPTool) -> bool: + if tool.name == "error_tool": + raise ValueError("Simulated filter error") + return True + + server.tool_filter = error_prone_filter + + # Test with server call + run_context = create_test_context() + agent = create_test_agent() + tools = await server.list_tools(run_context, agent) + assert len(tools) == 2 + assert {t.name for t in tools} == {"good_tool", "another_good_tool"} + + +# === Integration Tests === + +@pytest.mark.asyncio +async def test_agent_dynamic_filtering_integration(): + """Test dynamic filtering integration with Agent methods""" + server = FakeMCPServer() + server.add_tool("file_read", {"type": "object", "properties": {"path": {"type": "string"}}}) + server.add_tool( + "file_write", + { + "type": "object", + "properties": {"path": {"type": "string"}, "content": {"type": "string"}}, + }, + ) + server.add_tool( + "database_query", {"type": "object", "properties": {"query": {"type": "string"}}} + ) + server.add_tool( + "network_request", {"type": "object", "properties": {"url": {"type": "string"}}} + ) + + # Role-based filter for comprehensive testing + async def role_based_filter(context: ToolFilterContext, tool: MCPTool) -> bool: + # Simulate async permission check + await asyncio.sleep(0.001) + + agent_name = context.agent.name.lower() + if "admin" in agent_name: + return True + elif "readonly" in agent_name: + return "read" in tool.name or "query" in tool.name + else: + return tool.name.startswith("file_") + + server.tool_filter = role_based_filter + + # Test admin agent + admin_agent = Agent(name="admin_user", instructions="Admin", mcp_servers=[server]) + run_context = RunContextWrapper(context=None) + admin_tools = await admin_agent.get_mcp_tools(run_context) + assert len(admin_tools) == 4 + + # Test readonly agent + readonly_agent = Agent(name="readonly_viewer", instructions="Read-only", mcp_servers=[server]) + readonly_tools = await readonly_agent.get_mcp_tools(run_context) + assert len(readonly_tools) == 2 + assert {t.name for t in readonly_tools} == {"file_read", "database_query"} + + # Test regular agent + regular_agent = Agent(name="regular_user", instructions="Regular", mcp_servers=[server]) + regular_tools = await regular_agent.get_mcp_tools(run_context) + assert len(regular_tools) == 2 + assert {t.name for t in regular_tools} == {"file_read", "file_write"} + + # Test get_all_tools method + all_tools = await regular_agent.get_all_tools(run_context) + mcp_tool_names = { + t.name + for t in all_tools + if t.name in {"file_read", "file_write", "database_query", "network_request"} + } + assert mcp_tool_names == {"file_read", "file_write"}