From c6fb822c865c231db6daf4507cdc2088e29456ba Mon Sep 17 00:00:00 2001 From: ihrpr Date: Mon, 12 May 2025 18:31:35 +0100 Subject: [PATCH] Fix streamable http sampling (#693) --- src/mcp/cli/claude.py | 2 + src/mcp/client/streamable_http.py | 24 ++++-- src/mcp/server/session.py | 10 ++- src/mcp/server/streamable_http.py | 21 +++-- src/mcp/shared/session.py | 1 - tests/client/test_config.py | 6 +- tests/shared/test_streamable_http.py | 111 ++++++++++++++++++++++++++- 7 files changed, 152 insertions(+), 23 deletions(-) diff --git a/src/mcp/cli/claude.py b/src/mcp/cli/claude.py index 17c957df..1629f928 100644 --- a/src/mcp/cli/claude.py +++ b/src/mcp/cli/claude.py @@ -31,6 +31,7 @@ def get_claude_config_path() -> Path | None: return path return None + def get_uv_path() -> str: """Get the full path to the uv executable.""" uv_path = shutil.which("uv") @@ -42,6 +43,7 @@ def get_uv_path() -> str: return "uv" # Fall back to just "uv" if not found return uv_path + def update_claude_config( file_spec: str, server_name: str, diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 183653b9..893aeb84 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -15,6 +15,7 @@ import anyio import httpx +from anyio.abc import TaskGroup from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from httpx_sse import EventSource, ServerSentEvent, aconnect_sse @@ -239,7 +240,7 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None: break async def _handle_post_request(self, ctx: RequestContext) -> None: - """Handle a POST request with response processing.""" + """Handle a POST request with response processing.""" headers = self._update_headers_with_session(ctx.headers) message = ctx.session_message.message is_initialization = self._is_initialization_request(message) @@ -300,7 +301,7 @@ async def _handle_sse_response( try: event_source = EventSource(response) async for sse in event_source.aiter_sse(): - await self._handle_sse_event( + is_complete = await self._handle_sse_event( sse, ctx.read_stream_writer, resumption_callback=( @@ -309,6 +310,10 @@ async def _handle_sse_response( else None ), ) + # If the SSE event indicates completion, like returning respose/error + # break the loop + if is_complete: + break except Exception as e: logger.exception("Error reading SSE stream:") await ctx.read_stream_writer.send(e) @@ -344,6 +349,7 @@ async def post_writer( read_stream_writer: StreamWriter, write_stream: MemoryObjectSendStream[SessionMessage], start_get_stream: Callable[[], None], + tg: TaskGroup, ) -> None: """Handle writing requests to the server.""" try: @@ -375,10 +381,17 @@ async def post_writer( sse_read_timeout=self.sse_read_timeout, ) - if is_resumption: - await self._handle_resumption_request(ctx) + async def handle_request_async(): + if is_resumption: + await self._handle_resumption_request(ctx) + else: + await self._handle_post_request(ctx) + + # If this is a request, start a new task to handle it + if isinstance(message.root, JSONRPCRequest): + tg.start_soon(handle_request_async) else: - await self._handle_post_request(ctx) + await handle_request_async() except Exception as exc: logger.error(f"Error in post_writer: {exc}") @@ -466,6 +479,7 @@ def start_get_stream() -> None: read_stream_writer, write_stream, start_get_stream, + tg, ) try: diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index c769d1aa..f4e72eac 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -47,7 +47,7 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]: import mcp.types as types from mcp.server.models import InitializationOptions -from mcp.shared.message import SessionMessage +from mcp.shared.message import ServerMessageMetadata, SessionMessage from mcp.shared.session import ( BaseSession, RequestResponder, @@ -230,10 +230,11 @@ async def create_message( stop_sequences: list[str] | None = None, metadata: dict[str, Any] | None = None, model_preferences: types.ModelPreferences | None = None, + related_request_id: types.RequestId | None = None, ) -> types.CreateMessageResult: """Send a sampling/create_message request.""" return await self.send_request( - types.ServerRequest( + request=types.ServerRequest( types.CreateMessageRequest( method="sampling/createMessage", params=types.CreateMessageRequestParams( @@ -248,7 +249,10 @@ async def create_message( ), ) ), - types.CreateMessageResult, + result_type=types.CreateMessageResult, + metadata=ServerMessageMetadata( + related_request_id=related_request_id, + ), ) async def list_roots(self) -> types.ListRootsResult: diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index ace74b33..8f4a1f51 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -33,7 +33,6 @@ ErrorData, JSONRPCError, JSONRPCMessage, - JSONRPCNotification, JSONRPCRequest, JSONRPCResponse, RequestId, @@ -849,9 +848,15 @@ async def message_router(): # Determine which request stream(s) should receive this message message = session_message.message target_request_id = None - if isinstance( - message.root, JSONRPCNotification | JSONRPCRequest - ): + # Check if this is a response + if isinstance(message.root, JSONRPCResponse | JSONRPCError): + response_id = str(message.root.id) + # If this response is for an existing request stream, + # send it there + if response_id in self._request_streams: + target_request_id = response_id + + else: # Extract related_request_id from meta if it exists if ( session_message.metadata is not None @@ -865,10 +870,12 @@ async def message_router(): target_request_id = str( session_message.metadata.related_request_id ) - else: - target_request_id = str(message.root.id) - request_stream_id = target_request_id or GET_STREAM_KEY + request_stream_id = ( + target_request_id + if target_request_id is not None + else GET_STREAM_KEY + ) # Store the event if we have an event store, # regardless of whether a client is connected diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index cce8b118..c390386a 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -223,7 +223,6 @@ async def send_request( Do not use this method to emit notifications! Use send_notification() instead. """ - request_id = self._request_id self._request_id = request_id + 1 diff --git a/tests/client/test_config.py b/tests/client/test_config.py index 6577d663..9f1cd8ee 100644 --- a/tests/client/test_config.py +++ b/tests/client/test_config.py @@ -54,7 +54,7 @@ def test_absolute_uv_path(mock_config_path: Path): """Test that the absolute path to uv is used when available.""" # Mock the shutil.which function to return a fake path mock_uv_path = "/usr/local/bin/uv" - + with patch("mcp.cli.claude.get_uv_path", return_value=mock_uv_path): # Setup server_name = "test_server" @@ -71,5 +71,5 @@ def test_absolute_uv_path(mock_config_path: Path): # Verify the command is the absolute path server_config = config["mcpServers"][server_name] command = server_config["command"] - - assert command == mock_uv_path \ No newline at end of file + + assert command == mock_uv_path diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 28d29ac2..9b32254a 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -8,6 +8,7 @@ import socket import time from collections.abc import Generator +from typing import Any import anyio import httpx @@ -33,6 +34,7 @@ StreamId, ) from mcp.server.streamable_http_manager import StreamableHTTPSessionManager +from mcp.shared.context import RequestContext from mcp.shared.exceptions import McpError from mcp.shared.message import ( ClientMessageMetadata, @@ -139,6 +141,11 @@ async def handle_list_tools() -> list[Tool]: description="A long-running tool that sends periodic notifications", inputSchema={"type": "object", "properties": {}}, ), + Tool( + name="test_sampling_tool", + description="A tool that triggers server-side sampling", + inputSchema={"type": "object", "properties": {}}, + ), ] @self.call_tool() @@ -174,6 +181,34 @@ async def handle_call_tool(name: str, args: dict) -> list[TextContent]: return [TextContent(type="text", text="Completed!")] + elif name == "test_sampling_tool": + # Test sampling by requesting the client to sample a message + sampling_result = await ctx.session.create_message( + messages=[ + types.SamplingMessage( + role="user", + content=types.TextContent( + type="text", text="Server needs client sampling" + ), + ) + ], + max_tokens=100, + related_request_id=ctx.request_id, + ) + + # Return the sampling result in the tool response + response = ( + sampling_result.content.text + if sampling_result.content.type == "text" + else None + ) + return [ + TextContent( + type="text", + text=f"Response from sampling: {response}", + ) + ] + return [TextContent(type="text", text=f"Called {name}")] @@ -754,7 +789,7 @@ async def test_streamablehttp_client_tool_invocation(initialized_client_session) """Test client tool invocation.""" # First list tools tools = await initialized_client_session.list_tools() - assert len(tools.tools) == 3 + assert len(tools.tools) == 4 assert tools.tools[0].name == "test_tool" # Call the tool @@ -795,7 +830,7 @@ async def test_streamablehttp_client_session_persistence( # Make multiple requests to verify session persistence tools = await session.list_tools() - assert len(tools.tools) == 3 + assert len(tools.tools) == 4 # Read a resource resource = await session.read_resource(uri=AnyUrl("foobar://test-persist")) @@ -826,7 +861,7 @@ async def test_streamablehttp_client_json_response( # Check tool listing tools = await session.list_tools() - assert len(tools.tools) == 3 + assert len(tools.tools) == 4 # Call a tool and verify JSON response handling result = await session.call_tool("test_tool", {}) @@ -905,7 +940,7 @@ async def test_streamablehttp_client_session_termination( # Make a request to confirm session is working tools = await session.list_tools() - assert len(tools.tools) == 3 + assert len(tools.tools) == 4 headers = {} if captured_session_id: @@ -1054,3 +1089,71 @@ async def run_tool(): assert not any( n in captured_notifications_pre for n in captured_notifications ) + + +@pytest.mark.anyio +async def test_streamablehttp_server_sampling(basic_server, basic_server_url): + """Test server-initiated sampling request through streamable HTTP transport.""" + print("Testing server sampling...") + # Variable to track if sampling callback was invoked + sampling_callback_invoked = False + captured_message_params = None + + # Define sampling callback that returns a mock response + async def sampling_callback( + context: RequestContext[ClientSession, Any], + params: types.CreateMessageRequestParams, + ) -> types.CreateMessageResult: + nonlocal sampling_callback_invoked, captured_message_params + sampling_callback_invoked = True + captured_message_params = params + message_received = ( + params.messages[0].content.text + if params.messages[0].content.type == "text" + else None + ) + + return types.CreateMessageResult( + role="assistant", + content=types.TextContent( + type="text", + text=f"Received message from server: {message_received}", + ), + model="test-model", + stopReason="endTurn", + ) + + # Create client with sampling callback + async with streamablehttp_client(f"{basic_server_url}/mcp") as ( + read_stream, + write_stream, + _, + ): + async with ClientSession( + read_stream, + write_stream, + sampling_callback=sampling_callback, + ) as session: + # Initialize the session + result = await session.initialize() + assert isinstance(result, InitializeResult) + + # Call the tool that triggers server-side sampling + tool_result = await session.call_tool("test_sampling_tool", {}) + + # Verify the tool result contains the expected content + assert len(tool_result.content) == 1 + assert tool_result.content[0].type == "text" + assert ( + "Response from sampling: Received message from server" + in tool_result.content[0].text + ) + + # Verify sampling callback was invoked + assert sampling_callback_invoked + assert captured_message_params is not None + assert len(captured_message_params.messages) == 1 + assert ( + captured_message_params.messages[0].content.text + == "Server needs client sampling" + )