Skip to content

Commit 51b75dc

Browse files
committed
feat: enforce sampling handler presence and add CreateMessageResult builder
Improves error handling by enforcing sampling handler presence when sampling capabilities are enabled. Adds a builder pattern for CreateMessageResult to improve API ergonomics.
1 parent 07769a6 commit 51b75dc

File tree

4 files changed

+64
-28
lines changed

4 files changed

+64
-28
lines changed

mcp/src/main/java/org/springframework/ai/mcp/client/McpAsyncClient.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,10 @@ public McpAsyncClient(McpTransport transport, Duration requestTimeout, Implement
160160
}
161161

162162
// Sampling Handler
163-
if (samplingHandler != null && this.clientCapabilities.sampling() != null) {
163+
if (this.clientCapabilities.sampling() != null) {
164+
if (samplingHandler == null) {
165+
throw new McpError("Sampling handler must not be null when client capabilities include sampling");
166+
}
164167
this.samplingHandler = samplingHandler;
165168
requestHanlers.put(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE, samplingCreateMessageHandler());
166169
}

mcp/src/main/java/org/springframework/ai/mcp/spec/McpSchema.java

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -761,6 +761,46 @@ public enum StopReason {
761761
@JsonProperty("stop_sequence") STOP_SEQUENCE,
762762
@JsonProperty("max_tokens") MAX_TOKENS
763763
}
764+
765+
public static Builder builder() {
766+
return new Builder();
767+
}
768+
769+
public static class Builder {
770+
private Role role = Role.ASSISTANT;
771+
private Content content;
772+
private String model;
773+
private StopReason stopReason = StopReason.END_TURN;
774+
775+
public Builder role(Role role) {
776+
this.role = role;
777+
return this;
778+
}
779+
780+
public Builder content(Content content) {
781+
this.content = content;
782+
return this;
783+
}
784+
785+
public Builder model(String model) {
786+
this.model = model;
787+
return this;
788+
}
789+
790+
public Builder stopReason(StopReason stopReason) {
791+
this.stopReason = stopReason;
792+
return this;
793+
}
794+
795+
public Builder message(String message) {
796+
this.content = new TextContent(message);
797+
return this;
798+
}
799+
800+
public CreateMessageResult build() {
801+
return new CreateMessageResult(role, content, model, stopReason);
802+
}
803+
}
764804
}// @formatter:on
765805

766806
// ---------------------------

mcp/src/test/java/org/springframework/ai/mcp/client/AbstractMcpAsyncClientTests.java

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import java.time.Duration;
2020
import java.util.Map;
2121
import java.util.concurrent.atomic.AtomicBoolean;
22+
import java.util.function.Function;
2223

2324
import org.junit.jupiter.api.AfterEach;
2425
import org.junit.jupiter.api.BeforeEach;
@@ -29,6 +30,8 @@
2930
import org.springframework.ai.mcp.spec.McpSchema;
3031
import org.springframework.ai.mcp.spec.McpSchema.CallToolRequest;
3132
import org.springframework.ai.mcp.spec.McpSchema.ClientCapabilities;
33+
import org.springframework.ai.mcp.spec.McpSchema.CreateMessageRequest;
34+
import org.springframework.ai.mcp.spec.McpSchema.CreateMessageResult;
3235
import org.springframework.ai.mcp.spec.McpSchema.GetPromptRequest;
3336
import org.springframework.ai.mcp.spec.McpSchema.Prompt;
3437
import org.springframework.ai.mcp.spec.McpSchema.Resource;
@@ -295,7 +298,9 @@ void testInitializeWithSamplingCapability() {
295298

296299
var capabilities = ClientCapabilities.builder().sampling().build();
297300

298-
var client = McpClient.using(transport).requestTimeout(TIMEOUT).capabilities(capabilities).async();
301+
var client = McpClient.using(transport).requestTimeout(TIMEOUT).capabilities(capabilities).sampling(request -> {
302+
return CreateMessageResult.builder().message("test").model("test-model").build();
303+
}).async();
299304

300305
assertThatCode(() -> {
301306
client.initialize().block(Duration.ofSeconds(10));
@@ -313,7 +318,14 @@ void testInitializeWithAllCapabilities() {
313318
.sampling()
314319
.build();
315320

316-
var client = McpClient.using(transport).requestTimeout(TIMEOUT).capabilities(capabilities).async();
321+
Function<CreateMessageRequest, CreateMessageResult> samplingHandler = request -> {
322+
return CreateMessageResult.builder().message("test").model("test-model").build();
323+
};
324+
var client = McpClient.using(transport)
325+
.requestTimeout(TIMEOUT)
326+
.capabilities(capabilities)
327+
.sampling(samplingHandler)
328+
.async();
317329

318330
assertThatCode(() -> {
319331
var result = client.initialize().block(Duration.ofSeconds(10));

mcp/src/test/java/org/springframework/ai/mcp/client/McpAsyncClientResponseHandlerTests.java

Lines changed: 6 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,13 @@
2727
import org.junit.jupiter.api.Test;
2828

2929
import org.springframework.ai.mcp.MockMcpTransport;
30+
import org.springframework.ai.mcp.spec.McpError;
3031
import org.springframework.ai.mcp.spec.McpSchema;
3132
import org.springframework.ai.mcp.spec.McpSchema.ClientCapabilities;
3233
import org.springframework.ai.mcp.spec.McpSchema.Root;
3334

3435
import static org.assertj.core.api.Assertions.assertThat;
36+
import static org.assertj.core.api.Assertions.assertThatThrownBy;
3537
import static org.awaitility.Awaitility.await;
3638

3739
class McpAsyncClientResponseHandlerTests {
@@ -285,31 +287,10 @@ void testSamplingCreateMessageRequestHandlingWithNullHandler() {
285287
MockMcpTransport transport = new MockMcpTransport();
286288

287289
// Create client with sampling capability but null handler
288-
McpAsyncClient asyncMcpClient = McpClient.using(transport)
289-
.capabilities(ClientCapabilities.builder().sampling().build())
290-
.async();
291-
292-
// Create a mock create message request
293-
var messageRequest = new McpSchema.CreateMessageRequest(
294-
List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Test message"))),
295-
null, null, null, null, 0, null, null);
296-
297-
// Simulate incoming request
298-
McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION,
299-
McpSchema.METHOD_SAMPLING_CREATE_MESSAGE, "test-id", messageRequest);
300-
transport.simulateIncomingMessage(request);
301-
302-
// Verify error response
303-
McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessage();
304-
assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCResponse.class);
305-
306-
McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) sentMessage;
307-
assertThat(response.id()).isEqualTo("test-id");
308-
assertThat(response.result()).isNull();
309-
assertThat(response.error()).isNotNull();
310-
assertThat(response.error().message()).contains("Method not found: sampling/createMessage");
311-
312-
asyncMcpClient.closeGracefully();
290+
assertThatThrownBy(
291+
() -> McpClient.using(transport).capabilities(ClientCapabilities.builder().sampling().build()).async())
292+
.isInstanceOf(McpError.class)
293+
.hasMessage("Sampling handler must not be null when client capabilities include sampling");
313294
}
314295

315296
}

0 commit comments

Comments
 (0)