diff --git a/src/agents/mcp/server.py b/src/agents/mcp/server.py index 9a137bbdd..9916c92b0 100644 --- a/src/agents/mcp/server.py +++ b/src/agents/mcp/server.py @@ -3,6 +3,7 @@ import abc import asyncio from contextlib import AbstractAsyncContextManager, AsyncExitStack +from datetime import timedelta from pathlib import Path from typing import Any, Literal @@ -54,7 +55,7 @@ async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> C class _MCPServerWithClientSession(MCPServer, abc.ABC): """Base class for MCP servers that use a `ClientSession` to communicate with the server.""" - def __init__(self, cache_tools_list: bool): + def __init__(self, cache_tools_list: bool, client_session_timeout_seconds: float | None): """ Args: cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be @@ -63,12 +64,16 @@ def __init__(self, cache_tools_list: bool): by calling `invalidate_tools_cache()`. You should set this to `True` if you know the server will not change its tools list, because it can drastically improve latency (by avoiding a round-trip to the server every time). + + client_session_timeout_seconds: the read timeout passed to the MCP ClientSession. """ self.session: ClientSession | None = None self.exit_stack: AsyncExitStack = AsyncExitStack() self._cleanup_lock: asyncio.Lock = asyncio.Lock() self.cache_tools_list = cache_tools_list + self.client_session_timeout_seconds = client_session_timeout_seconds + # The cache is always dirty at startup, so that we fetch tools at least once self._cache_dirty = True self._tools_list: list[MCPTool] | None = None @@ -101,7 +106,15 @@ async def connect(self): try: transport = await self.exit_stack.enter_async_context(self.create_streams()) read, write = transport - session = await self.exit_stack.enter_async_context(ClientSession(read, write)) + session = await self.exit_stack.enter_async_context( + ClientSession( + read, + write, + timedelta(seconds=self.client_session_timeout_seconds) + if self.client_session_timeout_seconds + else None, + ) + ) await session.initialize() self.session = session except Exception as e: @@ -183,6 +196,7 @@ def __init__( params: MCPServerStdioParams, cache_tools_list: bool = False, name: str | None = None, + client_session_timeout_seconds: float | None = 5, ): """Create a new MCP server based on the stdio transport. @@ -199,8 +213,9 @@ def __init__( improve latency (by avoiding a round-trip to the server every time). name: A readable name for the server. If not provided, we'll create one from the command. + client_session_timeout_seconds: the read timeout passed to the MCP ClientSession. """ - super().__init__(cache_tools_list) + super().__init__(cache_tools_list, client_session_timeout_seconds) self.params = StdioServerParameters( command=params["command"], @@ -257,6 +272,7 @@ def __init__( params: MCPServerSseParams, cache_tools_list: bool = False, name: str | None = None, + client_session_timeout_seconds: float | None = 5, ): """Create a new MCP server based on the HTTP with SSE transport. @@ -274,8 +290,10 @@ def __init__( name: A readable name for the server. If not provided, we'll create one from the URL. + + client_session_timeout_seconds: the read timeout passed to the MCP ClientSession. """ - super().__init__(cache_tools_list) + super().__init__(cache_tools_list, client_session_timeout_seconds) self.params = params self._name = name or f"sse: {self.params['url']}" diff --git a/tests/mcp/test_server_errors.py b/tests/mcp/test_server_errors.py index bdca7ce62..fbd8db17d 100644 --- a/tests/mcp/test_server_errors.py +++ b/tests/mcp/test_server_errors.py @@ -6,7 +6,7 @@ class CrashingClientSessionServer(_MCPServerWithClientSession): def __init__(self): - super().__init__(cache_tools_list=False) + super().__init__(cache_tools_list=False, client_session_timeout_seconds=5) self.cleanup_called = False def create_streams(self):