Skip to content

Commit 122a06a

Browse files
committed
Properly handling session invalidation
1 parent a42f2bf commit 122a06a

File tree

6 files changed

+66
-14
lines changed

6 files changed

+66
-14
lines changed

mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java

Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
import org.springframework.http.MediaType;
1616
import org.springframework.http.codec.ServerSentEvent;
1717
import org.springframework.web.reactive.function.client.WebClient;
18+
import org.springframework.web.reactive.function.client.WebClientResponseException;
1819
import reactor.core.Disposable;
19-
import reactor.core.Disposables;
2020
import reactor.core.publisher.Flux;
2121
import reactor.core.publisher.Mono;
2222
import reactor.util.context.ContextView;
@@ -26,7 +26,6 @@
2626
import java.io.IOException;
2727
import java.util.List;
2828
import java.util.Optional;
29-
import java.util.concurrent.atomic.AtomicBoolean;
3029
import java.util.concurrent.atomic.AtomicLong;
3130
import java.util.concurrent.atomic.AtomicReference;
3231
import java.util.function.Consumer;
@@ -52,10 +51,10 @@ public class WebClientStreamableHttpTransport implements McpClientTransport {
5251

5352
private final boolean resumableStreams;
5453

55-
private AtomicReference<Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>>> handler = new AtomicReference<>();
56-
5754
private final AtomicReference<McpTransportSession> activeSession = new AtomicReference<>();
5855

56+
private final AtomicReference<Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>>> handler = new AtomicReference<>();
57+
5958
private final AtomicReference<Consumer<Throwable>> exceptionHandler = new AtomicReference<>();
6059

6160
public WebClientStreamableHttpTransport(ObjectMapper objectMapper, WebClient.Builder webClientBuilder,
@@ -88,6 +87,11 @@ public void registerExceptionHandler(Consumer<Throwable> handler) {
8887

8988
private void handleException(Throwable t) {
9089
logger.debug("Handling exception for session {}", activeSession.get().sessionId(), t);
90+
if (t instanceof McpSessionNotFoundException) {
91+
McpTransportSession invalidSession = this.activeSession.getAndSet(new McpTransportSession());
92+
logger.warn("Server does not recognize session {}. Invalidating.", invalidSession.sessionId());
93+
invalidSession.close();
94+
}
9195
Consumer<Throwable> handler = this.exceptionHandler.get();
9296
if (handler != null) {
9397
handler.accept(t);
@@ -106,6 +110,8 @@ public Mono<Void> closeGracefully() {
106110
});
107111
}
108112

113+
// FIXME: Avoid passing the ContextView - add hook allowing the Reactor Context to be
114+
// attached to the chain?
109115
private void reconnect(McpStream stream, ContextView ctx) {
110116
if (stream != null) {
111117
logger.debug("Reconnecting stream {} with lastId {}", stream.streamId(), stream.lastId());
@@ -273,8 +279,7 @@ else if (contentType.isCompatibleWith(MediaType.APPLICATION_JSON)) {
273279
else {
274280
logger.warn("Unknown media type {} returned for POST in session {}", contentType,
275281
transportSession.sessionId());
276-
sink.error(new RuntimeException("Unknown media type returned: " + contentType));
277-
return Flux.empty();
282+
return Flux.error(new RuntimeException("Unknown media type returned: " + contentType));
278283
}
279284
}
280285
else {
@@ -283,20 +288,45 @@ else if (contentType.isCompatibleWith(MediaType.APPLICATION_JSON)) {
283288

284289
McpSessionNotFoundException notFoundException = new McpSessionNotFoundException(
285290
transportSession.sessionId());
286-
// inform the caller of sendMessage
287-
sink.error(notFoundException);
288291
// inform the stream/connection subscriber
289292
return Flux.error(notFoundException);
290293
}
291-
return response.<McpSchema.JSONRPCMessage>createError().doOnError(e -> {
292-
sink.error(new RuntimeException("Sending request failed", e));
294+
return response.<McpSchema.JSONRPCMessage>createError().onErrorResume(e -> {
295+
WebClientResponseException responseException = (WebClientResponseException) e;
296+
byte[] body = responseException.getResponseBodyAsByteArray();
297+
McpSchema.JSONRPCResponse.JSONRPCError jsonRpcError = null;
298+
Exception toPropagate;
299+
try {
300+
McpSchema.JSONRPCResponse jsonRpcResponse = objectMapper.readValue(body,
301+
McpSchema.JSONRPCResponse.class);
302+
jsonRpcError = jsonRpcResponse.error();
303+
toPropagate = new McpError(jsonRpcError);
304+
}
305+
catch (IOException ex) {
306+
toPropagate = new RuntimeException("Sending request failed", e);
307+
logger.debug("Received content together with {} HTTP code response: {}",
308+
response.statusCode(), body);
309+
}
310+
311+
// Some implementations can return 400 when presented with a
312+
// session id that it doesn't know about, so we will
313+
// invalidate the session
314+
// https://github.com/modelcontextprotocol/typescript-sdk/issues/389
315+
if (responseException.getStatusCode().isSameCodeAs(HttpStatus.BAD_REQUEST)) {
316+
return Mono.error(new McpSessionNotFoundException(this.activeSession.get().sessionId(),
317+
toPropagate));
318+
}
319+
return Mono.empty();
293320
}).flux();
294321
}
295322
})
296323
.map(Mono::just)
297324
.flatMap(this.handler.get())
298325
.onErrorResume(t -> {
326+
// handle the error first
299327
this.handleException(t);
328+
329+
// inform the caller of sendMessage
300330
sink.error(t);
301331
return Flux.empty();
302332
})
@@ -321,7 +351,8 @@ public <T> T unmarshalFrom(Object data, TypeReference<T> typeRef) {
321351
private Tuple2<Optional<String>, Iterable<McpSchema.JSONRPCMessage>> parse(ServerSentEvent<String> event) {
322352
if (MESSAGE_EVENT_TYPE.equals(event.event())) {
323353
try {
324-
// TODO: support batching?
354+
// We don't support batching ATM and probably won't since the next version
355+
// considers removing it.
325356
McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(this.objectMapper, event.data());
326357
return Tuples.of(Optional.ofNullable(event.id()), List.of(message));
327358
}
@@ -340,6 +371,7 @@ private class McpStream {
340371

341372
private final AtomicReference<String> lastId = new AtomicReference<>();
342373

374+
// Used only for internal accounting
343375
private final long streamId;
344376

345377
private final boolean resumable;
@@ -360,8 +392,7 @@ long streamId() {
360392
Flux<McpSchema.JSONRPCMessage> consumeSseStream(
361393
Publisher<Tuple2<Optional<String>, Iterable<McpSchema.JSONRPCMessage>>> eventStream) {
362394
return Flux.deferContextual(ctx -> Flux.from(eventStream).doOnError(e -> {
363-
// TODO: examine which error :)
364-
if (resumable) {
395+
if (resumable && !(e instanceof McpSessionNotFoundException)) {
365396
reconnect(this, ctx);
366397
}
367398
})

mcp-spring/mcp-spring-webflux/src/test/resources/logback.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
</appender>
1010

1111
<!-- Main MCP package -->
12-
<logger name="org.springframework.ai.mcp" level="DEBUG"/>
12+
<logger name="io.modelcontextprotocol" level="DEBUG"/>
1313

1414
<!-- Client packages -->
1515
<logger name="org.springframework.ai.mcp.client" level="DEBUG"/>

mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,4 +159,18 @@ void testPing() {
159159
});
160160
}
161161

162+
@Test
163+
void testSessionInvalidation() {
164+
withClient(createMcpTransport(), mcpAsyncClient -> {
165+
StepVerifier.create(mcpAsyncClient.initialize()).expectNextCount(1).verifyComplete();
166+
167+
container.stop();
168+
container.start();
169+
170+
// The first try will face the session mismatch exception and the second one
171+
// will go through the re-initialization process.
172+
StepVerifier.create(mcpAsyncClient.ping().retry(1)).expectNextCount(1).verifyComplete();
173+
});
174+
}
175+
162176
}

mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ public class McpAsyncClient {
221221

222222
private void handleException(Throwable t) {
223223
if (t instanceof McpSessionNotFoundException) {
224+
this.initialization.set(null);
224225
this.initialize().subscribe();
225226
}
226227
}

mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,7 @@ public <T> Mono<T> sendRequest(String method, Object requestParams, TypeReferenc
252252
String requestId = this.generateRequestId();
253253

254254
return Mono.deferContextual(ctx -> Mono.<McpSchema.JSONRPCResponse>create(sink -> {
255+
logger.debug("Sending message for method {}", method);
255256
this.pendingResponses.put(requestId, sink);
256257
McpSchema.JSONRPCRequest jsonrpcRequest = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, method,
257258
requestId, requestParams);

mcp/src/main/java/io/modelcontextprotocol/spec/McpSessionNotFoundException.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22

33
public class McpSessionNotFoundException extends RuntimeException {
44

5+
public McpSessionNotFoundException(String sessionId, Exception cause) {
6+
super("Session " + sessionId + " not found on the server", cause);
7+
8+
}
9+
510
public McpSessionNotFoundException(String sessionId) {
611
super("Session " + sessionId + " not found on the server");
712
}

0 commit comments

Comments
 (0)