Skip to content

Commit 53012ad

Browse files
committed
WIP
1 parent 2e13f9f commit 53012ad

File tree

7 files changed

+412
-5
lines changed

7 files changed

+412
-5
lines changed
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
package io.modelcontextprotocol.client.transport;
2+
3+
import com.fasterxml.jackson.databind.ObjectMapper;
4+
import io.modelcontextprotocol.client.McpAsyncClient;
5+
import io.modelcontextprotocol.client.McpClient;
6+
import io.modelcontextprotocol.spec.McpSchema;
7+
import org.springframework.web.reactive.function.client.WebClient;
8+
9+
public class Main {
10+
public static void main(String[] args) {
11+
McpAsyncClient client = McpClient.async(
12+
new WebClientStreamableHttpTransport(new ObjectMapper(),
13+
WebClient.builder().baseUrl("http://localhost:3001"),
14+
"/mcp", true, false)
15+
).build();
16+
17+
/*
18+
Inspector does this:
19+
1. -> POST initialize request
20+
2. <- capabilities response (with sessionId)
21+
3. -> POST initialized notification
22+
4. -> GET initialize SSE connection (with sessionId)
23+
24+
VS
25+
26+
1. -> GET initialize SSE connection
27+
2. <- 2xx ok with sessionId
28+
3. -> POST initialize request
29+
4. <- capabilities response
30+
5. -> POST initialized notification
31+
32+
33+
SERVER-A + SERVER-B
34+
LOAD BALANCING between SERVER-A and SERVER-B
35+
STATELESS SERVER
36+
37+
1. -> (A) POST initialize request
38+
2. <- (A) 2xx ok with capabilities
39+
3. -> (B) POST initialized notification
40+
4. -> (B) 2xx ok
41+
5. -> (A or B) POST request tools
42+
6. -> 2xx response
43+
*/
44+
45+
client.initialize().flatMap(r -> client.listTools())
46+
.map(McpSchema.ListToolsResult::tools)
47+
.doOnNext(System.out::println)
48+
.block();
49+
}
50+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,291 @@
1+
package io.modelcontextprotocol.client.transport;
2+
3+
import com.fasterxml.jackson.core.type.TypeReference;
4+
import com.fasterxml.jackson.databind.ObjectMapper;
5+
import io.modelcontextprotocol.spec.McpClientTransport;
6+
import io.modelcontextprotocol.spec.McpError;
7+
import io.modelcontextprotocol.spec.McpSchema;
8+
import io.modelcontextprotocol.spec.McpSessionNotFoundException;
9+
import org.reactivestreams.Publisher;
10+
import org.slf4j.Logger;
11+
import org.slf4j.LoggerFactory;
12+
import org.springframework.core.ParameterizedTypeReference;
13+
import org.springframework.http.HttpStatus;
14+
import org.springframework.http.MediaType;
15+
import org.springframework.http.codec.ServerSentEvent;
16+
import org.springframework.web.reactive.function.client.WebClient;
17+
import reactor.core.Disposable;
18+
import reactor.core.Disposables;
19+
import reactor.core.publisher.Flux;
20+
import reactor.core.publisher.Mono;
21+
import reactor.util.context.ContextView;
22+
import reactor.util.function.Tuple2;
23+
import reactor.util.function.Tuples;
24+
25+
import java.io.IOException;
26+
import java.util.List;
27+
import java.util.Optional;
28+
import java.util.concurrent.atomic.AtomicBoolean;
29+
import java.util.concurrent.atomic.AtomicLong;
30+
import java.util.concurrent.atomic.AtomicReference;
31+
import java.util.function.Function;
32+
33+
public class WebClientStreamableHttpTransport implements McpClientTransport {
34+
35+
private static final Logger logger = LoggerFactory.getLogger(WebClientStreamableHttpTransport.class);
36+
37+
/**
38+
* Event type for JSON-RPC messages received through the SSE connection. The server
39+
* sends messages with this event type to transmit JSON-RPC protocol data.
40+
*/
41+
private static final String MESSAGE_EVENT_TYPE = "message";
42+
43+
private final ObjectMapper objectMapper;
44+
private final WebClient webClient;
45+
private final String endpoint;
46+
private final boolean openConnectionOnStartup;
47+
private final boolean resumableStreams;
48+
49+
private AtomicReference<Function<Mono<McpSchema.JSONRPCMessage>,
50+
Mono<McpSchema.JSONRPCMessage>>> handler = new AtomicReference<>();
51+
52+
private final Disposable.Composite openConnections = Disposables.composite();
53+
private final AtomicBoolean initialized = new AtomicBoolean();
54+
private final AtomicReference<String> sessionId = new AtomicReference<>();
55+
56+
public WebClientStreamableHttpTransport(
57+
ObjectMapper objectMapper,
58+
WebClient.Builder webClientBuilder,
59+
String endpoint,
60+
boolean resumableStreams,
61+
boolean openConnectionOnStartup) {
62+
this.objectMapper = objectMapper;
63+
this.webClient = webClientBuilder.build();
64+
this.endpoint = endpoint;
65+
this.resumableStreams = resumableStreams;
66+
this.openConnectionOnStartup = openConnectionOnStartup;
67+
}
68+
69+
@Override
70+
public Mono<Void> connect(Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>> handler) {
71+
if (this.openConnections.isDisposed()) {
72+
return Mono.error(new RuntimeException("Transport already disposed"));
73+
}
74+
this.handler.set(handler);
75+
return openConnectionOnStartup ? startOrResumeSession(null) : Mono.empty();
76+
}
77+
78+
@Override
79+
public Mono<Void> closeGracefully() {
80+
return Mono.fromRunnable(this.openConnections::dispose);
81+
}
82+
83+
private void reconnect(McpStream stream, ContextView ctx) {
84+
Disposable connection = this.startOrResumeSession(stream)
85+
.contextWrite(ctx)
86+
.subscribe();
87+
this.openConnections.add(connection);
88+
}
89+
90+
private Mono<Void> startOrResumeSession(McpStream stream) {
91+
return Mono.create(sink -> {
92+
// Here we attempt to initialize the client.
93+
// In case the server supports SSE, we will establish a long-running session here and
94+
// listen for messages.
95+
// If it doesn't, nothing actually happens here, that's just the way it is...
96+
97+
Disposable connection = webClient.get()
98+
.uri(this.endpoint)
99+
.accept(MediaType.TEXT_EVENT_STREAM)
100+
.headers(httpHeaders -> {
101+
if (sessionId.get() != null) {
102+
httpHeaders.add("mcp-session-id", sessionId.get());
103+
}
104+
if (stream != null && stream.lastId() != null) {
105+
httpHeaders.add("last-event-id", stream.lastId());
106+
}
107+
})
108+
.exchangeToFlux(response -> {
109+
// Per spec, we are not checking whether it's 2xx, but only if the Accept header is proper.
110+
if (response.headers().contentType().isPresent()
111+
&& response.headers().contentType().get().isCompatibleWith(MediaType.TEXT_EVENT_STREAM)) {
112+
113+
sink.success();
114+
115+
McpStream sessionStream = stream != null ? stream : new McpStream(this.resumableStreams);
116+
117+
Flux<Tuple2<Optional<String>, Iterable<McpSchema.JSONRPCMessage>>> idWithMessages =
118+
response.bodyToFlux(new ParameterizedTypeReference<ServerSentEvent<String>>() {
119+
}).map(this::parse);
120+
121+
return sessionStream.consumeSseStream(idWithMessages);
122+
} else if (response.statusCode().isSameCodeAs(HttpStatus.METHOD_NOT_ALLOWED)) {
123+
sink.success();
124+
logger.info("The server does not support SSE streams, using request-response mode.");
125+
return Flux.empty();
126+
} else {
127+
return response.<McpSchema.JSONRPCMessage>createError().doOnError(e -> {
128+
sink.error(new RuntimeException("Connection on client startup failed", e));
129+
}).flux();
130+
}
131+
})
132+
// TODO: Consider retries - examine cause to decide whether a retry is needed.
133+
.contextWrite(sink.contextView())
134+
.subscribe();
135+
this.openConnections.add(connection);
136+
});
137+
}
138+
139+
@Override
140+
public Mono<Void> sendMessage(McpSchema.JSONRPCMessage message) {
141+
return Mono.create(sink -> {
142+
System.out.println("Sending message " + message);
143+
// Here we attempt to initialize the client.
144+
// In case the server supports SSE, we will establish a long-running session here and
145+
// listen for messages.
146+
// If it doesn't, nothing actually happens here, that's just the way it is...
147+
Disposable connection = webClient.post()
148+
.uri(this.endpoint)
149+
.accept(MediaType.TEXT_EVENT_STREAM, MediaType.APPLICATION_JSON)
150+
.headers(httpHeaders -> {
151+
if (sessionId.get() != null) {
152+
httpHeaders.add("mcp-session-id", sessionId.get());
153+
}
154+
})
155+
.bodyValue(message)
156+
.exchangeToFlux(response -> {
157+
// TODO: this goes into the request phase
158+
if (!initialized.compareAndExchange(false, true)) {
159+
if (!response.headers().header("mcp-session-id").isEmpty()) {
160+
sessionId.set(response.headers().asHttpHeaders().getFirst("mcp-session-id"));
161+
// Once we have a session, we try to open an async stream for the server to send notifications and requests out-of-band.
162+
startOrResumeSession(null)
163+
.contextWrite(sink.contextView())
164+
.subscribe();
165+
}
166+
}
167+
168+
// The spec mentions only ACCEPTED, but the existing SDKs can return 200 OK for notifications
169+
// if (!response.statusCode().isSameCodeAs(HttpStatus.ACCEPTED)) {
170+
if (!response.statusCode().is2xxSuccessful()) {
171+
if (response.statusCode().isSameCodeAs(HttpStatus.NOT_FOUND)) {
172+
logger.info("Session {} was not found on the MCP server", sessionId.get());
173+
174+
McpSessionNotFoundException notFoundException = new McpSessionNotFoundException("Session " + sessionId.get() + " not found");
175+
// inform the caller of sendMessage
176+
sink.error(notFoundException);
177+
// inform the stream/connection subscriber
178+
return Flux.error(notFoundException);
179+
}
180+
return response.<McpSchema.JSONRPCMessage>createError().doOnError(e -> {
181+
sink.error(new RuntimeException("Sending request failed", e));
182+
}).flux();
183+
}
184+
185+
// Existing SDKs consume notifications with no response body nor content type
186+
if (response.headers().contentType().isEmpty()) {
187+
sink.success();
188+
return Flux.empty();
189+
// return response.<McpSchema.JSONRPCMessage>createError().doOnError(e -> {
190+
//// sink.error(new RuntimeException("Response has no content type"));
191+
// }).flux();
192+
}
193+
194+
MediaType contentType = response.headers().contentType().get();
195+
196+
if (contentType.isCompatibleWith(MediaType.TEXT_EVENT_STREAM)) {
197+
sink.success();
198+
McpStream sessionStream = new McpStream(this.resumableStreams);
199+
200+
Flux<Tuple2<Optional<String>, Iterable<McpSchema.JSONRPCMessage>>> idWithMessages =
201+
response.bodyToFlux(new ParameterizedTypeReference<ServerSentEvent<String>>() {
202+
}).map(this::parse);
203+
204+
return sessionStream.consumeSseStream(idWithMessages);
205+
} else if (contentType.isCompatibleWith(MediaType.APPLICATION_JSON)) {
206+
sink.success();
207+
// return response.bodyToMono(new ParameterizedTypeReference<Iterable<McpSchema.JSONRPCMessage>>() {});
208+
return response.bodyToMono(String.class)
209+
.<Iterable<McpSchema.JSONRPCMessage>>handle((responseMessage, s) -> {
210+
try {
211+
McpSchema.JSONRPCMessage jsonRpcResponse = McpSchema.deserializeJsonRpcMessage(objectMapper, responseMessage);
212+
s.next(List.of(jsonRpcResponse));
213+
} catch (IOException e) {
214+
s.error(e);
215+
}
216+
})
217+
.flatMapIterable(Function.identity());
218+
// .map(Mono::just)
219+
// .flatMap(this.handler.get());
220+
} else {
221+
sink.error(new RuntimeException("Unknown media type"));
222+
return Flux.empty();
223+
}
224+
})
225+
.map(Mono::just)
226+
.flatMap(this.handler.get())
227+
// TODO: Consider retries - examine cause to decide whether a retry is needed.
228+
.contextWrite(sink.contextView())
229+
.subscribe();
230+
this.openConnections.add(connection);
231+
});
232+
}
233+
234+
@Override
235+
public <T> T unmarshalFrom(Object data, TypeReference<T> typeRef) {
236+
return this.objectMapper.convertValue(data, typeRef);
237+
}
238+
239+
private Tuple2<Optional<String>, Iterable<McpSchema.JSONRPCMessage>> parse(ServerSentEvent<String> event) {
240+
if (MESSAGE_EVENT_TYPE.equals(event.event())) {
241+
try {
242+
// TODO: support batching
243+
McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(this.objectMapper, event.data());
244+
return Tuples.of(Optional.ofNullable(event.id()), List.of(message));
245+
}
246+
catch (IOException ioException) {
247+
throw new McpError("Error parsing JSON-RPC message: " + event.data());
248+
}
249+
}
250+
else {
251+
throw new McpError("Received unrecognized SSE event type: " + event.event());
252+
}
253+
}
254+
255+
private class McpStream {
256+
257+
private static final AtomicLong counter = new AtomicLong();
258+
259+
private final AtomicReference<String> lastId = new AtomicReference<>();
260+
261+
private final long streamId;
262+
private final boolean resumable;
263+
264+
McpStream(boolean resumable) {
265+
this.streamId = counter.getAndIncrement();
266+
this.resumable = resumable;
267+
}
268+
269+
String lastId() {
270+
return this.lastId.get();
271+
}
272+
273+
Flux<McpSchema.JSONRPCMessage> consumeSseStream(Publisher<Tuple2<Optional<String>, Iterable<McpSchema.JSONRPCMessage>>> eventStream) {
274+
return Flux.deferContextual(ctx ->
275+
Flux.from(eventStream)
276+
.doOnError(e -> {
277+
// TODO: examine which error :)
278+
if (resumable) {
279+
Disposable connection = WebClientStreamableHttpTransport.this.startOrResumeSession(this)
280+
.contextWrite(ctx)
281+
.subscribe();
282+
WebClientStreamableHttpTransport.this.openConnections.add(connection);
283+
}
284+
})
285+
.doOnNext(idAndMessage -> idAndMessage.getT1().ifPresent(this.lastId::set))
286+
.flatMapIterable(Tuple2::getT2)
287+
);
288+
}
289+
290+
}
291+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
package io.modelcontextprotocol.client;
2+
3+
import com.fasterxml.jackson.databind.ObjectMapper;
4+
import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport;
5+
import io.modelcontextprotocol.spec.McpClientTransport;
6+
import org.junit.jupiter.api.Timeout;
7+
import org.springframework.web.reactive.function.client.WebClient;
8+
import org.testcontainers.containers.GenericContainer;
9+
import org.testcontainers.containers.wait.strategy.Wait;
10+
import org.testcontainers.images.builder.ImageFromDockerfile;
11+
12+
@Timeout(15)
13+
public class WebClientStreamableHttpAsyncClientTests extends AbstractMcpAsyncClientTests {
14+
15+
static String host = "http://localhost:3001";
16+
17+
// Uses the https://github.com/tzolov/mcp-everything-server-docker-image
18+
@SuppressWarnings("resource")
19+
GenericContainer<?> container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server-streamable:v2")
20+
.withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String()))
21+
.withExposedPorts(3001)
22+
.waitingFor(Wait.forHttp("/").forStatusCode(404));
23+
24+
@Override
25+
protected McpClientTransport createMcpTransport() {
26+
return new WebClientStreamableHttpTransport(new ObjectMapper(), WebClient.builder(), "/mcp", true, false);
27+
}
28+
29+
@Override
30+
protected void onStart() {
31+
container.start();
32+
int port = container.getMappedPort(3001);
33+
host = "http://" + container.getHost() + ":" + port;
34+
}
35+
36+
@Override
37+
public void onClose() {
38+
container.stop();
39+
}
40+
}

0 commit comments

Comments
 (0)