Skip to content

[1/n] Add MCP types to the SDK #320

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ dependencies = [
"typing-extensions>=4.12.2, <5",
"requests>=2.0, <3",
"types-requests>=2.0, <3",
"mcp; python_version >= '3.10'",
]
classifiers = [
"Typing :: Typed",
Expand Down
21 changes: 21 additions & 0 deletions src/agents/mcp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
try:
from .server import (
MCPServer,
MCPServerSse,
MCPServerSseParams,
MCPServerStdio,
MCPServerStdioParams,
)
except ImportError:
pass

from .util import MCPUtil

__all__ = [
"MCPServer",
"MCPServerSse",
"MCPServerSseParams",
"MCPServerStdio",
"MCPServerStdioParams",
"MCPUtil",
]
94 changes: 94 additions & 0 deletions src/agents/mcp/mcp_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import functools
import json
from typing import Any

from mcp.types import Tool as MCPTool

from .. import _debug
from ..exceptions import AgentsException, ModelBehaviorError, UserError
from ..logger import logger
from ..run_context import RunContextWrapper
from ..tool import FunctionTool, Tool
from .server import MCPServer


class MCPUtil:
"""Set of utilities for interop between MCP and Agents SDK tools."""

@classmethod
async def get_all_function_tools(cls, servers: list[MCPServer]) -> list[Tool]:
"""Get all function tools from a list of MCP servers."""
tools = []
tool_names: set[str] = set()
for server in servers:
server_tools = await cls.get_function_tools(server)
server_tool_names = {tool.name for tool in server_tools}
if len(server_tool_names & tool_names) > 0:
raise UserError(
f"Duplicate tool names found across MCP servers: "
f"{server_tool_names & tool_names}"
)
tool_names.update(server_tool_names)
tools.extend(server_tools)

return tools

@classmethod
async def get_function_tools(cls, server: MCPServer) -> list[Tool]:
"""Get all function tools from a single MCP server."""
tools = await server.list_tools()
return [cls.to_function_tool(tool, server) for tool in tools]

@classmethod
def to_function_tool(cls, tool: MCPTool, server: MCPServer) -> FunctionTool:
"""Convert an MCP tool to an Agents SDK function tool."""
invoke_func = functools.partial(cls.invoke_mcp_tool, server, tool)
return FunctionTool(
name=tool.name,
description=tool.description or "",
params_json_schema=tool.inputSchema,
on_invoke_tool=invoke_func,
strict_json_schema=False,
)

@classmethod
async def invoke_mcp_tool(
cls, server: MCPServer, tool: MCPTool, context: RunContextWrapper[Any], input_json: str
) -> str:
"""Invoke an MCP tool and return the result as a string."""
try:
json_data: dict[str, Any] = json.loads(input_json) if input_json else {}
except Exception as e:
if _debug.DONT_LOG_TOOL_DATA:
logger.debug(f"Invalid JSON input for tool {tool.name}")
else:
logger.debug(f"Invalid JSON input for tool {tool.name}: {input_json}")
raise ModelBehaviorError(
f"Invalid JSON input for tool {tool.name}: {input_json}"
) from e

if _debug.DONT_LOG_TOOL_DATA:
logger.debug(f"Invoking MCP tool {tool.name}")
else:
logger.debug(f"Invoking MCP tool {tool.name} with input {input_json}")

try:
result = await server.call_tool(tool.name, json_data)
except Exception as e:
logger.error(f"Error invoking MCP tool {tool.name}: {e}")
raise AgentsException(f"Error invoking MCP tool {tool.name}: {e}") from e

if _debug.DONT_LOG_TOOL_DATA:
logger.debug(f"MCP tool {tool.name} completed.")
else:
logger.debug(f"MCP tool {tool.name} returned {result}")

# The MCP tool result is a list of content items, whereas OpenAI tool outputs are a single
# string. We'll try to convert.
if len(result.content) == 1:
return result.content[0].model_dump_json()
elif len(result.content) > 1:
return json.dumps([item.model_dump() for item in result.content])
else:
logger.error(f"Errored MCP tool result: {result}")
return "Error running tool."
269 changes: 269 additions & 0 deletions src/agents/mcp/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,269 @@
from __future__ import annotations

import abc
import asyncio
from contextlib import AbstractAsyncContextManager, AsyncExitStack
from pathlib import Path
from typing import Any, Literal

from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from mcp import ClientSession, StdioServerParameters, Tool as MCPTool, stdio_client
from mcp.client.sse import sse_client
from mcp.types import CallToolResult, JSONRPCMessage
from typing_extensions import NotRequired, TypedDict

from ..exceptions import UserError
from ..logger import logger


class MCPServer(abc.ABC):
"""Base class for Model Context Protocol servers."""

@abc.abstractmethod
async def connect(self):
"""Connect to the server. For example, this might mean spawning a subprocess or
opening a network connection. The server is expected to remain connected until
`cleanup()` is called.
"""
pass

@abc.abstractmethod
async def cleanup(self):
"""Cleanup the server. For example, this might mean closing a subprocess or
closing a network connection.
"""
pass

@abc.abstractmethod
async def list_tools(self) -> list[MCPTool]:
"""List the tools available on the server."""
pass

@abc.abstractmethod
async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> CallToolResult:
"""Invoke a tool on the server."""
pass
Comment on lines +37 to +45
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do you think about making these just return our Tool class? ie. the MCPServer is responsible for doing the conversion

this is an MCPServer class specifically for integrating with our SDK, so it can use our types

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah good question. I considered that, but I actually think MCP servers should be shared between frameworks. I talked to @dsp-ant and plan to make a PR to the MCP repo to make this part of their SDK, so then this import would just be an alias to mcp.client.MCPServer.



class _MCPServerWithClientSession(MCPServer, abc.ABC):
"""Base class for MCP servers that use a `ClientSession` to communicate with the server."""

def __init__(self, cache_tools_list: bool):
"""
Args:
cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be
cached and only fetched from the server once. If `False`, the tools list will be
fetched from the server on each call to `list_tools()`. The cache can be invalidated
by calling `invalidate_tools_cache()`. You should set this to `True` if you know the
server will not change its tools list, because it can drastically improve latency
(by avoiding a round-trip to the server every time).
"""
self.session: ClientSession | None = None
self.exit_stack: AsyncExitStack = AsyncExitStack()
self._cleanup_lock: asyncio.Lock = asyncio.Lock()
self.cache_tools_list = cache_tools_list

# The cache is always dirty at startup, so that we fetch tools at least once
self._cache_dirty = True
self._tools_list: list[MCPTool] | None = None

@abc.abstractmethod
def create_streams(
self,
) -> AbstractAsyncContextManager[
tuple[
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
MemoryObjectSendStream[JSONRPCMessage],
]
]:
"""Create the streams for the server."""
pass

async def __aenter__(self):
await self.connect()
return self

async def __aexit__(self, exc_type, exc_value, traceback):
await self.cleanup()

def invalidate_tools_cache(self):
"""Invalidate the tools cache."""
self._cache_dirty = True

async def connect(self):
"""Connect to the server."""
try:
transport = await self.exit_stack.enter_async_context(self.create_streams())
read, write = transport
session = await self.exit_stack.enter_async_context(ClientSession(read, write))
await session.initialize()
self.session = session
except Exception as e:
logger.error(f"Error initializing MCP server: {e}")
await self.cleanup()
raise

async def list_tools(self) -> list[MCPTool]:
"""List the tools available on the server."""
if not self.session:
raise UserError("Server not initialized. Make sure you call `connect()` first.")

# Return from cache if caching is enabled, we have tools, and the cache is not dirty
if self.cache_tools_list and not self._cache_dirty and self._tools_list:
return self._tools_list

# Reset the cache dirty to False
self._cache_dirty = False

# Fetch the tools from the server
self._tools_list = (await self.session.list_tools()).tools
return self._tools_list

async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> CallToolResult:
"""Invoke a tool on the server."""
if not self.session:
raise UserError("Server not initialized. Make sure you call `connect()` first.")

return await self.session.call_tool(tool_name, arguments)

async def cleanup(self):
"""Cleanup the server."""
async with self._cleanup_lock:
try:
await self.exit_stack.aclose()
self.session = None
except Exception as e:
logger.error(f"Error cleaning up server: {e}")


class MCPServerStdioParams(TypedDict):
"""Mirrors `mcp.client.stdio.StdioServerParameters`, but lets you pass params without another
import.
"""

command: str
"""The executable to run to start the server. For example, `python` or `node`."""

args: NotRequired[list[str]]
"""Command line args to pass to the `command` executable. For example, `['foo.py']` or
`['server.js', '--port', '8080']`."""

env: NotRequired[dict[str, str]]
"""The environment variables to set for the server. ."""

cwd: NotRequired[str | Path]
"""The working directory to use when spawning the process."""

encoding: NotRequired[str]
"""The text encoding used when sending/receiving messages to the server. Defaults to `utf-8`."""

encoding_error_handler: NotRequired[Literal["strict", "ignore", "replace"]]
"""The text encoding error handler. Defaults to `strict`.

See https://docs.python.org/3/library/codecs.html#codec-base-classes for
explanations of possible values.
"""


class MCPServerStdio(_MCPServerWithClientSession):
"""MCP server implementation that uses the stdio transport. See the [spec]
(https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#stdio) for
details.
"""

def __init__(self, params: MCPServerStdioParams, cache_tools_list: bool = False):
"""Create a new MCP server based on the stdio transport.

Args:
params: The params that configure the server. This includes:
- The command (e.g. `python` or `node`) that starts the server.
- The args to pass to the server command (e.g. `foo.py` or `server.js`).
- The environment variables to set for the server.
cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be
cached and only fetched from the server once. If `False`, the tools list will be
fetched from the server on each call to `list_tools()`. The cache can be
invalidated by calling `invalidate_tools_cache()`. You should set this to `True`
if you know the server will not change its tools list, because it can drastically
improve latency (by avoiding a round-trip to the server every time).
"""
super().__init__(cache_tools_list)

self.params = StdioServerParameters(
command=params["command"],
args=params.get("args", []),
env=params.get("env"),
cwd=params.get("cwd"),
encoding=params.get("encoding", "utf-8"),
encoding_error_handler=params.get("encoding_error_handler", "strict"),
)

def create_streams(
self,
) -> AbstractAsyncContextManager[
tuple[
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
MemoryObjectSendStream[JSONRPCMessage],
]
]:
"""Create the streams for the server."""
return stdio_client(self.params)


class MCPServerSseParams(TypedDict):
"""Mirrors the params in`mcp.client.sse.sse_client`."""

url: str
"""The URL of the server."""

headers: NotRequired[dict[str, str]]
"""The headers to send to the server."""

timeout: NotRequired[float]
"""The timeout for the HTTP request. Defaults to 5 seconds."""

sse_read_timeout: NotRequired[float]
"""The timeout for the SSE connection, in seconds. Defaults to 5 minutes."""


class MCPServerSse(_MCPServerWithClientSession):
"""MCP server implementation that uses the HTTP with SSE transport. See the [spec]
(https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#http-with-sse)
for details.
"""

def __init__(self, params: MCPServerSseParams, cache_tools_list: bool = False):
"""Create a new MCP server based on the HTTP with SSE transport.

Args:
params: The params that configure the server. This includes:
- The URL of the server.
- The headers to send to the server.
- The timeout for the HTTP request.
- The timeout for the SSE connection.

cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be
cached and only fetched from the server once. If `False`, the tools list will be
fetched from the server on each call to `list_tools()`. The cache can be
invalidated by calling `invalidate_tools_cache()`. You should set this to `True`
if you know the server will not change its tools list, because it can drastically
improve latency (by avoiding a round-trip to the server every time).
"""
super().__init__(cache_tools_list)

self.params = params

def create_streams(
self,
) -> AbstractAsyncContextManager[
tuple[
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
MemoryObjectSendStream[JSONRPCMessage],
]
]:
"""Create the streams for the server."""
return sse_client(
url=self.params["url"],
headers=self.params.get("headers", None),
timeout=self.params.get("timeout", 5),
sse_read_timeout=self.params.get("sse_read_timeout", 60 * 5),
)
Loading
Loading