15
15
import org .springframework .http .MediaType ;
16
16
import org .springframework .http .codec .ServerSentEvent ;
17
17
import org .springframework .web .reactive .function .client .WebClient ;
18
+ import org .springframework .web .reactive .function .client .WebClientResponseException ;
18
19
import reactor .core .Disposable ;
19
- import reactor .core .Disposables ;
20
20
import reactor .core .publisher .Flux ;
21
21
import reactor .core .publisher .Mono ;
22
22
import reactor .util .context .ContextView ;
26
26
import java .io .IOException ;
27
27
import java .util .List ;
28
28
import java .util .Optional ;
29
- import java .util .concurrent .atomic .AtomicBoolean ;
30
29
import java .util .concurrent .atomic .AtomicLong ;
31
30
import java .util .concurrent .atomic .AtomicReference ;
32
31
import java .util .function .Consumer ;
@@ -52,10 +51,10 @@ public class WebClientStreamableHttpTransport implements McpClientTransport {
52
51
53
52
private final boolean resumableStreams ;
54
53
55
- private AtomicReference <Function <Mono <McpSchema .JSONRPCMessage >, Mono <McpSchema .JSONRPCMessage >>> handler = new AtomicReference <>();
56
-
57
54
private final AtomicReference <McpTransportSession > activeSession = new AtomicReference <>();
58
55
56
+ private final AtomicReference <Function <Mono <McpSchema .JSONRPCMessage >, Mono <McpSchema .JSONRPCMessage >>> handler = new AtomicReference <>();
57
+
59
58
private final AtomicReference <Consumer <Throwable >> exceptionHandler = new AtomicReference <>();
60
59
61
60
public WebClientStreamableHttpTransport (ObjectMapper objectMapper , WebClient .Builder webClientBuilder ,
@@ -88,6 +87,11 @@ public void registerExceptionHandler(Consumer<Throwable> handler) {
88
87
89
88
private void handleException (Throwable t ) {
90
89
logger .debug ("Handling exception for session {}" , activeSession .get ().sessionId (), t );
90
+ if (t instanceof McpSessionNotFoundException ) {
91
+ McpTransportSession invalidSession = this .activeSession .getAndSet (new McpTransportSession ());
92
+ logger .warn ("Server does not recognize session {}. Invalidating." , invalidSession .sessionId ());
93
+ invalidSession .close ();
94
+ }
91
95
Consumer <Throwable > handler = this .exceptionHandler .get ();
92
96
if (handler != null ) {
93
97
handler .accept (t );
@@ -106,6 +110,8 @@ public Mono<Void> closeGracefully() {
106
110
});
107
111
}
108
112
113
+ // FIXME: Avoid passing the ContextView - add hook allowing the Reactor Context to be
114
+ // attached to the chain?
109
115
private void reconnect (McpStream stream , ContextView ctx ) {
110
116
if (stream != null ) {
111
117
logger .debug ("Reconnecting stream {} with lastId {}" , stream .streamId (), stream .lastId ());
@@ -273,8 +279,7 @@ else if (contentType.isCompatibleWith(MediaType.APPLICATION_JSON)) {
273
279
else {
274
280
logger .warn ("Unknown media type {} returned for POST in session {}" , contentType ,
275
281
transportSession .sessionId ());
276
- sink .error (new RuntimeException ("Unknown media type returned: " + contentType ));
277
- return Flux .empty ();
282
+ return Flux .error (new RuntimeException ("Unknown media type returned: " + contentType ));
278
283
}
279
284
}
280
285
else {
@@ -283,20 +288,45 @@ else if (contentType.isCompatibleWith(MediaType.APPLICATION_JSON)) {
283
288
284
289
McpSessionNotFoundException notFoundException = new McpSessionNotFoundException (
285
290
transportSession .sessionId ());
286
- // inform the caller of sendMessage
287
- sink .error (notFoundException );
288
291
// inform the stream/connection subscriber
289
292
return Flux .error (notFoundException );
290
293
}
291
- return response .<McpSchema .JSONRPCMessage >createError ().doOnError (e -> {
292
- sink .error (new RuntimeException ("Sending request failed" , e ));
294
+ return response .<McpSchema .JSONRPCMessage >createError ().onErrorResume (e -> {
295
+ WebClientResponseException responseException = (WebClientResponseException ) e ;
296
+ byte [] body = responseException .getResponseBodyAsByteArray ();
297
+ McpSchema .JSONRPCResponse .JSONRPCError jsonRpcError = null ;
298
+ Exception toPropagate ;
299
+ try {
300
+ McpSchema .JSONRPCResponse jsonRpcResponse = objectMapper .readValue (body ,
301
+ McpSchema .JSONRPCResponse .class );
302
+ jsonRpcError = jsonRpcResponse .error ();
303
+ toPropagate = new McpError (jsonRpcError );
304
+ }
305
+ catch (IOException ex ) {
306
+ toPropagate = new RuntimeException ("Sending request failed" , e );
307
+ logger .debug ("Received content together with {} HTTP code response: {}" ,
308
+ response .statusCode (), body );
309
+ }
310
+
311
+ // Some implementations can return 400 when presented with a
312
+ // session id that it doesn't know about, so we will
313
+ // invalidate the session
314
+ // https://github.com/modelcontextprotocol/typescript-sdk/issues/389
315
+ if (responseException .getStatusCode ().isSameCodeAs (HttpStatus .BAD_REQUEST )) {
316
+ return Mono .error (new McpSessionNotFoundException (this .activeSession .get ().sessionId (),
317
+ toPropagate ));
318
+ }
319
+ return Mono .empty ();
293
320
}).flux ();
294
321
}
295
322
})
296
323
.map (Mono ::just )
297
324
.flatMap (this .handler .get ())
298
325
.onErrorResume (t -> {
326
+ // handle the error first
299
327
this .handleException (t );
328
+
329
+ // inform the caller of sendMessage
300
330
sink .error (t );
301
331
return Flux .empty ();
302
332
})
@@ -321,7 +351,8 @@ public <T> T unmarshalFrom(Object data, TypeReference<T> typeRef) {
321
351
private Tuple2 <Optional <String >, Iterable <McpSchema .JSONRPCMessage >> parse (ServerSentEvent <String > event ) {
322
352
if (MESSAGE_EVENT_TYPE .equals (event .event ())) {
323
353
try {
324
- // TODO: support batching?
354
+ // We don't support batching ATM and probably won't since the next version
355
+ // considers removing it.
325
356
McpSchema .JSONRPCMessage message = McpSchema .deserializeJsonRpcMessage (this .objectMapper , event .data ());
326
357
return Tuples .of (Optional .ofNullable (event .id ()), List .of (message ));
327
358
}
@@ -340,6 +371,7 @@ private class McpStream {
340
371
341
372
private final AtomicReference <String > lastId = new AtomicReference <>();
342
373
374
+ // Used only for internal accounting
343
375
private final long streamId ;
344
376
345
377
private final boolean resumable ;
@@ -360,8 +392,7 @@ long streamId() {
360
392
Flux <McpSchema .JSONRPCMessage > consumeSseStream (
361
393
Publisher <Tuple2 <Optional <String >, Iterable <McpSchema .JSONRPCMessage >>> eventStream ) {
362
394
return Flux .deferContextual (ctx -> Flux .from (eventStream ).doOnError (e -> {
363
- // TODO: examine which error :)
364
- if (resumable ) {
395
+ if (resumable && !(e instanceof McpSessionNotFoundException )) {
365
396
reconnect (this , ctx );
366
397
}
367
398
})
0 commit comments