diff --git a/README.md b/README.md index 335542c79..f8852d215 100644 --- a/README.md +++ b/README.md @@ -417,9 +417,21 @@ server_params = StdioServerParameters( env=None # Optional environment variables ) +# Optional: create a sampling callback +async def handle_sampling_message(message: types.CreateMessageRequestParams) -> types.CreateMessageResult: + return types.CreateMessageResult( + role="assistant", + content=types.TextContent( + type="text", + text="Hello, world! from model", + ), + model="gpt-3.5-turbo", + stopReason="endTurn", + ) + async def run(): async with stdio_client(server_params) as (read, write): - async with ClientSession(read, write) as session: + async with ClientSession(read, write, sampling_callback=handle_sampling_message) as session: # Initialize the connection await session.initialize() diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 27ca74d8c..caa8e0f2c 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -1,12 +1,17 @@ from datetime import timedelta +from typing import Awaitable, Callable from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import AnyUrl import mcp.types as types -from mcp.shared.session import BaseSession +from mcp.shared.session import BaseSession, RequestResponder from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS +SamplingFnT = Callable[ + [types.CreateMessageRequestParams], Awaitable[types.CreateMessageResult] +] + class ClientSession( BaseSession[ @@ -17,11 +22,14 @@ class ClientSession( types.ServerNotification, ] ): + sampling_callback: SamplingFnT | None = None + def __init__( self, read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception], write_stream: MemoryObjectSendStream[types.JSONRPCMessage], read_timeout_seconds: timedelta | None = None, + sampling_callback: SamplingFnT | None = None, ) -> None: super().__init__( read_stream, @@ -30,8 +38,13 @@ def __init__( types.ServerNotification, read_timeout_seconds=read_timeout_seconds, ) + self.sampling_callback = sampling_callback async def initialize(self) -> types.InitializeResult: + sampling = ( + types.SamplingCapability() if self.sampling_callback is not None else None + ) + result = await self.send_request( types.ClientRequest( types.InitializeRequest( @@ -39,7 +52,7 @@ async def initialize(self) -> types.InitializeResult: params=types.InitializeRequestParams( protocolVersion=types.LATEST_PROTOCOL_VERSION, capabilities=types.ClientCapabilities( - sampling=None, + sampling=sampling, experimental=None, roots=types.RootsCapability( # TODO: Should this be based on whether we @@ -232,3 +245,13 @@ async def send_roots_list_changed(self) -> None: ) ) ) + + async def _received_request( + self, responder: RequestResponder["types.ServerRequest", "types.ClientResult"] + ) -> None: + if isinstance(responder.request.root, types.CreateMessageRequest): + if self.sampling_callback is not None: + response = await self.sampling_callback(responder.request.root.params) + client_response = types.ClientResult(root=response) + with responder: + await responder.respond(client_response) diff --git a/src/mcp/shared/memory.py b/src/mcp/shared/memory.py index 72549925b..0900cfd87 100644 --- a/src/mcp/shared/memory.py +++ b/src/mcp/shared/memory.py @@ -9,7 +9,7 @@ import anyio from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream -from mcp.client.session import ClientSession +from mcp.client.session import ClientSession, SamplingFnT from mcp.server import Server from mcp.types import JSONRPCMessage @@ -54,6 +54,7 @@ async def create_client_server_memory_streams() -> ( async def create_connected_server_and_client_session( server: Server, read_timeout_seconds: timedelta | None = None, + sampling_callback: SamplingFnT | None = None, raise_exceptions: bool = False, ) -> AsyncGenerator[ClientSession, None]: """Creates a ClientSession that is connected to a running MCP server.""" @@ -80,6 +81,7 @@ async def create_connected_server_and_client_session( read_stream=client_read, write_stream=client_write, read_timeout_seconds=read_timeout_seconds, + sampling_callback=sampling_callback, ) as client_session: await client_session.initialize() yield client_session diff --git a/tests/client/test_sampling_callback.py b/tests/client/test_sampling_callback.py new file mode 100644 index 000000000..5f4ff30fe --- /dev/null +++ b/tests/client/test_sampling_callback.py @@ -0,0 +1,53 @@ +import pytest + +from mcp.shared.memory import ( + create_connected_server_and_client_session as create_session, +) +from mcp.types import ( + CreateMessageRequestParams, + CreateMessageResult, + SamplingMessage, + TextContent, +) + + +@pytest.mark.anyio +async def test_sampling_callback(): + from mcp.server.fastmcp import FastMCP + + server = FastMCP("test") + + callback_return = CreateMessageResult( + role="assistant", + content=TextContent( + type="text", text="This is a response from the sampling callback" + ), + model="test-model", + stopReason="endTurn", + ) + + async def sampling_callback( + message: CreateMessageRequestParams, + ) -> CreateMessageResult: + return callback_return + + @server.tool("test_sampling") + async def test_sampling_tool(message: str): + value = await server.get_context().session.create_message( + messages=[ + SamplingMessage( + role="user", content=TextContent(type="text", text=message) + ) + ], + max_tokens=100, + ) + assert value == callback_return + return True + + async with create_session( + server._mcp_server, sampling_callback=sampling_callback + ) as client_session: + # Make a request to trigger sampling callback + assert await client_session.call_tool( + "test_sampling", {"message": "Test message for sampling"} + ) diff --git a/tests/client/test_stdio.py b/tests/client/test_stdio.py index 0bdec72d0..ba9461e6e 100644 --- a/tests/client/test_stdio.py +++ b/tests/client/test_stdio.py @@ -1,12 +1,17 @@ +import shutil + import pytest from mcp.client.stdio import StdioServerParameters, stdio_client from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse +tee: str = shutil.which("tee") # type: ignore +assert tee is not None, "could not find tee command" + @pytest.mark.anyio async def test_stdio_client(): - server_parameters = StdioServerParameters(command="/usr/bin/tee") + server_parameters = StdioServerParameters(command=tee) async with stdio_client(server_parameters) as (read_stream, write_stream): # Test sending and receiving messages