Skip to content

Commit 07769a6

Browse files
committed
feat(mcp): Add sampling capabilities to MCP server
Implement client-side sampling features in the Model Context Protocol (MCP) server: - Add createMessage method to McpAsyncServer and McpSyncServer for LLM sampling - Move MockMcpTransport to separate file for reusability - Add integration tests for sampling functionality - Test error handling for uninitialized clients and missing capabilities The sampling implementation allows servers to request LLM completions from clients while maintaining client-side control over model access and permissions. Resolves modelcontextprotocol#42
1 parent 0fd93a9 commit 07769a6

File tree

9 files changed

+331
-153
lines changed

9 files changed

+331
-153
lines changed

mcp/src/main/java/org/springframework/ai/mcp/server/McpAsyncServer.java

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -603,4 +603,41 @@ private DefaultMcpSession.RequestHandler setLoggerRequestHandler() {
603603
};
604604
}
605605

606+
// ---------------------------------------
607+
// Sampling
608+
// ---------------------------------------
609+
private static TypeReference<McpSchema.CreateMessageResult> CREATE_MESSAGE_RESULT_TYPE_REF = new TypeReference<>() {
610+
};
611+
612+
/**
613+
* Create a new message using the sampling capabilities of the client. The Model
614+
* Context Protocol (MCP) provides a standardized way for servers to request LLM
615+
* sampling (“completions” or “generations”) from language models via clients. This
616+
* flow allows clients to maintain control over model access, selection, and
617+
* permissions while enabling servers to leverage AI capabilities—with no server API
618+
* keys necessary. Servers can request text or image-based interactions and optionally
619+
* include context from MCP servers in their prompts.
620+
* @param createMessageRequest The request to create a new message
621+
* @return A Mono that completes when the message has been created
622+
* @throws McpError if the client has not been initialized or does not support
623+
* sampling capabilities
624+
* @throws McpError if the client does not support the createMessage method
625+
* @see McpSchema.CreateMessageRequest
626+
* @see McpSchema.CreateMessageResult
627+
* @see <a href=
628+
* "https://spec.modelcontextprotocol.io/specification/client/sampling/">Sampling
629+
* Specification</a>
630+
*/
631+
public Mono<McpSchema.CreateMessageResult> createMessage(McpSchema.CreateMessageRequest createMessageRequest) {
632+
633+
if (this.clientCapabilities == null) {
634+
return Mono.error(new McpError("Client must be initialized. Call the initialize method first!"));
635+
}
636+
if (this.clientCapabilities.sampling() == null) {
637+
return Mono.error(new McpError("Client must be configured with sampling capabilities"));
638+
}
639+
return this.mcpSession.sendRequest(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE, createMessageRequest,
640+
CREATE_MESSAGE_RESULT_TYPE_REF);
641+
}
642+
606643
}

mcp/src/main/java/org/springframework/ai/mcp/server/McpSyncServer.java

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,4 +178,27 @@ public McpAsyncServer getAsyncServer() {
178178
return this.asyncServer;
179179
}
180180

