Skip to content

Commit e249145

Browse files
committed
WIP.. more resiliency
1 parent 40bc356 commit e249145

File tree

6 files changed

+194
-112
lines changed

6 files changed

+194
-112
lines changed

mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/Main.java

Lines changed: 32 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -7,44 +7,38 @@
77
import org.springframework.web.reactive.function.client.WebClient;
88

99
public class Main {
10-
public static void main(String[] args) {
11-
McpAsyncClient client = McpClient.async(
12-
new WebClientStreamableHttpTransport(new ObjectMapper(),
13-
WebClient.builder().baseUrl("http://localhost:3001"),
14-
"/mcp", true, false)
15-
).build();
1610

17-
/*
18-
Inspector does this:
19-
1. -> POST initialize request
20-
2. <- capabilities response (with sessionId)
21-
3. -> POST initialized notification
22-
4. -> GET initialize SSE connection (with sessionId)
11+
public static void main(String[] args) {
12+
McpAsyncClient client = McpClient
13+
.async(new WebClientStreamableHttpTransport(new ObjectMapper(),
14+
WebClient.builder().baseUrl("http://localhost:3001"), "/mcp", true, false))
15+
.build();
16+
17+
/*
18+
* Inspector does this: 1. -> POST initialize request 2. <- capabilities response
19+
* (with sessionId) 3. -> POST initialized notification 4. -> GET initialize SSE
20+
* connection (with sessionId)
21+
*
22+
* VS
23+
*
24+
* 1. -> GET initialize SSE connection 2. <- 2xx ok with sessionId 3. -> POST
25+
* initialize request 4. <- capabilities response 5. -> POST initialized
26+
* notification
27+
*
28+
*
29+
* SERVER-A + SERVER-B LOAD BALANCING between SERVER-A and SERVER-B STATELESS
30+
* SERVER
31+
*
32+
* 1. -> (A) POST initialize request 2. <- (A) 2xx ok with capabilities 3. -> (B)
33+
* POST initialized notification 4. -> (B) 2xx ok 5. -> (A or B) POST request
34+
* tools 6. -> 2xx response
35+
*/
36+
37+
client.initialize()
38+
.flatMap(r -> client.listTools())
39+
.map(McpSchema.ListToolsResult::tools)
40+
.doOnNext(System.out::println)
41+
.block();
42+
}
2343

24-
VS
25-
26-
1. -> GET initialize SSE connection
27-
2. <- 2xx ok with sessionId
28-
3. -> POST initialize request
29-
4. <- capabilities response
30-
5. -> POST initialized notification
31-
32-
33-
SERVER-A + SERVER-B
34-
LOAD BALANCING between SERVER-A and SERVER-B
35-
STATELESS SERVER
36-
37-
1. -> (A) POST initialize request
38-
2. <- (A) 2xx ok with capabilities
39-
3. -> (B) POST initialized notification
40-
4. -> (B) 2xx ok
41-
5. -> (A or B) POST request tools
42-
6. -> 2xx response
43-
*/
44-
45-
client.initialize().flatMap(r -> client.listTools())
46-
.map(McpSchema.ListToolsResult::tools)
47-
.doOnNext(System.out::println)
48-
.block();
49-
}
5044
}

mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java

Lines changed: 44 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,7 @@
22

33
import com.fasterxml.jackson.core.type.TypeReference;
44
import com.fasterxml.jackson.databind.ObjectMapper;
5-
import io.modelcontextprotocol.spec.McpClientTransport;
6-
import io.modelcontextprotocol.spec.McpError;
7-
import io.modelcontextprotocol.spec.McpSchema;
8-
import io.modelcontextprotocol.spec.McpSessionNotFoundException;
5+
import io.modelcontextprotocol.spec.*;
96
import org.reactivestreams.Publisher;
107
import org.slf4j.Logger;
118
import org.slf4j.LoggerFactory;
@@ -28,6 +25,7 @@
2825
import java.util.concurrent.atomic.AtomicBoolean;
2926
import java.util.concurrent.atomic.AtomicLong;
3027
import java.util.concurrent.atomic.AtomicReference;
28+
import java.util.function.Consumer;
3129
import java.util.function.Function;
3230

