2
2
3
3
import com .fasterxml .jackson .core .type .TypeReference ;
4
4
import com .fasterxml .jackson .databind .ObjectMapper ;
5
- import io .modelcontextprotocol .spec .McpClientTransport ;
6
- import io .modelcontextprotocol .spec .McpError ;
7
- import io .modelcontextprotocol .spec .McpSchema ;
8
- import io .modelcontextprotocol .spec .McpSessionNotFoundException ;
5
+ import io .modelcontextprotocol .spec .*;
9
6
import org .reactivestreams .Publisher ;
10
7
import org .slf4j .Logger ;
11
8
import org .slf4j .LoggerFactory ;
28
25
import java .util .concurrent .atomic .AtomicBoolean ;
29
26
import java .util .concurrent .atomic .AtomicLong ;
30
27
import java .util .concurrent .atomic .AtomicReference ;
28
+ import java .util .function .Consumer ;
31
29
import java .util .function .Function ;
32
30
33
31
public class WebClientStreamableHttpTransport implements McpClientTransport {
@@ -52,11 +50,9 @@ public class WebClientStreamableHttpTransport implements McpClientTransport {
52
50
53
51
private AtomicReference <Function <Mono <McpSchema .JSONRPCMessage >, Mono <McpSchema .JSONRPCMessage >>> handler = new AtomicReference <>();
54
52
55
- private final Disposable . Composite openConnections = Disposables . composite ();
53
+ private final AtomicReference < McpTransportSession > activeSession = new AtomicReference <> ();
56
54
57
- private final AtomicBoolean initialized = new AtomicBoolean ();
58
-
59
- private final AtomicReference <String > sessionId = new AtomicReference <>();
55
+ private final AtomicReference <Consumer <Throwable >> exceptionHandler = new AtomicReference <>();
60
56
61
57
public WebClientStreamableHttpTransport (ObjectMapper objectMapper , WebClient .Builder webClientBuilder ,
62
58
String endpoint , boolean resumableStreams , boolean openConnectionOnStartup ) {
@@ -65,14 +61,12 @@ public WebClientStreamableHttpTransport(ObjectMapper objectMapper, WebClient.Bui
65
61
this .endpoint = endpoint ;
66
62
this .resumableStreams = resumableStreams ;
67
63
this .openConnectionOnStartup = openConnectionOnStartup ;
64
+ this .activeSession .set (new McpTransportSession ());
68
65
}
69
66
70
67
@ Override
71
68
public Mono <Void > connect (Function <Mono <McpSchema .JSONRPCMessage >, Mono <McpSchema .JSONRPCMessage >> handler ) {
72
69
return Mono .deferContextual (ctx -> {
73
- if (this .openConnections .isDisposed ()) {
74
- return Mono .error (new RuntimeException ("Transport already disposed" ));
75
- }
76
70
this .handler .set (handler );
77
71
if (openConnectionOnStartup ) {
78
72
this .reconnect (null , ctx );
@@ -81,9 +75,20 @@ public Mono<Void> connect(Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchem
81
75
});
82
76
}
83
77
78
+ @ Override
79
+ public void handleException (Consumer <Throwable > handler ) {
80
+ this .exceptionHandler .set (handler );
81
+ }
82
+
84
83
@ Override
85
84
public Mono <Void > closeGracefully () {
86
- return Mono .fromRunnable (this .openConnections ::dispose );
85
+ return Mono .defer (() -> {
86
+ McpTransportSession currentSession = this .activeSession .get ();
87
+ if (currentSession != null ) {
88
+ return currentSession .closeGracefully ();
89
+ }
90
+ return Mono .empty ();
91
+ });
87
92
}
88
93
89
94
private void reconnect (McpStream stream , ContextView ctx ) {
@@ -93,12 +98,13 @@ private void reconnect(McpStream stream, ContextView ctx) {
93
98
// listen for messages.
94
99
// If it doesn't, nothing actually happens here, that's just the way it is...
95
100
final AtomicReference <Disposable > disposableRef = new AtomicReference <>();
101
+ final McpTransportSession transportSession = this .activeSession .get ();
96
102
Disposable connection = webClient .get ()
97
103
.uri (this .endpoint )
98
104
.accept (MediaType .TEXT_EVENT_STREAM )
99
105
.headers (httpHeaders -> {
100
- if (sessionId . get () != null ) {
101
- httpHeaders .add ("mcp-session-id" , sessionId . get ());
106
+ if (transportSession . sessionId () != null ) {
107
+ httpHeaders .add ("mcp-session-id" , transportSession . sessionId ());
102
108
}
103
109
if (stream != null && stream .lastId () != null ) {
104
110
httpHeaders .add ("last-event-id" , stream .lastId ());
@@ -123,22 +129,33 @@ else if (response.statusCode().isSameCodeAs(HttpStatus.METHOD_NOT_ALLOWED)) {
123
129
logger .info ("The server does not support SSE streams, using request-response mode." );
124
130
return Flux .empty ();
125
131
}
132
+ else if (response .statusCode ().isSameCodeAs (HttpStatus .NOT_FOUND )) {
133
+ logger .info ("Session {} was not found on the MCP server" , transportSession .sessionId ());
134
+
135
+ McpSessionNotFoundException notFoundException = new McpSessionNotFoundException (
136
+ "Session " + transportSession .sessionId () + " not found" );
137
+ // inform the stream/connection subscriber
138
+ return Flux .error (notFoundException );
139
+ }
126
140
else {
127
141
return response .<McpSchema .JSONRPCMessage >createError ().doOnError (e -> {
128
142
logger .info ("Opening an SSE stream failed. This can be safely ignored." , e );
129
143
}).flux ();
130
144
}
131
145
})
146
+ .doOnError (e -> {
147
+ this .exceptionHandler .get ().accept (e );
148
+ })
132
149
.doFinally (s -> {
133
150
Disposable ref = disposableRef .getAndSet (null );
134
151
if (ref != null ) {
135
- this . openConnections . remove (ref );
152
+ transportSession . removeConnection (ref );
136
153
}
137
154
})
138
155
.contextWrite (ctx )
139
156
.subscribe ();
140
157
disposableRef .set (connection );
141
- this . openConnections . add (connection );
158
+ transportSession . addConnection (connection );
142
159
}
143
160
144
161
@ Override
@@ -151,20 +168,22 @@ public Mono<Void> sendMessage(McpSchema.JSONRPCMessage message) {
151
168
// listen for messages.
152
169
// If it doesn't, nothing actually happens here, that's just the way it is...
153
170
final AtomicReference <Disposable > disposableRef = new AtomicReference <>();
171
+ final McpTransportSession transportSession = this .activeSession .get ();
172
+
154
173
Disposable connection = webClient .post ()
155
174
.uri (this .endpoint )
156
175
.accept (MediaType .TEXT_EVENT_STREAM , MediaType .APPLICATION_JSON )
157
176
.headers (httpHeaders -> {
158
- if (sessionId . get () != null ) {
159
- httpHeaders .add ("mcp-session-id" , sessionId . get ());
177
+ if (transportSession . sessionId () != null ) {
178
+ httpHeaders .add ("mcp-session-id" , transportSession . sessionId ());
160
179
}
161
180
})
162
181
.bodyValue (message )
163
182
.exchangeToFlux (response -> {
164
- // TODO: this goes into the request phase
165
- if (!initialized .compareAndExchange (false , true )) {
183
+ if (transportSession .markInitialized ()) {
166
184
if (!response .headers ().header ("mcp-session-id" ).isEmpty ()) {
167
- sessionId .set (response .headers ().asHttpHeaders ().getFirst ("mcp-session-id" ));
185
+ transportSession
186
+ .setSessionId (response .headers ().asHttpHeaders ().getFirst ("mcp-session-id" ));
168
187
// Once we have a session, we try to open an async stream for
169
188
// the server to send notifications and requests out-of-band.
170
189
reconnect (null , sink .contextView ());
@@ -176,10 +195,10 @@ public Mono<Void> sendMessage(McpSchema.JSONRPCMessage message) {
176
195
// if (!response.statusCode().isSameCodeAs(HttpStatus.ACCEPTED)) {
177
196
if (!response .statusCode ().is2xxSuccessful ()) {
178
197
if (response .statusCode ().isSameCodeAs (HttpStatus .NOT_FOUND )) {
179
- logger .info ("Session {} was not found on the MCP server" , sessionId . get ());
198
+ logger .info ("Session {} was not found on the MCP server" , transportSession . sessionId ());
180
199
181
200
McpSessionNotFoundException notFoundException = new McpSessionNotFoundException (
182
- "Session " + sessionId . get () + " not found" );
201
+ "Session " + transportSession . sessionId () + " not found" );
183
202
// inform the caller of sendMessage
184
203
sink .error (notFoundException );
185
204
// inform the stream/connection subscriber
@@ -233,8 +252,6 @@ else if (contentType.isCompatibleWith(MediaType.APPLICATION_JSON)) {
233
252
}
234
253
})
235
254
.flatMapIterable (Function .identity ());
236
- // .map(Mono::just)
237
- // .flatMap(this.handler.get());
238
255
}
239
256
else {
240
257
sink .error (new RuntimeException ("Unknown media type" ));
@@ -246,13 +263,13 @@ else if (contentType.isCompatibleWith(MediaType.APPLICATION_JSON)) {
246
263
.doFinally (s -> {
247
264
Disposable ref = disposableRef .getAndSet (null );
248
265
if (ref != null ) {
249
- this . openConnections . remove (ref );
266
+ transportSession . removeConnection (ref );
250
267
}
251
268
})
252
269
.contextWrite (sink .contextView ())
253
270
.subscribe ();
254
271
disposableRef .set (connection );
255
- this . openConnections . add (connection );
272
+ transportSession . addConnection (connection );
256
273
});
257
274
}
258
275
0 commit comments