From 300e12c19815f793f348bb9ed17d2eae66690b0b Mon Sep 17 00:00:00 2001 From: Rohan Mehta Date: Mon, 24 Mar 2025 15:08:02 -0400 Subject: [PATCH 1/2] [1/n] Add MCP types to the SDK ### Summary: 1. Add the MCP dep for python 3.10, since it doesn't support 3.9 and below 2. Create MCPServer, which is the agents SDK representation of an MCP server 3. Create implementations for HTTP-SSE and StdIO servers, directly copying the [MCP SDK example](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/clients/simple-chatbot/mcp_simple_chatbot/main.py) 4. Add a util to transform MCP tools into Agent SDK tools Note: I added optional caching support to the servers. That way, if you happen to know a server's tools don't change, you can just cache them. ### Test Plan: Checks pass. I added tests at the end of the stack. --- pyproject.toml | 1 + src/agents/mcp/__init__.py | 21 +++ src/agents/mcp/mcp_util.py | 94 +++++++++++++ src/agents/mcp/server.py | 269 +++++++++++++++++++++++++++++++++++++ src/agents/mcp/util.py | 96 +++++++++++++ uv.lock | 93 ++++++++++++- 6 files changed, 572 insertions(+), 2 deletions(-) create mode 100644 src/agents/mcp/__init__.py create mode 100644 src/agents/mcp/mcp_util.py create mode 100644 src/agents/mcp/server.py create mode 100644 src/agents/mcp/util.py diff --git a/pyproject.toml b/pyproject.toml index 667ab355..3678c714 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,7 @@ dependencies = [ "typing-extensions>=4.12.2, <5", "requests>=2.0, <3", "types-requests>=2.0, <3", + "mcp; python_version >= '3.10'", ] classifiers = [ "Typing :: Typed", diff --git a/src/agents/mcp/__init__.py b/src/agents/mcp/__init__.py new file mode 100644 index 00000000..1a72a89f --- /dev/null +++ b/src/agents/mcp/__init__.py @@ -0,0 +1,21 @@ +try: + from .server import ( + MCPServer, + MCPServerSse, + MCPServerSseParams, + MCPServerStdio, + MCPServerStdioParams, + ) +except ImportError: + pass + +from .util import MCPUtil + +__all__ = [ + "MCPServer", + "MCPServerSse", + "MCPServerSseParams", + "MCPServerStdio", + "MCPServerStdioParams", + "MCPUtil", +] diff --git a/src/agents/mcp/mcp_util.py b/src/agents/mcp/mcp_util.py new file mode 100644 index 00000000..41b4c521 --- /dev/null +++ b/src/agents/mcp/mcp_util.py @@ -0,0 +1,94 @@ +import functools +import json +from typing import Any + +from mcp.types import Tool as MCPTool + +from .. import _debug +from ..exceptions import AgentsException, ModelBehaviorError, UserError +from ..logger import logger +from ..run_context import RunContextWrapper +from ..tool import FunctionTool, Tool +from .server import MCPServer + + +class MCPUtil: + """Set of utilities for interop between MCP and Agents SDK tools.""" + + @classmethod + async def get_all_function_tools(cls, servers: list[MCPServer]) -> list[Tool]: + """Get all function tools from a list of MCP servers.""" + tools = [] + tool_names: set[str] = set() + for server in servers: + server_tools = await cls.get_function_tools(server) + server_tool_names = {tool.name for tool in server_tools} + if len(server_tool_names & tool_names) > 0: + raise UserError( + f"Duplicate tool names found across MCP servers: " + f"{server_tool_names & tool_names}" + ) + tool_names.update(server_tool_names) + tools.extend(server_tools) + + return tools + + @classmethod + async def get_function_tools(cls, server: MCPServer) -> list[Tool]: + """Get all function tools from a single MCP server.""" + tools = await server.list_tools() + return [cls.to_function_tool(tool, server) for tool in tools] + + @classmethod + def to_function_tool(cls, tool: MCPTool, server: MCPServer) -> FunctionTool: + """Convert an MCP tool to an Agents SDK function tool.""" + invoke_func = functools.partial(cls.invoke_mcp_tool, server, tool) + return FunctionTool( + name=tool.name, + description=tool.description or "", + params_json_schema=tool.inputSchema, + on_invoke_tool=invoke_func, + strict_json_schema=False, + ) + + @classmethod + async def invoke_mcp_tool( + cls, server: MCPServer, tool: MCPTool, context: RunContextWrapper[Any], input_json: str + ) -> str: + """Invoke an MCP tool and return the result as a string.""" + try: + json_data: dict[str, Any] = json.loads(input_json) if input_json else {} + except Exception as e: + if _debug.DONT_LOG_TOOL_DATA: + logger.debug(f"Invalid JSON input for tool {tool.name}") + else: + logger.debug(f"Invalid JSON input for tool {tool.name}: {input_json}") + raise ModelBehaviorError( + f"Invalid JSON input for tool {tool.name}: {input_json}" + ) from e + + if _debug.DONT_LOG_TOOL_DATA: + logger.debug(f"Invoking MCP tool {tool.name}") + else: + logger.debug(f"Invoking MCP tool {tool.name} with input {input_json}") + + try: + result = await server.call_tool(tool.name, json_data) + except Exception as e: + logger.error(f"Error invoking MCP tool {tool.name}: {e}") + raise AgentsException(f"Error invoking MCP tool {tool.name}: {e}") from e + + if _debug.DONT_LOG_TOOL_DATA: + logger.debug(f"MCP tool {tool.name} completed.") + else: + logger.debug(f"MCP tool {tool.name} returned {result}") + + # The MCP tool result is a list of content items, whereas OpenAI tool outputs are a single + # string. We'll try to convert. + if len(result.content) == 1: + return result.content[0].model_dump_json() + elif len(result.content) > 1: + return json.dumps([item.model_dump() for item in result.content]) + else: + logger.error(f"Errored MCP tool result: {result}") + return "Error running tool." diff --git a/src/agents/mcp/server.py b/src/agents/mcp/server.py new file mode 100644 index 00000000..e19e686a --- /dev/null +++ b/src/agents/mcp/server.py @@ -0,0 +1,269 @@ +from __future__ import annotations + +import abc +import asyncio +from contextlib import AbstractAsyncContextManager, AsyncExitStack +from pathlib import Path +from typing import Any, Literal + +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from mcp import ClientSession, StdioServerParameters, Tool as MCPTool, stdio_client +from mcp.client.sse import sse_client +from mcp.types import CallToolResult, JSONRPCMessage +from typing_extensions import NotRequired, TypedDict + +from ..exceptions import UserError +from ..logger import logger + + +class MCPServer(abc.ABC): + """Base class for Model Context Protocol servers.""" + + @abc.abstractmethod + async def connect(self): + """Connect to the server. For example, this might mean spawning a subprocess or + opening a network connection. The server is expected to remain connected until + `cleanup()` is called. + """ + pass + + @abc.abstractmethod + async def cleanup(self): + """Cleanup the server. For example, this might mean closing a subprocess or + closing a network connection. + """ + pass + + @abc.abstractmethod + async def list_tools(self) -> list[MCPTool]: + """List the tools available on the server.""" + pass + + @abc.abstractmethod + async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> CallToolResult: + """Invoke a tool on the server.""" + pass + + +class _MCPServerWithClientSession(MCPServer, abc.ABC): + """Base class for MCP servers that use a `ClientSession` to communicate with the server.""" + + def __init__(self, cache_tools_list: bool): + """ + Args: + cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be + cached and only fetched from the server once. If `False`, the tools list will be + fetched from the server on each call to `list_tools()`. The cache can be invalidated + by calling `invalidate_tools_cache()`. You should set this to `True` if you know the + server will not change its tools list, because it can drastically improve latency + (by avoiding a round-trip to the server every time). + """ + self.session: ClientSession | None = None + self.exit_stack: AsyncExitStack = AsyncExitStack() + self._cleanup_lock: asyncio.Lock = asyncio.Lock() + self.cache_tools_list = cache_tools_list + + # The cache is always dirty at startup, so that we fetch tools at least once + self._cache_dirty = True + self._tools_list: list[MCPTool] | None = None + + @abc.abstractmethod + def create_streams( + self, + ) -> AbstractAsyncContextManager[ + tuple[ + MemoryObjectReceiveStream[JSONRPCMessage | Exception], + MemoryObjectSendStream[JSONRPCMessage], + ] + ]: + """Create the streams for the server.""" + pass + + async def __aenter__(self): + await self.connect() + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.cleanup() + + def invalidate_tools_cache(self): + """Invalidate the tools cache.""" + self._cache_dirty = True + + async def connect(self): + """Connect to the server.""" + try: + transport = await self.exit_stack.enter_async_context(self.create_streams()) + read, write = transport + session = await self.exit_stack.enter_async_context(ClientSession(read, write)) + await session.initialize() + self.session = session + except Exception as e: + logger.error(f"Error initializing MCP server: {e}") + await self.cleanup() + raise + + async def list_tools(self) -> list[MCPTool]: + """List the tools available on the server.""" + if not self.session: + raise UserError("Server not initialized. Make sure you call `connect()` first.") + + # Return from cache if caching is enabled, we have tools, and the cache is not dirty + if self.cache_tools_list and not self._cache_dirty and self._tools_list: + return self._tools_list + + # Reset the cache dirty to False + self._cache_dirty = False + + # Fetch the tools from the server + self._tools_list = (await self.session.list_tools()).tools + return self._tools_list + + async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> CallToolResult: + """Invoke a tool on the server.""" + if not self.session: + raise UserError("Server not initialized. Make sure you call `connect()` first.") + + return await self.session.call_tool(tool_name, arguments) + + async def cleanup(self): + """Cleanup the server.""" + async with self._cleanup_lock: + try: + await self.exit_stack.aclose() + self.session = None + except Exception as e: + logger.error(f"Error cleaning up server: {e}") + + +class MCPServerStdioParams(TypedDict): + """Mirrors `mcp.client.stdio.StdioServerParameters`, but lets you pass params without another + import. + """ + + command: str + """The executable to run to start the server. For example, `python` or `node`.""" + + args: NotRequired[list[str]] + """Command line args to pass to the `command` executable. For example, `['foo.py']` or + `['server.js', '--port', '8080']`.""" + + env: NotRequired[dict[str, str]] + """The environment variables to set for the server. .""" + + cwd: NotRequired[str | Path] + """The working directory to use when spawning the process.""" + + encoding: NotRequired[str] + """The text encoding used when sending/receiving messages to the server. Defaults to `utf-8`.""" + + encoding_error_handler: NotRequired[Literal["strict", "ignore", "replace"]] + """The text encoding error handler. Defaults to `strict`. + + See https://docs.python.org/3/library/codecs.html#codec-base-classes for + explanations of possible values. + """ + + +class MCPServerStdio(_MCPServerWithClientSession): + """MCP server implementation that uses the stdio transport. See the [spec] + (https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#stdio) for + details. + """ + + def __init__(self, params: MCPServerStdioParams, cache_tools_list: bool = False): + """Create a new MCP server based on the stdio transport. + + Args: + params: The params that configure the server. This includes: + - The command (e.g. `python` or `node`) that starts the server. + - The args to pass to the server command (e.g. `foo.py` or `server.js`). + - The environment variables to set for the server. + cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be + cached and only fetched from the server once. If `False`, the tools list will be + fetched from the server on each call to `list_tools()`. The cache can be + invalidated by calling `invalidate_tools_cache()`. You should set this to `True` + if you know the server will not change its tools list, because it can drastically + improve latency (by avoiding a round-trip to the server every time). + """ + super().__init__(cache_tools_list) + + self.params = StdioServerParameters( + command=params["command"], + args=params.get("args", []), + env=params.get("env"), + cwd=params.get("cwd"), + encoding=params.get("encoding", "utf-8"), + encoding_error_handler=params.get("encoding_error_handler", "strict"), + ) + + def create_streams( + self, + ) -> AbstractAsyncContextManager[ + tuple[ + MemoryObjectReceiveStream[JSONRPCMessage | Exception], + MemoryObjectSendStream[JSONRPCMessage], + ] + ]: + """Create the streams for the server.""" + return stdio_client(self.params) + + +class MCPServerSseParams(TypedDict): + """Mirrors the params in`mcp.client.sse.sse_client`.""" + + url: str + """The URL of the server.""" + + headers: NotRequired[dict[str, str]] + """The headers to send to the server.""" + + timeout: NotRequired[float] + """The timeout for the HTTP request. Defaults to 5 seconds.""" + + sse_read_timeout: NotRequired[float] + """The timeout for the SSE connection, in seconds. Defaults to 5 minutes.""" + + +class MCPServerSse(_MCPServerWithClientSession): + """MCP server implementation that uses the HTTP with SSE transport. See the [spec] + (https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#http-with-sse) + for details. + """ + + def __init__(self, params: MCPServerSseParams, cache_tools_list: bool = False): + """Create a new MCP server based on the HTTP with SSE transport. + + Args: + params: The params that configure the server. This includes: + - The URL of the server. + - The headers to send to the server. + - The timeout for the HTTP request. + - The timeout for the SSE connection. + + cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be + cached and only fetched from the server once. If `False`, the tools list will be + fetched from the server on each call to `list_tools()`. The cache can be + invalidated by calling `invalidate_tools_cache()`. You should set this to `True` + if you know the server will not change its tools list, because it can drastically + improve latency (by avoiding a round-trip to the server every time). + """ + super().__init__(cache_tools_list) + + self.params = params + + def create_streams( + self, + ) -> AbstractAsyncContextManager[ + tuple[ + MemoryObjectReceiveStream[JSONRPCMessage | Exception], + MemoryObjectSendStream[JSONRPCMessage], + ] + ]: + """Create the streams for the server.""" + return sse_client( + url=self.params["url"], + headers=self.params.get("headers", None), + timeout=self.params.get("timeout", 5), + sse_read_timeout=self.params.get("sse_read_timeout", 60 * 5), + ) diff --git a/src/agents/mcp/util.py b/src/agents/mcp/util.py new file mode 100644 index 00000000..038c4fec --- /dev/null +++ b/src/agents/mcp/util.py @@ -0,0 +1,96 @@ +import functools +import json +from typing import TYPE_CHECKING, Any + +from .. import _debug +from ..exceptions import AgentsException, ModelBehaviorError, UserError +from ..logger import logger +from ..run_context import RunContextWrapper +from ..tool import FunctionTool, Tool + +if TYPE_CHECKING: + from mcp.types import Tool as MCPTool + + from .server import MCPServer + + +class MCPUtil: + """Set of utilities for interop between MCP and Agents SDK tools.""" + + @classmethod + async def get_all_function_tools(cls, servers: list["MCPServer"]) -> list[Tool]: + """Get all function tools from a list of MCP servers.""" + tools = [] + tool_names: set[str] = set() + for server in servers: + server_tools = await cls.get_function_tools(server) + server_tool_names = {tool.name for tool in server_tools} + if len(server_tool_names & tool_names) > 0: + raise UserError( + f"Duplicate tool names found across MCP servers: " + f"{server_tool_names & tool_names}" + ) + tool_names.update(server_tool_names) + tools.extend(server_tools) + + return tools + + @classmethod + async def get_function_tools(cls, server: "MCPServer") -> list[Tool]: + """Get all function tools from a single MCP server.""" + tools = await server.list_tools() + return [cls.to_function_tool(tool, server) for tool in tools] + + @classmethod + def to_function_tool(cls, tool: "MCPTool", server: "MCPServer") -> FunctionTool: + """Convert an MCP tool to an Agents SDK function tool.""" + invoke_func = functools.partial(cls.invoke_mcp_tool, server, tool) + return FunctionTool( + name=tool.name, + description=tool.description or "", + params_json_schema=tool.inputSchema, + on_invoke_tool=invoke_func, + strict_json_schema=False, + ) + + @classmethod + async def invoke_mcp_tool( + cls, server: "MCPServer", tool: "MCPTool", context: RunContextWrapper[Any], input_json: str + ) -> str: + """Invoke an MCP tool and return the result as a string.""" + try: + json_data: dict[str, Any] = json.loads(input_json) if input_json else {} + except Exception as e: + if _debug.DONT_LOG_TOOL_DATA: + logger.debug(f"Invalid JSON input for tool {tool.name}") + else: + logger.debug(f"Invalid JSON input for tool {tool.name}: {input_json}") + raise ModelBehaviorError( + f"Invalid JSON input for tool {tool.name}: {input_json}" + ) from e + + if _debug.DONT_LOG_TOOL_DATA: + logger.debug(f"Invoking MCP tool {tool.name}") + else: + logger.debug(f"Invoking MCP tool {tool.name} with input {input_json}") + + try: + result = await server.call_tool(tool.name, json_data) + except Exception as e: + logger.error(f"Error invoking MCP tool {tool.name}: {e}") + raise AgentsException(f"Error invoking MCP tool {tool.name}: {e}") from e + + if _debug.DONT_LOG_TOOL_DATA: + logger.debug(f"MCP tool {tool.name} completed.") + else: + logger.debug(f"MCP tool {tool.name} returned {result}") + + # The MCP tool result is a list of content items, whereas OpenAI tool outputs are a single + # string. We'll try to convert. + if len(result.content) == 1: + return result.content[0].model_dump_json() + elif len(result.content) > 1: + return json.dumps([item.model_dump() for item in result.content]) + else: + logger.error(f"Errored MCP tool result: {result}") + return "Error running tool." diff --git a/uv.lock b/uv.lock index a9c79e21..d6eba43f 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,4 @@ version = 1 -revision = 1 requires-python = ">=3.9" resolution-markers = [ "python_full_version >= '3.10'", @@ -459,6 +458,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517 }, ] +[[package]] +name = "httpx-sse" +version = "0.4.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/4c/60/8f4281fa9bbf3c8034fd54c0e7412e66edbab6bc74c4996bd616f8d0406e/httpx-sse-0.4.0.tar.gz", hash = "sha256:1e81a3a3070ce322add1d3529ed42eb5f70817f45ed6ec915ab753f961139721", size = 12624 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e1/9b/a181f281f65d776426002f330c31849b86b31fc9d848db62e16f03ff739f/httpx_sse-0.4.0-py3-none-any.whl", hash = "sha256:f329af6eae57eaa2bdfd962b42524764af68075ea87370a2de920af5341e318f", size = 7819 }, +] + [[package]] name = "idna" version = "3.10" @@ -699,6 +707,25 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b3/73/085399401383ce949f727afec55ec3abd76648d04b9f22e1c0e99cb4bec3/MarkupSafe-3.0.2-cp39-cp39-win_amd64.whl", hash = "sha256:6e296a513ca3d94054c2c881cc913116e90fd030ad1c656b3869762b754f5f8a", size = 15506 }, ] +[[package]] +name = "mcp" +version = "1.5.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio", marker = "python_full_version >= '3.10'" }, + { name = "httpx", marker = "python_full_version >= '3.10'" }, + { name = "httpx-sse", marker = "python_full_version >= '3.10'" }, + { name = "pydantic", marker = "python_full_version >= '3.10'" }, + { name = "pydantic-settings", marker = "python_full_version >= '3.10'" }, + { name = "sse-starlette", marker = "python_full_version >= '3.10'" }, + { name = "starlette", marker = "python_full_version >= '3.10'" }, + { name = "uvicorn", marker = "python_full_version >= '3.10'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/6d/c9/c55764824e893fdebe777ac7223200986a275c3191dba9169f8eb6d7c978/mcp-1.5.0.tar.gz", hash = "sha256:5b2766c05e68e01a2034875e250139839498c61792163a7b221fc170c12f5aa9", size = 159128 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c1/d1/3ff566ecf322077d861f1a68a1ff025cad337417bd66ad22a7c6f7dfcfaf/mcp-1.5.0-py3-none-any.whl", hash = "sha256:51c3f35ce93cb702f7513c12406bbea9665ef75a08db909200b07da9db641527", size = 73734 }, +] + [[package]] name = "mdit-py-plugins" version = "0.4.2" @@ -1054,6 +1081,7 @@ version = "0.0.6" source = { editable = "." } dependencies = [ { name = "griffe" }, + { name = "mcp", marker = "python_full_version >= '3.10'" }, { name = "openai" }, { name = "pydantic" }, { name = "requests" }, @@ -1091,6 +1119,7 @@ dev = [ [package.metadata] requires-dist = [ { name = "griffe", specifier = ">=1.5.6,<2" }, + { name = "mcp", marker = "python_full_version >= '3.10'" }, { name = "numpy", marker = "python_full_version >= '3.10' and extra == 'voice'", specifier = ">=2.2.0,<3" }, { name = "openai", specifier = ">=1.66.5" }, { name = "pydantic", specifier = ">=2.10,<3" }, @@ -1099,7 +1128,6 @@ requires-dist = [ { name = "typing-extensions", specifier = ">=4.12.2,<5" }, { name = "websockets", marker = "extra == 'voice'", specifier = ">=15.0,<16" }, ] -provides-extras = ["voice"] [package.metadata.requires-dev] dev = [ @@ -1305,6 +1333,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a1/0c/c5c5cd3689c32ed1fe8c5d234b079c12c281c051759770c05b8bed6412b5/pydantic_core-2.27.2-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:7d0c8399fcc1848491f00e0314bd59fb34a9c008761bcb422a057670c3f65e35", size = 2004961 }, ] +[[package]] +name = "pydantic-settings" +version = "2.8.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pydantic", marker = "python_full_version >= '3.10'" }, + { name = "python-dotenv", marker = "python_full_version >= '3.10'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/88/82/c79424d7d8c29b994fb01d277da57b0a9b09cc03c3ff875f9bd8a86b2145/pydantic_settings-2.8.1.tar.gz", hash = "sha256:d5c663dfbe9db9d5e1c646b2e161da12f0d734d422ee56f567d0ea2cee4e8585", size = 83550 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0b/53/a64f03044927dc47aafe029c42a5b7aabc38dfb813475e0e1bf71c4a59d0/pydantic_settings-2.8.1-py3-none-any.whl", hash = "sha256:81942d5ac3d905f7f3ee1a70df5dfb62d5569c12f51a5a647defc1c3d9ee2e9c", size = 30839 }, +] + [[package]] name = "pyee" version = "12.1.1" @@ -1496,6 +1537,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427", size = 229892 }, ] +[[package]] +name = "python-dotenv" +version = "1.0.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/bc/57/e84d88dfe0aec03b7a2d4327012c1627ab5f03652216c63d49846d7a6c58/python-dotenv-1.0.1.tar.gz", hash = "sha256:e324ee90a023d808f1959c46bcbc04446a10ced277783dc6ee09987c37ec10ca", size = 39115 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6a/3e/b68c118422ec867fa7ab88444e1274aa40681c606d59ac27de5a5588f082/python_dotenv-1.0.1-py3-none-any.whl", hash = "sha256:f7b63ef50f1b690dddf550d03497b66d609393b40b564ed0d674909a68ebf16a", size = 19863 }, +] + [[package]] name = "python-xlib" version = "0.33" @@ -1660,6 +1710,31 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/af/9b/15217b04f3b36d30de55fef542389d722de63f1ad81f9c72d8afc98cb6ab/sounddevice-0.5.1-py3-none-win_amd64.whl", hash = "sha256:4313b63f2076552b23ac3e0abd3bcfc0c1c6a696fc356759a13bd113c9df90f1", size = 363634 }, ] +[[package]] +name = "sse-starlette" +version = "2.2.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio", marker = "python_full_version >= '3.10'" }, + { name = "starlette", marker = "python_full_version >= '3.10'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/71/a4/80d2a11af59fe75b48230846989e93979c892d3a20016b42bb44edb9e398/sse_starlette-2.2.1.tar.gz", hash = "sha256:54470d5f19274aeed6b2d473430b08b4b379ea851d953b11d7f1c4a2c118b419", size = 17376 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d9/e0/5b8bd393f27f4a62461c5cf2479c75a2cc2ffa330976f9f00f5f6e4f50eb/sse_starlette-2.2.1-py3-none-any.whl", hash = "sha256:6410a3d3ba0c89e7675d4c273a301d64649c03a5ef1ca101f10b47f895fd0e99", size = 10120 }, +] + +[[package]] +name = "starlette" +version = "0.46.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio", marker = "python_full_version >= '3.10'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/04/1b/52b27f2e13ceedc79a908e29eac426a63465a1a01248e5f24aa36a62aeb3/starlette-0.46.1.tar.gz", hash = "sha256:3c88d58ee4bd1bb807c0d1acb381838afc7752f9ddaec81bbe4383611d833230", size = 2580102 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a0/4b/528ccf7a982216885a1ff4908e886b8fb5f19862d1962f56a3fce2435a70/starlette-0.46.1-py3-none-any.whl", hash = "sha256:77c74ed9d2720138b25875133f3a2dae6d854af2ec37dceb56aef370c1d8a227", size = 71995 }, +] + [[package]] name = "textual" version = "2.1.2" @@ -1774,6 +1849,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c8/19/4ec628951a74043532ca2cf5d97b7b14863931476d117c471e8e2b1eb39f/urllib3-2.3.0-py3-none-any.whl", hash = "sha256:1cee9ad369867bfdbbb48b7dd50374c0967a0bb7710050facf0dd6911440e3df", size = 128369 }, ] +[[package]] +name = "uvicorn" +version = "0.34.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click", marker = "python_full_version >= '3.10'" }, + { name = "h11", marker = "python_full_version >= '3.10'" }, + { name = "typing-extensions", marker = "python_full_version == '3.10.*'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/4b/4d/938bd85e5bf2edeec766267a5015ad969730bb91e31b44021dfe8b22df6c/uvicorn-0.34.0.tar.gz", hash = "sha256:404051050cd7e905de2c9a7e61790943440b3416f49cb409f965d9dcd0fa73e9", size = 76568 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/61/14/33a3a1352cfa71812a3a21e8c9bfb83f60b0011f5e36f2b1399d51928209/uvicorn-0.34.0-py3-none-any.whl", hash = "sha256:023dc038422502fa28a09c7a30bf2b6991512da7dcdb8fd35fe57cfc154126f4", size = 62315 }, +] + [[package]] name = "watchdog" version = "6.0.0" From 68c800d2a3a7ea101d144a69b53a0b6658291e16 Mon Sep 17 00:00:00 2001 From: Rohan Mehta Date: Mon, 24 Mar 2025 15:08:02 -0400 Subject: [PATCH 2/2] [2/n] Add MCP support to Runner ### Summary: This enables users to **use** MCP inside the SDK. 1. You add a list of MCP servers to `Agent`, via `mcp_server=[...]` 2. When an agent runs, we look up its MCP tools and add them to the list of tools. 3. When a tool call occurs, we call the relevant MCP server. Notes: 1. There's some refactoring to make sure we send the full list of tools to the Runner/Model etc. 2. Right now, you could have a locally defined tool that conflicts with an MCP defined tool. I didn't add errors for that, will do in a followup. ### Test Plan: See unit tests. Also has an end to end example next PR. --- src/agents/_run_impl.py | 7 +- src/agents/agent.py | 20 +++ src/agents/run.py | 24 +++- tests/mcp/__init__.py | 0 tests/mcp/conftest.py | 11 ++ tests/mcp/helpers.py | 54 ++++++++ tests/mcp/test_caching.py | 57 ++++++++ tests/mcp/test_connect_disconnect.py | 69 ++++++++++ tests/mcp/test_mcp_util.py | 109 +++++++++++++++ tests/mcp/test_runner_calls_mcp.py | 197 +++++++++++++++++++++++++++ tests/mcp/test_server_errors.py | 38 ++++++ tests/test_run_step_execution.py | 1 + tests/test_run_step_processing.py | 100 +++++++++++--- tests/voice/test_openai_stt.py | 10 +- 14 files changed, 662 insertions(+), 35 deletions(-) create mode 100644 tests/mcp/__init__.py create mode 100644 tests/mcp/conftest.py create mode 100644 tests/mcp/helpers.py create mode 100644 tests/mcp/test_caching.py create mode 100644 tests/mcp/test_connect_disconnect.py create mode 100644 tests/mcp/test_mcp_util.py create mode 100644 tests/mcp/test_runner_calls_mcp.py create mode 100644 tests/mcp/test_server_errors.py diff --git a/src/agents/_run_impl.py b/src/agents/_run_impl.py index 2849538d..02e3bf58 100644 --- a/src/agents/_run_impl.py +++ b/src/agents/_run_impl.py @@ -50,7 +50,7 @@ from .models.interface import ModelTracing from .run_context import RunContextWrapper, TContext from .stream_events import RunItemStreamEvent, StreamEvent -from .tool import ComputerTool, FunctionTool, FunctionToolResult +from .tool import ComputerTool, FunctionTool, FunctionToolResult, Tool from .tracing import ( SpanError, Trace, @@ -301,6 +301,7 @@ def process_model_response( cls, *, agent: Agent[Any], + all_tools: list[Tool], response: ModelResponse, output_schema: AgentOutputSchema | None, handoffs: list[Handoff], @@ -312,8 +313,8 @@ def process_model_response( computer_actions = [] handoff_map = {handoff.tool_name: handoff for handoff in handoffs} - function_map = {tool.name: tool for tool in agent.tools if isinstance(tool, FunctionTool)} - computer_tool = next((tool for tool in agent.tools if isinstance(tool, ComputerTool)), None) + function_map = {tool.name: tool for tool in all_tools if isinstance(tool, FunctionTool)} + computer_tool = next((tool for tool in all_tools if isinstance(tool, ComputerTool)), None) for output in response.output: if isinstance(output, ResponseOutputMessage): diff --git a/src/agents/agent.py b/src/agents/agent.py index 2723e678..3258e15a 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -12,6 +12,7 @@ from .handoffs import Handoff from .items import ItemHelpers from .logger import logger +from .mcp import MCPUtil from .model_settings import ModelSettings from .models.interface import Model from .run_context import RunContextWrapper, TContext @@ -21,6 +22,7 @@ if TYPE_CHECKING: from .lifecycle import AgentHooks + from .mcp import MCPServer from .result import RunResult @@ -107,6 +109,16 @@ class Agent(Generic[TContext]): tools: list[Tool] = field(default_factory=list) """A list of tools that the agent can use.""" + mcp_servers: list[MCPServer] = field(default_factory=list) + """A list of [Model Context Protocol](https://modelcontextprotocol.io/) servers that + the agent can use. Every time the agent runs, it will include tools from these servers in the + list of available tools. + + NOTE: You are expected to manage the lifecycle of these servers. Specifically, you must call + `server.connect()` before passing it to the agent, and `server.cleanup()` when the server is no + longer needed. + """ + input_guardrails: list[InputGuardrail[TContext]] = field(default_factory=list) """A list of checks that run in parallel to the agent's execution, before generating a response. Runs only if the agent is the first agent in the chain. @@ -205,3 +217,11 @@ async def get_system_prompt(self, run_context: RunContextWrapper[TContext]) -> s logger.error(f"Instructions must be a string or a function, got {self.instructions}") return None + + async def get_mcp_tools(self) -> list[Tool]: + """Fetches the available tools from the MCP servers.""" + return await MCPUtil.get_all_function_tools(self.mcp_servers) + + async def get_all_tools(self) -> list[Tool]: + """All agent tools, including MCP tools and function tools.""" + return await MCPUtil.get_all_function_tools(self.mcp_servers) + self.tools diff --git a/src/agents/run.py b/src/agents/run.py index 934400fe..b7ac85f9 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -7,6 +7,8 @@ from openai.types.responses import ResponseCompletedEvent +from agents.tool import Tool + from ._run_impl import ( NextStepFinalOutput, NextStepHandoff, @@ -177,7 +179,8 @@ async def run( # agent changes, or if the agent loop ends. if current_span is None: handoff_names = [h.agent_name for h in cls._get_handoffs(current_agent)] - tool_names = [t.name for t in current_agent.tools] + all_tools = await cls._get_all_tools(current_agent) + tool_names = [t.name for t in all_tools] if output_schema := cls._get_output_schema(current_agent): output_type_name = output_schema.output_type_name() else: @@ -217,6 +220,7 @@ async def run( ), cls._run_single_turn( agent=current_agent, + all_tools=all_tools, original_input=original_input, generated_items=generated_items, hooks=hooks, @@ -228,6 +232,7 @@ async def run( else: turn_result = await cls._run_single_turn( agent=current_agent, + all_tools=all_tools, original_input=original_input, generated_items=generated_items, hooks=hooks, @@ -627,7 +632,7 @@ async def _run_single_turn_streamed( system_prompt = await agent.get_system_prompt(context_wrapper) handoffs = cls._get_handoffs(agent) - + all_tools = await cls._get_all_tools(agent) model = cls._get_model(agent, run_config) model_settings = agent.model_settings.resolve(run_config.model_settings) final_response: ModelResponse | None = None @@ -640,7 +645,7 @@ async def _run_single_turn_streamed( system_prompt, input, model_settings, - agent.tools, + all_tools, output_schema, handoffs, get_model_tracing_impl( @@ -677,6 +682,7 @@ async def _run_single_turn_streamed( pre_step_items=streamed_result.new_items, new_response=final_response, output_schema=output_schema, + all_tools=all_tools, handoffs=handoffs, hooks=hooks, context_wrapper=context_wrapper, @@ -691,6 +697,7 @@ async def _run_single_turn( cls, *, agent: Agent[TContext], + all_tools: list[Tool], original_input: str | list[TResponseInputItem], generated_items: list[RunItem], hooks: RunHooks[TContext], @@ -721,6 +728,7 @@ async def _run_single_turn( system_prompt, input, output_schema, + all_tools, handoffs, context_wrapper, run_config, @@ -732,6 +740,7 @@ async def _run_single_turn( pre_step_items=generated_items, new_response=new_response, output_schema=output_schema, + all_tools=all_tools, handoffs=handoffs, hooks=hooks, context_wrapper=context_wrapper, @@ -743,6 +752,7 @@ async def _get_single_step_result_from_response( cls, *, agent: Agent[TContext], + all_tools: list[Tool], original_input: str | list[TResponseInputItem], pre_step_items: list[RunItem], new_response: ModelResponse, @@ -754,6 +764,7 @@ async def _get_single_step_result_from_response( ) -> SingleStepResult: processed_response = RunImpl.process_model_response( agent=agent, + all_tools=all_tools, response=new_response, output_schema=output_schema, handoffs=handoffs, @@ -853,6 +864,7 @@ async def _get_new_response( system_prompt: str | None, input: list[TResponseInputItem], output_schema: AgentOutputSchema | None, + all_tools: list[Tool], handoffs: list[Handoff], context_wrapper: RunContextWrapper[TContext], run_config: RunConfig, @@ -863,7 +875,7 @@ async def _get_new_response( system_instructions=system_prompt, input=input, model_settings=model_settings, - tools=agent.tools, + tools=all_tools, output_schema=output_schema, handoffs=handoffs, tracing=get_model_tracing_impl( @@ -892,6 +904,10 @@ def _get_handoffs(cls, agent: Agent[Any]) -> list[Handoff]: handoffs.append(handoff(handoff_item)) return handoffs + @classmethod + async def _get_all_tools(cls, agent: Agent[Any]) -> list[Tool]: + return await agent.get_all_tools() + @classmethod def _get_model(cls, agent: Agent[Any], run_config: RunConfig) -> Model: if isinstance(run_config.model, Model): diff --git a/tests/mcp/__init__.py b/tests/mcp/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/mcp/conftest.py b/tests/mcp/conftest.py new file mode 100644 index 00000000..80fd15ec --- /dev/null +++ b/tests/mcp/conftest.py @@ -0,0 +1,11 @@ +import os +import sys + + +# Skip MCP tests on Python 3.9 +def pytest_ignore_collect(collection_path, config): + if sys.version_info[:2] == (3, 9): + this_dir = os.path.dirname(__file__) + + if str(collection_path).startswith(this_dir): + return True diff --git a/tests/mcp/helpers.py b/tests/mcp/helpers.py new file mode 100644 index 00000000..952b3ea7 --- /dev/null +++ b/tests/mcp/helpers.py @@ -0,0 +1,54 @@ +import json +import shutil +from typing import Any + +from mcp import Tool as MCPTool +from mcp.types import CallToolResult, TextContent + +from agents.mcp import MCPServer + +tee = shutil.which("tee") or "" +assert tee, "tee not found" + + +# Added dummy stream classes for patching stdio_client to avoid real I/O during tests +class DummyStream: + async def send(self, msg): + pass + + async def receive(self): + raise Exception("Dummy receive not implemented") + + +class DummyStreamsContextManager: + async def __aenter__(self): + return (DummyStream(), DummyStream()) + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + +class FakeMCPServer(MCPServer): + def __init__(self, tools: list[MCPTool] | None = None): + self.tools: list[MCPTool] = tools or [] + self.tool_calls: list[str] = [] + self.tool_results: list[str] = [] + + def add_tool(self, name: str, input_schema: dict[str, Any]): + self.tools.append(MCPTool(name=name, inputSchema=input_schema)) + + async def connect(self): + pass + + async def cleanup(self): + pass + + async def list_tools(self): + return self.tools + + async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> CallToolResult: + self.tool_calls.append(tool_name) + self.tool_results.append(f"result_{tool_name}_{json.dumps(arguments)}") + return CallToolResult( + content=[TextContent(text=self.tool_results[-1], type="text")], + ) diff --git a/tests/mcp/test_caching.py b/tests/mcp/test_caching.py new file mode 100644 index 00000000..cac409e6 --- /dev/null +++ b/tests/mcp/test_caching.py @@ -0,0 +1,57 @@ +from unittest.mock import AsyncMock, patch + +import pytest +from mcp.types import ListToolsResult, Tool as MCPTool + +from agents.mcp import MCPServerStdio + +from .helpers import DummyStreamsContextManager, tee + + +@pytest.mark.asyncio +@patch("mcp.client.stdio.stdio_client", return_value=DummyStreamsContextManager()) +@patch("mcp.client.session.ClientSession.initialize", new_callable=AsyncMock, return_value=None) +@patch("mcp.client.session.ClientSession.list_tools") +async def test_server_caching_works( + mock_list_tools: AsyncMock, mock_initialize: AsyncMock, mock_stdio_client +): + """Test that if we turn caching on, the list of tools is cached and not fetched from the server + on each call to `list_tools()`. + """ + server = MCPServerStdio( + params={ + "command": tee, + }, + cache_tools_list=True, + ) + + tools = [ + MCPTool(name="tool1", inputSchema={}), + MCPTool(name="tool2", inputSchema={}), + ] + + mock_list_tools.return_value = ListToolsResult(tools=tools) + + async with server: + # Call list_tools() multiple times + tools = await server.list_tools() + assert tools == tools + + assert mock_list_tools.call_count == 1, "list_tools() should have been called once" + + # Call list_tools() again, should return the cached value + tools = await server.list_tools() + assert tools == tools + + assert mock_list_tools.call_count == 1, "list_tools() should not have been called again" + + # Invalidate the cache and call list_tools() again + server.invalidate_tools_cache() + tools = await server.list_tools() + assert tools == tools + + assert mock_list_tools.call_count == 2, "list_tools() should be called again" + + # Without invalidating the cache, calling list_tools() again should return the cached value + tools = await server.list_tools() + assert tools == tools diff --git a/tests/mcp/test_connect_disconnect.py b/tests/mcp/test_connect_disconnect.py new file mode 100644 index 00000000..b0013039 --- /dev/null +++ b/tests/mcp/test_connect_disconnect.py @@ -0,0 +1,69 @@ +from unittest.mock import AsyncMock, patch + +import pytest +from mcp.types import ListToolsResult, Tool as MCPTool + +from agents.mcp import MCPServerStdio + +from .helpers import DummyStreamsContextManager, tee + + +@pytest.mark.asyncio +@patch("mcp.client.stdio.stdio_client", return_value=DummyStreamsContextManager()) +@patch("mcp.client.session.ClientSession.initialize", new_callable=AsyncMock, return_value=None) +@patch("mcp.client.session.ClientSession.list_tools") +async def test_async_ctx_manager_works( + mock_list_tools: AsyncMock, mock_initialize: AsyncMock, mock_stdio_client +): + """Test that the async context manager works.""" + server = MCPServerStdio( + params={ + "command": tee, + }, + cache_tools_list=True, + ) + + tools = [ + MCPTool(name="tool1", inputSchema={}), + MCPTool(name="tool2", inputSchema={}), + ] + + mock_list_tools.return_value = ListToolsResult(tools=tools) + + assert server.session is None, "Server should not be connected" + + async with server: + assert server.session is not None, "Server should be connected" + + assert server.session is None, "Server should be disconnected" + + +@pytest.mark.asyncio +@patch("mcp.client.stdio.stdio_client", return_value=DummyStreamsContextManager()) +@patch("mcp.client.session.ClientSession.initialize", new_callable=AsyncMock, return_value=None) +@patch("mcp.client.session.ClientSession.list_tools") +async def test_manual_connect_disconnect_works( + mock_list_tools: AsyncMock, mock_initialize: AsyncMock, mock_stdio_client +): + """Test that the async context manager works.""" + server = MCPServerStdio( + params={ + "command": tee, + }, + cache_tools_list=True, + ) + + tools = [ + MCPTool(name="tool1", inputSchema={}), + MCPTool(name="tool2", inputSchema={}), + ] + + mock_list_tools.return_value = ListToolsResult(tools=tools) + + assert server.session is None, "Server should not be connected" + + await server.connect() + assert server.session is not None, "Server should be connected" + + await server.cleanup() + assert server.session is None, "Server should be disconnected" diff --git a/tests/mcp/test_mcp_util.py b/tests/mcp/test_mcp_util.py new file mode 100644 index 00000000..345df996 --- /dev/null +++ b/tests/mcp/test_mcp_util.py @@ -0,0 +1,109 @@ +import logging +from typing import Any + +import pytest +from mcp.types import Tool as MCPTool +from pydantic import BaseModel + +from agents import FunctionTool, RunContextWrapper +from agents.exceptions import AgentsException, ModelBehaviorError +from agents.mcp import MCPServer, MCPUtil + +from .helpers import FakeMCPServer + + +class Foo(BaseModel): + bar: str + baz: int + + +class Bar(BaseModel): + qux: str + + +@pytest.mark.asyncio +async def test_get_all_function_tools(): + """Test that the get_all_function_tools function returns all function tools from a list of MCP + servers. + """ + names = ["test_tool_1", "test_tool_2", "test_tool_3", "test_tool_4", "test_tool_5"] + schemas = [ + {}, + {}, + {}, + Foo.model_json_schema(), + Bar.model_json_schema(), + ] + + server1 = FakeMCPServer() + server1.add_tool(names[0], schemas[0]) + server1.add_tool(names[1], schemas[1]) + + server2 = FakeMCPServer() + server2.add_tool(names[2], schemas[2]) + server2.add_tool(names[3], schemas[3]) + + server3 = FakeMCPServer() + server3.add_tool(names[4], schemas[4]) + + servers: list[MCPServer] = [server1, server2, server3] + tools = await MCPUtil.get_all_function_tools(servers) + assert len(tools) == 5 + assert all(tool.name in names for tool in tools) + + for idx, tool in enumerate(tools): + assert isinstance(tool, FunctionTool) + assert tool.params_json_schema == schemas[idx] + assert tool.name == names[idx] + + +@pytest.mark.asyncio +async def test_invoke_mcp_tool(): + """Test that the invoke_mcp_tool function invokes an MCP tool and returns the result.""" + server = FakeMCPServer() + server.add_tool("test_tool_1", {}) + + ctx = RunContextWrapper(context=None) + tool = MCPTool(name="test_tool_1", inputSchema={}) + + await MCPUtil.invoke_mcp_tool(server, tool, ctx, "") + # Just making sure it doesn't crash + + +@pytest.mark.asyncio +async def test_mcp_invoke_bad_json_errors(caplog: pytest.LogCaptureFixture): + caplog.set_level(logging.DEBUG) + + """Test that bad JSON input errors are logged and re-raised.""" + server = FakeMCPServer() + server.add_tool("test_tool_1", {}) + + ctx = RunContextWrapper(context=None) + tool = MCPTool(name="test_tool_1", inputSchema={}) + + with pytest.raises(ModelBehaviorError): + await MCPUtil.invoke_mcp_tool(server, tool, ctx, "not_json") + + assert "Invalid JSON input for tool test_tool_1" in caplog.text + + +class CrashingFakeMCPServer(FakeMCPServer): + async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None): + raise Exception("Crash!") + + +@pytest.mark.asyncio +async def test_mcp_invocation_crash_causes_error(caplog: pytest.LogCaptureFixture): + caplog.set_level(logging.DEBUG) + + """Test that bad JSON input errors are logged and re-raised.""" + server = CrashingFakeMCPServer() + server.add_tool("test_tool_1", {}) + + ctx = RunContextWrapper(context=None) + tool = MCPTool(name="test_tool_1", inputSchema={}) + + with pytest.raises(AgentsException): + await MCPUtil.invoke_mcp_tool(server, tool, ctx, "") + + assert "Error invoking MCP tool test_tool_1" in caplog.text diff --git a/tests/mcp/test_runner_calls_mcp.py b/tests/mcp/test_runner_calls_mcp.py new file mode 100644 index 00000000..3319c097 --- /dev/null +++ b/tests/mcp/test_runner_calls_mcp.py @@ -0,0 +1,197 @@ +import json + +import pytest +from pydantic import BaseModel + +from agents import Agent, ModelBehaviorError, Runner, UserError + +from ..fake_model import FakeModel +from ..test_responses import get_function_tool_call, get_text_message +from .helpers import FakeMCPServer + + +@pytest.mark.asyncio +@pytest.mark.parametrize("streaming", [False, True]) +async def test_runner_calls_mcp_tool(streaming: bool): + """Test that the runner calls an MCP tool when the model produces a tool call.""" + server = FakeMCPServer() + server.add_tool("test_tool_1", {}) + server.add_tool("test_tool_2", {}) + server.add_tool("test_tool_3", {}) + model = FakeModel() + agent = Agent( + name="test", + model=model, + mcp_servers=[server], + ) + + model.add_multiple_turn_outputs( + [ + # First turn: a message and tool call + [get_text_message("a_message"), get_function_tool_call("test_tool_2", "")], + # Second turn: text message + [get_text_message("done")], + ] + ) + + if streaming: + result = Runner.run_streamed(agent, input="user_message") + async for _ in result.stream_events(): + pass + else: + await Runner.run(agent, input="user_message") + + assert server.tool_calls == ["test_tool_2"] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("streaming", [False, True]) +async def test_runner_asserts_when_mcp_tool_not_found(streaming: bool): + """Test that the runner asserts when an MCP tool is not found.""" + server = FakeMCPServer() + server.add_tool("test_tool_1", {}) + server.add_tool("test_tool_2", {}) + server.add_tool("test_tool_3", {}) + model = FakeModel() + agent = Agent( + name="test", + model=model, + mcp_servers=[server], + ) + + model.add_multiple_turn_outputs( + [ + # First turn: a message and tool call + [get_text_message("a_message"), get_function_tool_call("test_tool_doesnt_exist", "")], + # Second turn: text message + [get_text_message("done")], + ] + ) + + with pytest.raises(ModelBehaviorError): + if streaming: + result = Runner.run_streamed(agent, input="user_message") + async for _ in result.stream_events(): + pass + else: + await Runner.run(agent, input="user_message") + + +@pytest.mark.asyncio +@pytest.mark.parametrize("streaming", [False, True]) +async def test_runner_works_with_multiple_mcp_servers(streaming: bool): + """Test that the runner works with multiple MCP servers.""" + server1 = FakeMCPServer() + server1.add_tool("test_tool_1", {}) + + server2 = FakeMCPServer() + server2.add_tool("test_tool_2", {}) + server2.add_tool("test_tool_3", {}) + + model = FakeModel() + agent = Agent( + name="test", + model=model, + mcp_servers=[server1, server2], + ) + + model.add_multiple_turn_outputs( + [ + # First turn: a message and tool call + [get_text_message("a_message"), get_function_tool_call("test_tool_2", "")], + # Second turn: text message + [get_text_message("done")], + ] + ) + + if streaming: + result = Runner.run_streamed(agent, input="user_message") + async for _ in result.stream_events(): + pass + else: + await Runner.run(agent, input="user_message") + + assert server1.tool_calls == [] + assert server2.tool_calls == ["test_tool_2"] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("streaming", [False, True]) +async def test_runner_errors_when_mcp_tools_clash(streaming: bool): + """Test that the runner errors when multiple servers have the same tool name.""" + server1 = FakeMCPServer() + server1.add_tool("test_tool_1", {}) + server1.add_tool("test_tool_2", {}) + + server2 = FakeMCPServer() + server2.add_tool("test_tool_2", {}) + server2.add_tool("test_tool_3", {}) + + model = FakeModel() + agent = Agent( + name="test", + model=model, + mcp_servers=[server1, server2], + ) + + model.add_multiple_turn_outputs( + [ + # First turn: a message and tool call + [get_text_message("a_message"), get_function_tool_call("test_tool_3", "")], + # Second turn: text message + [get_text_message("done")], + ] + ) + + with pytest.raises(UserError): + if streaming: + result = Runner.run_streamed(agent, input="user_message") + async for _ in result.stream_events(): + pass + else: + await Runner.run(agent, input="user_message") + + +class Foo(BaseModel): + bar: str + baz: int + + +@pytest.mark.asyncio +@pytest.mark.parametrize("streaming", [False, True]) +async def test_runner_calls_mcp_tool_with_args(streaming: bool): + """Test that the runner calls an MCP tool when the model produces a tool call.""" + server = FakeMCPServer() + await server.connect() + server.add_tool("test_tool_1", {}) + server.add_tool("test_tool_2", Foo.model_json_schema()) + server.add_tool("test_tool_3", {}) + model = FakeModel() + agent = Agent( + name="test", + model=model, + mcp_servers=[server], + ) + + json_args = json.dumps(Foo(bar="baz", baz=1).model_dump()) + + model.add_multiple_turn_outputs( + [ + # First turn: a message and tool call + [get_text_message("a_message"), get_function_tool_call("test_tool_2", json_args)], + # Second turn: text message + [get_text_message("done")], + ] + ) + + if streaming: + result = Runner.run_streamed(agent, input="user_message") + async for _ in result.stream_events(): + pass + else: + await Runner.run(agent, input="user_message") + + assert server.tool_calls == ["test_tool_2"] + assert server.tool_results == [f"result_test_tool_2_{json_args}"] + + await server.cleanup() diff --git a/tests/mcp/test_server_errors.py b/tests/mcp/test_server_errors.py new file mode 100644 index 00000000..5c6432bc --- /dev/null +++ b/tests/mcp/test_server_errors.py @@ -0,0 +1,38 @@ +import pytest + +from agents.exceptions import UserError +from agents.mcp.server import _MCPServerWithClientSession + + +class CrashingClientSessionServer(_MCPServerWithClientSession): + def __init__(self): + super().__init__(cache_tools_list=False) + self.cleanup_called = False + + def create_streams(self): + raise ValueError("Crash!") + + async def cleanup(self): + self.cleanup_called = True + await super().cleanup() + + +@pytest.mark.asyncio +async def test_server_errors_cause_error_and_cleanup_called(): + server = CrashingClientSessionServer() + + with pytest.raises(ValueError): + await server.connect() + + assert server.cleanup_called + + +@pytest.mark.asyncio +async def test_not_calling_connect_causes_error(): + server = CrashingClientSessionServer() + + with pytest.raises(UserError): + await server.list_tools() + + with pytest.raises(UserError): + await server.call_tool("foo", {}) diff --git a/tests/test_run_step_execution.py b/tests/test_run_step_execution.py index 2d581bf6..16c62c84 100644 --- a/tests/test_run_step_execution.py +++ b/tests/test_run_step_execution.py @@ -290,6 +290,7 @@ async def get_execute_result( processed_response = RunImpl.process_model_response( agent=agent, + all_tools=await agent.get_all_tools(), response=response, output_schema=output_schema, handoffs=handoffs, diff --git a/tests/test_run_step_processing.py b/tests/test_run_step_processing.py index 24f9e8e3..2a6634ac 100644 --- a/tests/test_run_step_processing.py +++ b/tests/test_run_step_processing.py @@ -43,7 +43,11 @@ def test_empty_response(): ) result = RunImpl.process_model_response( - agent=agent, response=response, output_schema=None, handoffs=[] + agent=agent, + response=response, + output_schema=None, + handoffs=[], + all_tools=[], ) assert not result.handoffs assert not result.functions @@ -57,13 +61,14 @@ def test_no_tool_calls(): referenceable_id=None, ) result = RunImpl.process_model_response( - agent=agent, response=response, output_schema=None, handoffs=[] + agent=agent, response=response, output_schema=None, handoffs=[], all_tools=[] ) assert not result.handoffs assert not result.functions -def test_single_tool_call(): +@pytest.mark.asyncio +async def test_single_tool_call(): agent = Agent(name="test", tools=[get_function_tool(name="test")]) response = ModelResponse( output=[ @@ -74,7 +79,11 @@ def test_single_tool_call(): referenceable_id=None, ) result = RunImpl.process_model_response( - agent=agent, response=response, output_schema=None, handoffs=[] + agent=agent, + response=response, + output_schema=None, + handoffs=[], + all_tools=await agent.get_all_tools(), ) assert not result.handoffs assert result.functions and len(result.functions) == 1 @@ -84,7 +93,8 @@ def test_single_tool_call(): assert func.tool_call.arguments == "" -def test_missing_tool_call_raises_error(): +@pytest.mark.asyncio +async def test_missing_tool_call_raises_error(): agent = Agent(name="test", tools=[get_function_tool(name="test")]) response = ModelResponse( output=[ @@ -97,11 +107,16 @@ def test_missing_tool_call_raises_error(): with pytest.raises(ModelBehaviorError): RunImpl.process_model_response( - agent=agent, response=response, output_schema=None, handoffs=[] + agent=agent, + response=response, + output_schema=None, + handoffs=[], + all_tools=await agent.get_all_tools(), ) -def test_multiple_tool_calls(): +@pytest.mark.asyncio +async def test_multiple_tool_calls(): agent = Agent( name="test", tools=[ @@ -121,7 +136,11 @@ def test_multiple_tool_calls(): ) result = RunImpl.process_model_response( - agent=agent, response=response, output_schema=None, handoffs=[] + agent=agent, + response=response, + output_schema=None, + handoffs=[], + all_tools=await agent.get_all_tools(), ) assert not result.handoffs assert result.functions and len(result.functions) == 2 @@ -146,7 +165,11 @@ async def test_handoffs_parsed_correctly(): referenceable_id=None, ) result = RunImpl.process_model_response( - agent=agent_3, response=response, output_schema=None, handoffs=[] + agent=agent_3, + response=response, + output_schema=None, + handoffs=[], + all_tools=await agent_3.get_all_tools(), ) assert not result.handoffs, "Shouldn't have a handoff here" @@ -160,6 +183,7 @@ async def test_handoffs_parsed_correctly(): response=response, output_schema=None, handoffs=Runner._get_handoffs(agent_3), + all_tools=await agent_3.get_all_tools(), ) assert len(result.handoffs) == 1, "Should have a handoff here" handoff = result.handoffs[0] @@ -189,10 +213,12 @@ async def test_missing_handoff_fails(): response=response, output_schema=None, handoffs=Runner._get_handoffs(agent_3), + all_tools=await agent_3.get_all_tools(), ) -def test_multiple_handoffs_doesnt_error(): +@pytest.mark.asyncio +async def test_multiple_handoffs_doesnt_error(): agent_1 = Agent(name="test_1") agent_2 = Agent(name="test_2") agent_3 = Agent(name="test_3", handoffs=[agent_1, agent_2]) @@ -210,6 +236,7 @@ def test_multiple_handoffs_doesnt_error(): response=response, output_schema=None, handoffs=Runner._get_handoffs(agent_3), + all_tools=await agent_3.get_all_tools(), ) assert len(result.handoffs) == 2, "Should have multiple handoffs here" @@ -218,7 +245,8 @@ class Foo(BaseModel): bar: str -def test_final_output_parsed_correctly(): +@pytest.mark.asyncio +async def test_final_output_parsed_correctly(): agent = Agent(name="test", output_type=Foo) response = ModelResponse( output=[ @@ -234,10 +262,12 @@ def test_final_output_parsed_correctly(): response=response, output_schema=Runner._get_output_schema(agent), handoffs=[], + all_tools=await agent.get_all_tools(), ) -def test_file_search_tool_call_parsed_correctly(): +@pytest.mark.asyncio +async def test_file_search_tool_call_parsed_correctly(): # Ensure that a ResponseFileSearchToolCall output is parsed into a ToolCallItem and that no tool # runs are scheduled. @@ -254,7 +284,11 @@ def test_file_search_tool_call_parsed_correctly(): referenceable_id=None, ) result = RunImpl.process_model_response( - agent=agent, response=response, output_schema=None, handoffs=[] + agent=agent, + response=response, + output_schema=None, + handoffs=[], + all_tools=await agent.get_all_tools(), ) # The final item should be a ToolCallItem for the file search call assert any( @@ -265,7 +299,8 @@ def test_file_search_tool_call_parsed_correctly(): assert not result.handoffs -def test_function_web_search_tool_call_parsed_correctly(): +@pytest.mark.asyncio +async def test_function_web_search_tool_call_parsed_correctly(): agent = Agent(name="test") web_search_call = ResponseFunctionWebSearch(id="w1", status="completed", type="web_search_call") response = ModelResponse( @@ -274,7 +309,11 @@ def test_function_web_search_tool_call_parsed_correctly(): referenceable_id=None, ) result = RunImpl.process_model_response( - agent=agent, response=response, output_schema=None, handoffs=[] + agent=agent, + response=response, + output_schema=None, + handoffs=[], + all_tools=await agent.get_all_tools(), ) assert any( isinstance(item, ToolCallItem) and item.raw_item is web_search_call @@ -284,7 +323,8 @@ def test_function_web_search_tool_call_parsed_correctly(): assert not result.handoffs -def test_reasoning_item_parsed_correctly(): +@pytest.mark.asyncio +async def test_reasoning_item_parsed_correctly(): # Verify that a Reasoning output item is converted into a ReasoningItem. reasoning = ResponseReasoningItem( @@ -296,7 +336,11 @@ def test_reasoning_item_parsed_correctly(): referenceable_id=None, ) result = RunImpl.process_model_response( - agent=Agent(name="test"), response=response, output_schema=None, handoffs=[] + agent=Agent(name="test"), + response=response, + output_schema=None, + handoffs=[], + all_tools=await Agent(name="test").get_all_tools(), ) assert any( isinstance(item, ReasoningItem) and item.raw_item is reasoning for item in result.new_items @@ -342,7 +386,8 @@ def drag(self, path: list[tuple[int, int]]) -> None: return None # pragma: no cover -def test_computer_tool_call_without_computer_tool_raises_error(): +@pytest.mark.asyncio +async def test_computer_tool_call_without_computer_tool_raises_error(): # If the agent has no ComputerTool in its tools, process_model_response should raise a # ModelBehaviorError when encountering a ResponseComputerToolCall. computer_call = ResponseComputerToolCall( @@ -360,11 +405,16 @@ def test_computer_tool_call_without_computer_tool_raises_error(): ) with pytest.raises(ModelBehaviorError): RunImpl.process_model_response( - agent=Agent(name="test"), response=response, output_schema=None, handoffs=[] + agent=Agent(name="test"), + response=response, + output_schema=None, + handoffs=[], + all_tools=await Agent(name="test").get_all_tools(), ) -def test_computer_tool_call_with_computer_tool_parsed_correctly(): +@pytest.mark.asyncio +async def test_computer_tool_call_with_computer_tool_parsed_correctly(): # If the agent contains a ComputerTool, ensure that a ResponseComputerToolCall is parsed into a # ToolCallItem and scheduled to run in computer_actions. dummy_computer = DummyComputer() @@ -383,7 +433,11 @@ def test_computer_tool_call_with_computer_tool_parsed_correctly(): referenceable_id=None, ) result = RunImpl.process_model_response( - agent=agent, response=response, output_schema=None, handoffs=[] + agent=agent, + response=response, + output_schema=None, + handoffs=[], + all_tools=await agent.get_all_tools(), ) assert any( isinstance(item, ToolCallItem) and item.raw_item is computer_call @@ -392,7 +446,8 @@ def test_computer_tool_call_with_computer_tool_parsed_correctly(): assert result.computer_actions and result.computer_actions[0].tool_call == computer_call -def test_tool_and_handoff_parsed_correctly(): +@pytest.mark.asyncio +async def test_tool_and_handoff_parsed_correctly(): agent_1 = Agent(name="test_1") agent_2 = Agent(name="test_2") agent_3 = Agent( @@ -413,6 +468,7 @@ def test_tool_and_handoff_parsed_correctly(): response=response, output_schema=None, handoffs=Runner._get_handoffs(agent_3), + all_tools=await agent_3.get_all_tools(), ) assert result.functions and len(result.functions) == 1 assert len(result.handoffs) == 1, "Should have a handoff here" diff --git a/tests/voice/test_openai_stt.py b/tests/voice/test_openai_stt.py index 75559232..89b5cca7 100644 --- a/tests/voice/test_openai_stt.py +++ b/tests/voice/test_openai_stt.py @@ -269,7 +269,7 @@ def fake_time_func(): async for _ in turns: pass - assert "Timeout waiting for transcription_session.created event" in str(exc_info.value) + assert "Timeout waiting for transcription_session.created event" in str(exc_info.value) await session.close() @@ -302,13 +302,11 @@ async def test_session_error_event(): trace_include_sensitive_audio_data=False, ) - with pytest.raises(STTWebsocketConnectionError) as exc_info: + with pytest.raises(STTWebsocketConnectionError): turns = session.transcribe_turns() async for _ in turns: pass - assert "Simulated server error!" in str(exc_info.value) - await session.close() @@ -362,8 +360,8 @@ async def test_inactivity_timeout(): async for turn in session.transcribe_turns(): collected_turns.append(turn) - assert "Timeout waiting for transcription_session" in str(exc_info.value) + assert "Timeout waiting for transcription_session" in str(exc_info.value) - assert len(collected_turns) == 0, "No transcripts expected, but we got something?" + assert len(collected_turns) == 0, "No transcripts expected, but we got something?" await session.close()