3331
public class WebClientStreamableHttpTransport implements McpClientTransport {
@@ -52,11 +50,9 @@ public class WebClientStreamableHttpTransport implements McpClientTransport {
5250

5351
private AtomicReference<Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>>> handler = new AtomicReference<>();
5452

55-
private final Disposable.Composite openConnections = Disposables.composite();
53+
private final AtomicReference<McpTransportSession> activeSession = new AtomicReference<>();
5654

57-
private final AtomicBoolean initialized = new AtomicBoolean();
58-
59-
private final AtomicReference<String> sessionId = new AtomicReference<>();
55+
private final AtomicReference<Consumer<Throwable>> exceptionHandler = new AtomicReference<>();
6056

6157
public WebClientStreamableHttpTransport(ObjectMapper objectMapper, WebClient.Builder webClientBuilder,
6258
String endpoint, boolean resumableStreams, boolean openConnectionOnStartup) {
@@ -65,14 +61,12 @@ public WebClientStreamableHttpTransport(ObjectMapper objectMapper, WebClient.Bui
6561
this.endpoint = endpoint;
6662
this.resumableStreams = resumableStreams;
6763
this.openConnectionOnStartup = openConnectionOnStartup;
64+
this.activeSession.set(new McpTransportSession());
6865
}
6966

7067
@Override
7168
public Mono<Void> connect(Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>> handler) {
7269
return Mono.deferContextual(ctx -> {
73-
if (this.openConnections.isDisposed()) {
74-
return Mono.error(new RuntimeException("Transport already disposed"));
75-
}
7670
this.handler.set(handler);
7771
if (openConnectionOnStartup) {
7872
this.reconnect(null, ctx);
@@ -81,9 +75,20 @@ public Mono<Void> connect(Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchem
8175
});
8276
}
8377

78+
@Override
79+
public void handleException(Consumer<Throwable> handler) {
80+
this.exceptionHandler.set(handler);
81+
}
82+
8483
@Override
8584
public Mono<Void> closeGracefully() {
86-
return Mono.fromRunnable(this.openConnections::dispose);
85+
return Mono.defer(() -> {
86+
McpTransportSession currentSession = this.activeSession.get();
87+
if (currentSession != null) {
88+
return currentSession.closeGracefully();
89+
}
90+
return Mono.empty();
91+
});
8792
}
8893

8994
private void reconnect(McpStream stream, ContextView ctx) {
@@ -93,12 +98,13 @@ private void reconnect(McpStream stream, ContextView ctx) {
9398
// listen for messages.
9499
// If it doesn't, nothing actually happens here, that's just the way it is...
95100
final AtomicReference<Disposable> disposableRef = new AtomicReference<>();
101+
final McpTransportSession transportSession = this.activeSession.get();
96102
Disposable connection = webClient.get()
97103
.uri(this.endpoint)
98104
.accept(MediaType.TEXT_EVENT_STREAM)
99105
.headers(httpHeaders -> {
100-
if (sessionId.get() != null) {
101-
httpHeaders.add("mcp-session-id", sessionId.get());
106+
if (transportSession.sessionId() != null) {
107+
httpHeaders.add("mcp-session-id", transportSession.sessionId());
102108
}
103109
if (stream != null && stream.lastId() != null) {
104110
httpHeaders.add("last-event-id", stream.lastId());
@@ -123,22 +129,33 @@ else if (response.statusCode().isSameCodeAs(HttpStatus.METHOD_NOT_ALLOWED)) {
123129
logger.info("The server does not support SSE streams, using request-response mode.");
124130
return Flux.empty();
125131
}
132+
else if (response.statusCode().isSameCodeAs(HttpStatus.NOT_FOUND)) {
133+
logger.info("Session {} was not found on the MCP server", transportSession.sessionId());
134+
135+
McpSessionNotFoundException notFoundException = new McpSessionNotFoundException(
136+
"Session " + transportSession.sessionId() + " not found");
137+
// inform the stream/connection subscriber
138+
return Flux.error(notFoundException);
139+
}
126140
else {
127141
return response.<McpSchema.JSONRPCMessage>createError().doOnError(e -> {
128142
logger.info("Opening an SSE stream failed. This can be safely ignored.", e);
129143
}).flux();
130144
}
131145
})
146+
.doOnError(e -> {
147+
this.exceptionHandler.get().accept(e);
148+
})
132149
.doFinally(s -> {
133150
Disposable ref = disposableRef.getAndSet(null);
134151
if (ref != null) {
135-
this.openConnections.remove(ref);
152+
transportSession.removeConnection(ref);
136153
}
137154
})
138155
.contextWrite(ctx)
139156
.subscribe();
140157
disposableRef.set(connection);
141-
this.openConnections.add(connection);
158+
transportSession.addConnection(connection);
142159
}
143160

144161
@Override
@@ -151,20 +168,22 @@ public Mono<Void> sendMessage(McpSchema.JSONRPCMessage message) {
151168
// listen for messages.
152169
// If it doesn't, nothing actually happens here, that's just the way it is...
153170
final AtomicReference<Disposable> disposableRef = new AtomicReference<>();
171+
final McpTransportSession transportSession = this.activeSession.get();
172+
154173
Disposable connection = webClient.post()
155174
.uri(this.endpoint)
156175
.accept(MediaType.TEXT_EVENT_STREAM, MediaType.APPLICATION_JSON)
157176
.headers(httpHeaders -> {
158-
if (sessionId.get() != null) {
159-
httpHeaders.add("mcp-session-id", sessionId.get());
177+
if (transportSession.sessionId() != null) {
178+
httpHeaders.add("mcp-session-id", transportSession.sessionId());
160179
}
161180
})
162181
.bodyValue(message)
163182
.exchangeToFlux(response -> {
164-
// TODO: this goes into the request phase
165-
if (!initialized.compareAndExchange(false, true)) {
183+
if (transportSession.markInitialized()) {
166184
if (!response.headers().header("mcp-session-id").isEmpty()) {
167-
sessionId.set(response.headers().asHttpHeaders().getFirst("mcp-session-id"));
185+
transportSession
186+
.setSessionId(response.headers().asHttpHeaders().getFirst("mcp-session-id"));
168187
// Once we have a session, we try to open an async stream for
169188
// the server to send notifications and requests out-of-band.
170189
reconnect(null, sink.contextView());
@@ -176,10 +195,10 @@ public Mono<Void> sendMessage(McpSchema.JSONRPCMessage message) {
176195
// if (!response.statusCode().isSameCodeAs(HttpStatus.ACCEPTED)) {
177196
if (!response.statusCode().is2xxSuccessful()) {
178197
if (response.statusCode().isSameCodeAs(HttpStatus.NOT_FOUND)) {
179-
logger.info("Session {} was not found on the MCP server", sessionId.get());
198+
logger.info("Session {} was not found on the MCP server", transportSession.sessionId());
180199

181200
McpSessionNotFoundException notFoundException = new McpSessionNotFoundException(
182-
"Session " + sessionId.get() + " not found");
201+
"Session " + transportSession.sessionId() + " not found");
183202
// inform the caller of sendMessage
184203
sink.error(notFoundException);
185204
// inform the stream/connection subscriber
@@ -233,8 +252,6 @@ else if (contentType.isCompatibleWith(MediaType.APPLICATION_JSON)) {
233252
}
234253
})
235254
.flatMapIterable(Function.identity());
236-
// .map(Mono::just)
237-
// .flatMap(this.handler.get());
238255
}
239256
else {
240257
sink.error(new RuntimeException("Unknown media type"));
@@ -246,13 +263,13 @@ else if (contentType.isCompatibleWith(MediaType.APPLICATION_JSON)) {
246263
.doFinally(s -> {
247264
Disposable ref = disposableRef.getAndSet(null);
248265
if (ref != null) {
249-
this.openConnections.remove(ref);
266+
transportSession.removeConnection(ref);
250267
}
251268
})
252269
.contextWrite(sink.contextView())
253270
.subscribe();
254271
disposableRef.set(connection);
255-
this.openConnections.add(connection);
272+
transportSession.addConnection(connection);
256273
});
257274
}
258275

0 commit comments

Comments
 (0)