181+
/**
182+
* Create a new message using the sampling capabilities of the client. The Model
183+
* Context Protocol (MCP) provides a standardized way for servers to request LLM
184+
* sampling (“completions” or “generations”) from language models via clients. This
185+
* flow allows clients to maintain control over model access, selection, and
186+
* permissions while enabling servers to leverage AI capabilities—with no server API
187+
* keys necessary. Servers can request text or image-based interactions and optionally
188+
* include context from MCP servers in their prompts.
189+
* @param createMessageRequest The request to create a new message
190+
* @return A Mono that completes when the message has been created
191+
* @throws McpError if the client has not been initialized or does not support
192+
* sampling capabilities
193+
* @throws McpError if the client does not support the createMessage method
194+
* @see McpSchema.CreateMessageRequest
195+
* @see McpSchema.CreateMessageResult
196+
* @see <a href=
197+
* "https://spec.modelcontextprotocol.io/specification/client/sampling/">Sampling
198+
* Specification</a>
199+
*/
200+
public McpSchema.CreateMessageResult createMessage(McpSchema.CreateMessageRequest createMessageRequest) {
201+
return this.asyncServer.createMessage(createMessageRequest).block();
202+
}
203+
181204
}
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
/*
2+
* Copyright 2024-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.mcp;
18+
19+
import java.util.concurrent.atomic.AtomicInteger;
20+
import java.util.function.Function;
21+
22+
import com.fasterxml.jackson.core.type.TypeReference;
23+
import com.fasterxml.jackson.databind.ObjectMapper;
24+
import reactor.core.publisher.Flux;
25+
import reactor.core.publisher.Mono;
26+
import reactor.core.publisher.Sinks;
27+
import reactor.core.scheduler.Schedulers;
28+
29+
import org.springframework.ai.mcp.spec.McpSchema;
30+
import org.springframework.ai.mcp.spec.McpTransport;
31+
import org.springframework.ai.mcp.spec.McpSchema.JSONRPCMessage;
32+
import org.springframework.ai.mcp.spec.McpSchema.JSONRPCNotification;
33+
import org.springframework.ai.mcp.spec.McpSchema.JSONRPCRequest;
34+
35+
@SuppressWarnings("unused")
36+
public class MockMcpTransport implements McpTransport {
37+
38+
private final AtomicInteger inboundMessageCount = new AtomicInteger(0);
39+
40+
private final Sinks.Many<McpSchema.JSONRPCMessage> outgoing = Sinks.many().multicast().onBackpressureBuffer();
41+
42+
private final Sinks.Many<McpSchema.JSONRPCMessage> inbound = Sinks.many().unicast().onBackpressureBuffer();
43+
44+
private final Flux<McpSchema.JSONRPCMessage> outboundView = outgoing.asFlux().cache(1);
45+
46+
public void simulateIncomingMessage(McpSchema.JSONRPCMessage message) {
47+
if (inbound.tryEmitNext(message).isFailure()) {
48+
throw new RuntimeException("Failed to emit message " + message);
49+
}
50+
inboundMessageCount.incrementAndGet();
51+
}
52+
53+
@Override
54+
public Mono<Void> sendMessage(McpSchema.JSONRPCMessage message) {
55+
if (outgoing.tryEmitNext(message).isFailure()) {
56+
return Mono.error(new RuntimeException("Can't emit outgoing message " + message));
57+
}
58+
return Mono.empty();
59+
}
60+
61+
public McpSchema.JSONRPCRequest getLastSentMessageAsRequest() {
62+
return (JSONRPCRequest) outboundView.blockFirst();
63+
}
64+
65+
public McpSchema.JSONRPCNotification getLastSentMessageAsNotifiation() {
66+
return (JSONRPCNotification) outboundView.blockFirst();
67+
}
68+
69+
public McpSchema.JSONRPCMessage getLastSentMessage() {
70+
return outboundView.blockFirst();
71+
}
72+
73+
private volatile boolean connected = false;
74+
75+
@Override
76+
public Mono<Void> connect(Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>> handler) {
77+
if (connected) {
78+
return Mono.error(new IllegalStateException("Already connected"));
79+
}
80+
connected = true;
81+
return inbound.asFlux()
82+
.publishOn(Schedulers.boundedElastic())
83+
.flatMap(message -> Mono.just(message).transform(handler))
84+
.doFinally(signal -> connected = false)
85+
.then();
86+
}
87+
88+
@Override
89+
public Mono<Void> closeGracefully() {
90+
return Mono.defer(() -> {
91+
connected = false;
92+
outgoing.tryEmitComplete();
93+
inbound.tryEmitComplete();
94+
return Mono.empty();
95+
});
96+
}
97+
98+
@Override
99+
public <T> T unmarshalFrom(Object data, TypeReference<T> typeRef) {
100+
return new ObjectMapper().convertValue(data, typeRef);
101+
}
102+
103+
}

mcp/src/test/java/org/springframework/ai/mcp/client/McpAsyncClientResponseHandlerTests.java

Lines changed: 1 addition & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -20,100 +20,22 @@
2020
import java.util.ArrayList;
2121
import java.util.List;
2222
import java.util.Map;
23-
import java.util.concurrent.atomic.AtomicInteger;
2423
import java.util.function.Consumer;
2524
import java.util.function.Function;
2625

2726
import com.fasterxml.jackson.core.type.TypeReference;
28-
import com.fasterxml.jackson.databind.ObjectMapper;
2927
import org.junit.jupiter.api.Test;
30-
import reactor.core.publisher.Flux;
31-
import reactor.core.publisher.Mono;
32-
import reactor.core.publisher.Sinks;
33-
import reactor.core.scheduler.Schedulers;
3428

29+
import org.springframework.ai.mcp.MockMcpTransport;
3530
import org.springframework.ai.mcp.spec.McpSchema;
3631
import org.springframework.ai.mcp.spec.McpSchema.ClientCapabilities;
37-
import org.springframework.ai.mcp.spec.McpSchema.JSONRPCNotification;
38-
import org.springframework.ai.mcp.spec.McpSchema.JSONRPCRequest;
3932
import org.springframework.ai.mcp.spec.McpSchema.Root;
40-
import org.springframework.ai.mcp.spec.McpTransport;
4133

4234
import static org.assertj.core.api.Assertions.assertThat;
4335
import static org.awaitility.Awaitility.await;
4436

4537
class McpAsyncClientResponseHandlerTests {
4638

47-
@SuppressWarnings("unused")
48-
private static class MockMcpTransport implements McpTransport {
49-
50-
private final AtomicInteger inboundMessageCount = new AtomicInteger(0);
51-
52-
private final Sinks.Many<McpSchema.JSONRPCMessage> outgoing = Sinks.many().multicast().onBackpressureBuffer();
53-
54-
private final Sinks.Many<McpSchema.JSONRPCMessage> inbound = Sinks.many().unicast().onBackpressureBuffer();
55-
56-
private final Flux<McpSchema.JSONRPCMessage> outboundView = outgoing.asFlux().cache(1);
57-
58-
public void simulateIncomingMessage(McpSchema.JSONRPCMessage message) {
59-
if (inbound.tryEmitNext(message).isFailure()) {
60-
throw new RuntimeException("Failed to emit message " + message);
61-
}
62-
inboundMessageCount.incrementAndGet();
63-
}
64-
65-
@Override
66-
public Mono<Void> sendMessage(McpSchema.JSONRPCMessage message) {
67-
if (outgoing.tryEmitNext(message).isFailure()) {
68-
return Mono.error(new RuntimeException("Can't emit outgoing message " + message));
69-
}
70-
return Mono.empty();
71-
}
72-
73-
public McpSchema.JSONRPCRequest getLastSentMessageAsRequest() {
74-
return (JSONRPCRequest) outboundView.blockFirst();
75-
}
76-
77-
public McpSchema.JSONRPCNotification getLastSentMessageAsNotifiation() {
78-
return (JSONRPCNotification) outboundView.blockFirst();
79-
}
80-
81-
public McpSchema.JSONRPCMessage getLastSentMessage() {
82-
return outboundView.blockFirst();
83-
}
84-
85-
private volatile boolean connected = false;
86-
87-
@Override
88-
public Mono<Void> connect(Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>> handler) {
89-
if (connected) {
90-
return Mono.error(new IllegalStateException("Already connected"));
91-
}
92-
connected = true;
93-
return inbound.asFlux()
94-
.publishOn(Schedulers.boundedElastic())
95-
.flatMap(message -> Mono.just(message).transform(handler))
96-
.doFinally(signal -> connected = false)
97-
.then();
98-
}
99-
100-
@Override
101-
public Mono<Void> closeGracefully() {
102-
return Mono.defer(() -> {
103-
connected = false;
104-
outgoing.tryEmitComplete();
105-
inbound.tryEmitComplete();
106-
return Mono.empty();
107-
});
108-
}
109-
110-
@Override
111-
public <T> T unmarshalFrom(Object data, TypeReference<T> typeRef) {
112-
return new ObjectMapper().convertValue(data, typeRef);
113-
}
114-
115-
}
116-
11739
@Test
11840
void testToolsChangeNotificationHandling() {
11941
MockMcpTransport transport = new MockMcpTransport();

mcp/src/test/java/org/springframework/ai/mcp/server/AbstractMcpAsyncServerTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,11 @@
2525
import org.junit.jupiter.api.Test;
2626
import reactor.test.StepVerifier;
2727

28-
import org.springframework.ai.mcp.spec.McpSchema;
2928
import org.springframework.ai.mcp.server.McpServer.PromptRegistration;
3029
import org.springframework.ai.mcp.server.McpServer.ResourceRegistration;
3130
import org.springframework.ai.mcp.server.McpServer.ToolRegistration;
3231
import org.springframework.ai.mcp.spec.McpError;
32+
import org.springframework.ai.mcp.spec.McpSchema;
3333
import org.springframework.ai.mcp.spec.McpSchema.CallToolResult;
3434
import org.springframework.ai.mcp.spec.McpSchema.GetPromptResult;
3535
import org.springframework.ai.mcp.spec.McpSchema.Prompt;

0 commit comments

Comments
 (0)