diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index e29797d17..fc86f0110 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -254,7 +254,10 @@ async def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult: ) async def call_tool( - self, name: str, arguments: dict[str, Any] | None = None + self, + name: str, + arguments: dict[str, Any] | None = None, + read_timeout_seconds: timedelta | None = None, ) -> types.CallToolResult: """Send a tools/call request.""" return await self.send_request( @@ -265,6 +268,7 @@ async def call_tool( ) ), types.CallToolResult, + request_read_timeout_seconds=read_timeout_seconds, ) async def list_prompts(self) -> types.ListPromptsResult: diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 05fd3ce37..90ad92e33 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -185,7 +185,7 @@ def __init__( self._request_id = 0 self._receive_request_type = receive_request_type self._receive_notification_type = receive_notification_type - self._read_timeout_seconds = read_timeout_seconds + self._session_read_timeout_seconds = read_timeout_seconds self._in_flight = {} self._exit_stack = AsyncExitStack() @@ -213,10 +213,12 @@ async def send_request( self, request: SendRequestT, result_type: type[ReceiveResultT], + request_read_timeout_seconds: timedelta | None = None, ) -> ReceiveResultT: """ Sends a request and wait for a response. Raises an McpError if the - response contains an error. + response contains an error. If a request read timeout is provided, it + will take precedence over the session read timeout. Do not use this method to emit notifications! Use send_notification() instead. @@ -243,12 +245,15 @@ async def send_request( await self._write_stream.send(JSONRPCMessage(jsonrpc_request)) + # request read timeout takes precedence over session read timeout + timeout = None + if request_read_timeout_seconds is not None: + timeout = request_read_timeout_seconds.total_seconds() + elif self._session_read_timeout_seconds is not None: + timeout = self._session_read_timeout_seconds.total_seconds() + try: - with anyio.fail_after( - None - if self._read_timeout_seconds is None - else self._read_timeout_seconds.total_seconds() - ): + with anyio.fail_after(timeout): response_or_error = await response_stream_reader.receive() except TimeoutError: raise McpError( @@ -257,7 +262,7 @@ async def send_request( message=( f"Timed out while waiting for response to " f"{request.__class__.__name__}. Waited " - f"{self._read_timeout_seconds} seconds." + f"{timeout} seconds." ), ) )