From cce8519221b6174c76c99c9a17240a9a41866b5b Mon Sep 17 00:00:00 2001 From: TerminalMan <84923604+SecretiveShell@users.noreply.github.com> Date: Sat, 4 Jan 2025 13:28:36 +0000 Subject: [PATCH 1/7] add sampling callback paramater --- src/mcp/client/session.py | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 27ca74d8c..254b5bc10 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -1,4 +1,6 @@ from datetime import timedelta +from inspect import iscoroutinefunction +from typing import Awaitable, Callable from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import AnyUrl @@ -7,6 +9,10 @@ from mcp.shared.session import BaseSession from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS +sampling_function_signature = Callable[ + [types.CreateMessageRequestParams], Awaitable[types.CreateMessageResult] +] + class ClientSession( BaseSession[ @@ -17,11 +23,14 @@ class ClientSession( types.ServerNotification, ] ): + sampling_callback: sampling_function_signature | None = None + def __init__( self, read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception], write_stream: MemoryObjectSendStream[types.JSONRPCMessage], read_timeout_seconds: timedelta | None = None, + sampling_callback: sampling_function_signature | None = None, ) -> None: super().__init__( read_stream, @@ -31,7 +40,21 @@ def __init__( read_timeout_seconds=read_timeout_seconds, ) + # validate sampling_callback + # use asserts here because this should be known at compile time + if sampling_callback is not None: + assert callable(sampling_callback), "sampling_callback must be callable" + assert iscoroutinefunction( + sampling_callback + ), "sampling_callback must be an async function" + + self.sampling_callback = sampling_callback + async def initialize(self) -> types.InitializeResult: + sampling = None + if self.sampling_callback is not None: + sampling = types.SamplingCapability() + result = await self.send_request( types.ClientRequest( types.InitializeRequest( @@ -39,7 +62,7 @@ async def initialize(self) -> types.InitializeResult: params=types.InitializeRequestParams( protocolVersion=types.LATEST_PROTOCOL_VERSION, capabilities=types.ClientCapabilities( - sampling=None, + sampling=sampling, experimental=None, roots=types.RootsCapability( # TODO: Should this be based on whether we From 368782c15543a655d06e6098afe81e3926275712 Mon Sep 17 00:00:00 2001 From: TerminalMan <84923604+SecretiveShell@users.noreply.github.com> Date: Sat, 4 Jan 2025 16:29:37 +0000 Subject: [PATCH 2/7] add request handler --- src/mcp/client/session.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 254b5bc10..87df78000 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -6,7 +6,7 @@ from pydantic import AnyUrl import mcp.types as types -from mcp.shared.session import BaseSession +from mcp.shared.session import BaseSession, RequestResponder from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS sampling_function_signature = Callable[ @@ -255,3 +255,17 @@ async def send_roots_list_changed(self) -> None: ) ) ) + + async def _received_request( + self, responder: RequestResponder["types.ServerRequest", "types.ClientResult"] + ) -> None: + if isinstance(responder.request.root, types.CreateMessageRequest): + print("Received create message request") + if self.sampling_callback is None: + raise RuntimeError("Sampling callback is not set") + response = await self.sampling_callback(responder.request.root.params) + + client_response = types.ClientResult(**response.model_dump()) + + print(f"Response: {response.dict()}") + await responder.respond(client_response) From f9de5f096d30e379386b801e1c2865ce3ae6f940 Mon Sep 17 00:00:00 2001 From: TerminalMan <84923604+SecretiveShell@users.noreply.github.com> Date: Tue, 7 Jan 2025 16:20:20 +0000 Subject: [PATCH 3/7] cleanup print statements --- src/mcp/client/session.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 87df78000..71284ab0f 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -259,13 +259,13 @@ async def send_roots_list_changed(self) -> None: async def _received_request( self, responder: RequestResponder["types.ServerRequest", "types.ClientResult"] ) -> None: + if isinstance(responder.request.root, types.CreateMessageRequest): - print("Received create message request") + # handle create message request (sampling) + if self.sampling_callback is None: raise RuntimeError("Sampling callback is not set") + response = await self.sampling_callback(responder.request.root.params) - client_response = types.ClientResult(**response.model_dump()) - - print(f"Response: {response.dict()}") await responder.respond(client_response) From 9e68fa8f144d54a9771df6b8a0be43c3ab32d162 Mon Sep 17 00:00:00 2001 From: TerminalMan <84923604+SecretiveShell@users.noreply.github.com> Date: Tue, 7 Jan 2025 18:32:35 +0000 Subject: [PATCH 4/7] add docs to readme --- README.md | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 335542c79..f8852d215 100644 --- a/README.md +++ b/README.md @@ -417,9 +417,21 @@ server_params = StdioServerParameters( env=None # Optional environment variables ) +# Optional: create a sampling callback +async def handle_sampling_message(message: types.CreateMessageRequestParams) -> types.CreateMessageResult: + return types.CreateMessageResult( + role="assistant", + content=types.TextContent( + type="text", + text="Hello, world! from model", + ), + model="gpt-3.5-turbo", + stopReason="endTurn", + ) + async def run(): async with stdio_client(server_params) as (read, write): - async with ClientSession(read, write) as session: + async with ClientSession(read, write, sampling_callback=handle_sampling_message) as session: # Initialize the connection await session.initialize() From f979ca9980689df7816a99bcfc391a653f73eca4 Mon Sep 17 00:00:00 2001 From: TerminalMan <84923604+SecretiveShell@users.noreply.github.com> Date: Tue, 7 Jan 2025 18:47:15 +0000 Subject: [PATCH 5/7] ruff format --- src/mcp/client/session.py | 3 +-- tests/client/test_stdio.py | 7 ++++++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 71284ab0f..93206d728 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -259,13 +259,12 @@ async def send_roots_list_changed(self) -> None: async def _received_request( self, responder: RequestResponder["types.ServerRequest", "types.ClientResult"] ) -> None: - if isinstance(responder.request.root, types.CreateMessageRequest): # handle create message request (sampling) if self.sampling_callback is None: raise RuntimeError("Sampling callback is not set") - + response = await self.sampling_callback(responder.request.root.params) client_response = types.ClientResult(**response.model_dump()) await responder.respond(client_response) diff --git a/tests/client/test_stdio.py b/tests/client/test_stdio.py index 0bdec72d0..ba9461e6e 100644 --- a/tests/client/test_stdio.py +++ b/tests/client/test_stdio.py @@ -1,12 +1,17 @@ +import shutil + import pytest from mcp.client.stdio import StdioServerParameters, stdio_client from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse +tee: str = shutil.which("tee") # type: ignore +assert tee is not None, "could not find tee command" + @pytest.mark.anyio async def test_stdio_client(): - server_parameters = StdioServerParameters(command="/usr/bin/tee") + server_parameters = StdioServerParameters(command=tee) async with stdio_client(server_parameters) as (read_stream, write_stream): # Test sending and receiving messages From e2e2f4335c471a69d44c4b8435c0b8d9e334e7c3 Mon Sep 17 00:00:00 2001 From: David Soria Parra Date: Thu, 13 Feb 2025 15:30:51 +0000 Subject: [PATCH 6/7] simplify the implementation --- src/mcp/client/session.py | 28 +++++++--------------------- 1 file changed, 7 insertions(+), 21 deletions(-) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 93206d728..cf1aee137 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -1,5 +1,4 @@ from datetime import timedelta -from inspect import iscoroutinefunction from typing import Awaitable, Callable from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream @@ -39,21 +38,12 @@ def __init__( types.ServerNotification, read_timeout_seconds=read_timeout_seconds, ) - - # validate sampling_callback - # use asserts here because this should be known at compile time - if sampling_callback is not None: - assert callable(sampling_callback), "sampling_callback must be callable" - assert iscoroutinefunction( - sampling_callback - ), "sampling_callback must be an async function" - self.sampling_callback = sampling_callback async def initialize(self) -> types.InitializeResult: - sampling = None - if self.sampling_callback is not None: - sampling = types.SamplingCapability() + sampling = ( + types.SamplingCapability() if self.sampling_callback is not None else None + ) result = await self.send_request( types.ClientRequest( @@ -260,11 +250,7 @@ async def _received_request( self, responder: RequestResponder["types.ServerRequest", "types.ClientResult"] ) -> None: if isinstance(responder.request.root, types.CreateMessageRequest): - # handle create message request (sampling) - - if self.sampling_callback is None: - raise RuntimeError("Sampling callback is not set") - - response = await self.sampling_callback(responder.request.root.params) - client_response = types.ClientResult(**response.model_dump()) - await responder.respond(client_response) + if self.sampling_callback is not None: + response = await self.sampling_callback(responder.request.root.params) + client_response = types.ClientResult(root=response) + await responder.respond(client_response) From 8f0f7c5d00ed00195c28d4f88d5e31fe6ceb33a7 Mon Sep 17 00:00:00 2001 From: David Soria Parra Date: Thu, 13 Feb 2025 16:29:40 +0000 Subject: [PATCH 7/7] fix: simplify implementation and add test --- src/mcp/client/session.py | 9 +++-- src/mcp/shared/memory.py | 4 +- tests/client/test_sampling_callback.py | 53 ++++++++++++++++++++++++++ 3 files changed, 61 insertions(+), 5 deletions(-) create mode 100644 tests/client/test_sampling_callback.py diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index cf1aee137..caa8e0f2c 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -8,7 +8,7 @@ from mcp.shared.session import BaseSession, RequestResponder from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS -sampling_function_signature = Callable[ +SamplingFnT = Callable[ [types.CreateMessageRequestParams], Awaitable[types.CreateMessageResult] ] @@ -22,14 +22,14 @@ class ClientSession( types.ServerNotification, ] ): - sampling_callback: sampling_function_signature | None = None + sampling_callback: SamplingFnT | None = None def __init__( self, read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception], write_stream: MemoryObjectSendStream[types.JSONRPCMessage], read_timeout_seconds: timedelta | None = None, - sampling_callback: sampling_function_signature | None = None, + sampling_callback: SamplingFnT | None = None, ) -> None: super().__init__( read_stream, @@ -253,4 +253,5 @@ async def _received_request( if self.sampling_callback is not None: response = await self.sampling_callback(responder.request.root.params) client_response = types.ClientResult(root=response) - await responder.respond(client_response) + with responder: + await responder.respond(client_response) diff --git a/src/mcp/shared/memory.py b/src/mcp/shared/memory.py index 72549925b..0900cfd87 100644 --- a/src/mcp/shared/memory.py +++ b/src/mcp/shared/memory.py @@ -9,7 +9,7 @@ import anyio from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream -from mcp.client.session import ClientSession +from mcp.client.session import ClientSession, SamplingFnT from mcp.server import Server from mcp.types import JSONRPCMessage @@ -54,6 +54,7 @@ async def create_client_server_memory_streams() -> ( async def create_connected_server_and_client_session( server: Server, read_timeout_seconds: timedelta | None = None, + sampling_callback: SamplingFnT | None = None, raise_exceptions: bool = False, ) -> AsyncGenerator[ClientSession, None]: """Creates a ClientSession that is connected to a running MCP server.""" @@ -80,6 +81,7 @@ async def create_connected_server_and_client_session( read_stream=client_read, write_stream=client_write, read_timeout_seconds=read_timeout_seconds, + sampling_callback=sampling_callback, ) as client_session: await client_session.initialize() yield client_session diff --git a/tests/client/test_sampling_callback.py b/tests/client/test_sampling_callback.py new file mode 100644 index 000000000..5f4ff30fe --- /dev/null +++ b/tests/client/test_sampling_callback.py @@ -0,0 +1,53 @@ +import pytest + +from mcp.shared.memory import ( + create_connected_server_and_client_session as create_session, +) +from mcp.types import ( + CreateMessageRequestParams, + CreateMessageResult, + SamplingMessage, + TextContent, +) + + +@pytest.mark.anyio +async def test_sampling_callback(): + from mcp.server.fastmcp import FastMCP + + server = FastMCP("test") + + callback_return = CreateMessageResult( + role="assistant", + content=TextContent( + type="text", text="This is a response from the sampling callback" + ), + model="test-model", + stopReason="endTurn", + ) + + async def sampling_callback( + message: CreateMessageRequestParams, + ) -> CreateMessageResult: + return callback_return + + @server.tool("test_sampling") + async def test_sampling_tool(message: str): + value = await server.get_context().session.create_message( + messages=[ + SamplingMessage( + role="user", content=TextContent(type="text", text=message) + ) + ], + max_tokens=100, + ) + assert value == callback_return + return True + + async with create_session( + server._mcp_server, sampling_callback=sampling_callback + ) as client_session: + # Make a request to trigger sampling callback + assert await client_session.call_tool( + "test_sampling", {"message": "Test message for sampling"} + )