diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 00000000..6009a645 --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,119 @@ +# Contributor Covenant Code of Conduct + +## Our Pledge + +We as members, contributors, and leaders pledge to make participation in our community a +harassment-free experience for everyone, regardless of age, body size, visible or +invisible disability, ethnicity, sex characteristics, gender identity and expression, +level of experience, education, socio-economic status, nationality, personal appearance, +race, religion, or sexual identity and orientation. + +We pledge to act and interact in ways that contribute to an open, welcoming, diverse, +inclusive, and healthy community. + +## Our Standards + +Examples of behavior that contributes to a positive environment for our community +include: + +- Demonstrating empathy and kindness toward other people +- Being respectful of differing opinions, viewpoints, and experiences +- Giving and gracefully accepting constructive feedback +- Accepting responsibility and apologizing to those affected by our mistakes, and + learning from the experience +- Focusing on what is best not just for us as individuals, but for the overall community + +Examples of unacceptable behavior include: + +- The use of sexualized language or imagery, and sexual attention or advances of any kind +- Trolling, insulting or derogatory comments, and personal or political attacks +- Public or private harassment +- Publishing others' private information, such as a physical or email address, without + their explicit permission +- Other conduct which could reasonably be considered inappropriate in a professional + setting + +## Enforcement Responsibilities + +Community leaders are responsible for clarifying and enforcing our standards of +acceptable behavior and will take appropriate and fair corrective action in response to +any behavior that they deem inappropriate, threatening, offensive, or harmful. + +Community leaders have the right and responsibility to remove, edit, or reject comments, +commits, code, wiki edits, issues, and other contributions that are not aligned to this +Code of Conduct, and will communicate reasons for moderation decisions when appropriate. + +## Scope + +This Code of Conduct applies within all community spaces, and also applies when an +individual is officially representing the community in public spaces. Examples of +representing our community include using an official e-mail address, posting via an +official social media account, or acting as an appointed representative at an online or +offline event. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be reported to +the community leaders responsible for enforcement at mcp-coc@anthropic.com. All +complaints will be reviewed and investigated promptly and fairly. + +All community leaders are obligated to respect the privacy and security of the reporter +of any incident. + +## Enforcement Guidelines + +Community leaders will follow these Community Impact Guidelines in determining the +consequences for any action they deem in violation of this Code of Conduct: + +### 1. Correction + +**Community Impact**: Use of inappropriate language or other behavior deemed +unprofessional or unwelcome in the community. + +**Consequence**: A private, written warning from community leaders, providing clarity +around the nature of the violation and an explanation of why the behavior was +inappropriate. A public apology may be requested. + +### 2. Warning + +**Community Impact**: A violation through a single incident or series of actions. + +**Consequence**: A warning with consequences for continued behavior. No interaction with +the people involved, including unsolicited interaction with those enforcing the Code of +Conduct, for a specified period of time. This includes avoiding interactions in community +spaces as well as external channels like social media. Violating these terms may lead to +a temporary or permanent ban. + +### 3. Temporary Ban + +**Community Impact**: A serious violation of community standards, including sustained +inappropriate behavior. + +**Consequence**: A temporary ban from any sort of interaction or public communication +with the community for a specified period of time. No public or private interaction with +the people involved, including unsolicited interaction with those enforcing the Code of +Conduct, is allowed during this period. Violating these terms may lead to a permanent +ban. + +### 4. Permanent Ban + +**Community Impact**: Demonstrating a pattern of violation of community standards, +including sustained inappropriate behavior, harassment of an individual, or aggression +toward or disparagement of classes of individuals. + +**Consequence**: A permanent ban from any sort of public interaction within the +community. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 2.0, +available at https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. + +Community Impact Guidelines were inspired by +[Mozilla's code of conduct enforcement ladder](https://github.com/mozilla/diversity). + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see the FAQ at +https://www.contributor-covenant.org/faq. Translations are available at +https://www.contributor-covenant.org/translations. \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 00000000..517f3255 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,94 @@ +# Contributing to Model Context Protocol Java SDK + +Thank you for your interest in contributing to the Model Context Protocol Java SDK! +This document outlines how to contribute to this project. + +## Prerequisites + +The following software is required to work on the codebase: + +- `Java 17` or above +- `Docker` +- `npx` + +## Getting Started + +1. Fork the repository +2. Clone your fork: + +```bash +git clone https://github.com/YOUR-USERNAME/java-sdk.git +cd java-sdk +``` + +3. Build from source: + +```bash +./mvnw clean install -DskipTests # skip the tests +./mvnw test # run tests +``` + +## Reporting Issues + +Please create an issue in the repository if you discover a bug or would like to +propose an enhancement. Bug reports should have a reproducer in the form of a code +sample or a repository attached that the maintainers or contributors can work with to +address the problem. + +## Making Changes + +1. Create a new branch: + +```bash +git checkout -b feature/your-feature-name +``` + +2. Make your changes +3. Validate your changes: + +```bash +./mvnw clean test +``` + +### Change Proposal Guidelines + +#### Principles of MCP + +1. **Simple + Minimal**: It is much easier to add things to the codebase than it is to + remove them. To maintain simplicity, we keep a high bar for adding new concepts and + primitives as each addition requires maintenance and compatibility consideration. +2. **Concrete**: Code changes need to be based on specific usage and implementation + challenges and not on speculative ideas. Most importantly, the SDK is meant to + implement the MCP specification. + +## Submitting Changes + +1. For non-trivial changes, please clarify with the maintainers in an issue whether + you can contribute the change and the desired scope of the change. +2. For trivial changes (for example a couple of lines or documentation changes) there + is no need to open an issue first. +3. Push your changes to your fork. +4. Submit a pull request to the main repository. +5. Follow the pull request template. +6. Wait for review. +7. For any follow-up work, please add new commits instead of force-pushing. This will + allow the reviewer to focus on incremental changes instead of having to restart the + review process. + +## Code of Conduct + +This project follows a Code of Conduct. Please review it in +[CODE_OF_CONDUCT.md](CODE_OF_CONDUCT.md). + +## Questions + +If you have questions, please create a discussion in the repository. + +## License + +By contributing, you agree that your contributions will be licensed under the MIT +License. + +## Security + +Please review our [Security Policy](SECURITY.md) for reporting security issues. \ No newline at end of file diff --git a/README.md b/README.md index 9fc17306..0cd3f84a 100644 --- a/README.md +++ b/README.md @@ -30,11 +30,8 @@ To run the tests you have to pre-install `Docker` and `npx`. ## Contributing -Contributions are welcome! Please: - -1. Fork the repository -2. Create a feature branch -3. Submit a Pull Request +Contributions are welcome! +Please follow the [Contributing Guidelines](CONTRIBUTING.md). ## Team diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 00000000..74e9880f --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,21 @@ +# Security Policy + +Thank you for helping us keep the SDKs and systems they interact with secure. + +## Reporting Security Issues + +This SDK is maintained by [Anthropic](https://www.anthropic.com/) as part of the Model +Context Protocol project. + +The security of our systems and user data is Anthropic’s top priority. We appreciate the +work of security researchers acting in good faith in identifying and reporting potential +vulnerabilities. + +Our security program is managed on HackerOne and we ask that any validated vulnerability +in this functionality be reported through their +[submission form](https://hackerone.com/anthropic-vdp/reports/new?type=team&report_type=vulnerability). + +## Vulnerability Disclosure Program + +Our Vulnerability Program Guidelines are defined on our +[HackerOne program page](https://hackerone.com/anthropic-vdp). \ No newline at end of file diff --git a/mcp-bom/pom.xml b/mcp-bom/pom.xml index 4f24f719..7214dacd 100644 --- a/mcp-bom/pom.xml +++ b/mcp-bom/pom.xml @@ -7,7 +7,7 @@ io.modelcontextprotocol.sdk mcp-parent - 0.10.0-SNAPSHOT + 0.11.0-SNAPSHOT mcp-bom diff --git a/mcp-spring/mcp-spring-webflux/pom.xml b/mcp-spring/mcp-spring-webflux/pom.xml index 86f46bf9..26452fe9 100644 --- a/mcp-spring/mcp-spring-webflux/pom.xml +++ b/mcp-spring/mcp-spring-webflux/pom.xml @@ -6,7 +6,7 @@ io.modelcontextprotocol.sdk mcp-parent - 0.10.0-SNAPSHOT + 0.11.0-SNAPSHOT ../../pom.xml mcp-spring-webflux @@ -25,13 +25,13 @@ io.modelcontextprotocol.sdk mcp - 0.10.0-SNAPSHOT + 0.11.0-SNAPSHOT io.modelcontextprotocol.sdk mcp-test - 0.10.0-SNAPSHOT + 0.11.0-SNAPSHOT test @@ -99,6 +99,12 @@ ${testcontainers.version} test + + org.testcontainers + toxiproxy + ${toxiproxy.version} + test + org.awaitility diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java new file mode 100644 index 00000000..e7b7c8ee --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java @@ -0,0 +1,520 @@ +package io.modelcontextprotocol.client.transport; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.DefaultMcpTransportSession; +import io.modelcontextprotocol.spec.DefaultMcpTransportStream; +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpTransportSessionNotFoundException; +import io.modelcontextprotocol.spec.McpTransportSession; +import io.modelcontextprotocol.spec.McpTransportStream; +import io.modelcontextprotocol.util.Assert; +import org.reactivestreams.Publisher; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.codec.ServerSentEvent; +import org.springframework.web.reactive.function.client.ClientResponse; +import org.springframework.web.reactive.function.client.WebClient; +import org.springframework.web.reactive.function.client.WebClientResponseException; +import reactor.core.Disposable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.util.function.Tuple2; +import reactor.util.function.Tuples; + +import java.io.IOException; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.function.Supplier; + +/** + * An implementation of the Streamable HTTP protocol as defined by the + * 2025-03-26 version of the MCP specification. + * + *

+ * The transport is capable of resumability and reconnects. It reacts to transport-level + * session invalidation and will propagate {@link McpTransportSessionNotFoundException + * appropriate exceptions} to the higher level abstraction layer when needed in order to + * allow proper state management. The implementation handles servers that are stateful and + * provide session meta information, but can also communicate with stateless servers that + * do not provide a session identifier and do not support SSE streams. + *

+ *

+ * This implementation does not handle backwards compatibility with the "HTTP + * with SSE" transport. In order to communicate over the phased-out + * 2024-11-05 protocol, use {@link HttpClientSseClientTransport} or + * {@link WebFluxSseClientTransport}. + *

+ * + * @author Dariusz Jędrzejczyk + * @see Streamable + * HTTP transport specification + */ +public class WebClientStreamableHttpTransport implements McpClientTransport { + + private static final Logger logger = LoggerFactory.getLogger(WebClientStreamableHttpTransport.class); + + private static final String DEFAULT_ENDPOINT = "/mcp"; + + /** + * Event type for JSON-RPC messages received through the SSE connection. The server + * sends messages with this event type to transmit JSON-RPC protocol data. + */ + private static final String MESSAGE_EVENT_TYPE = "message"; + + private static final ParameterizedTypeReference> PARAMETERIZED_TYPE_REF = new ParameterizedTypeReference<>() { + }; + + private final ObjectMapper objectMapper; + + private final WebClient webClient; + + private final String endpoint; + + private final boolean openConnectionOnStartup; + + private final boolean resumableStreams; + + private final AtomicReference activeSession = new AtomicReference<>(); + + private final AtomicReference, Mono>> handler = new AtomicReference<>(); + + private final AtomicReference> exceptionHandler = new AtomicReference<>(); + + private WebClientStreamableHttpTransport(ObjectMapper objectMapper, WebClient.Builder webClientBuilder, + String endpoint, boolean resumableStreams, boolean openConnectionOnStartup) { + this.objectMapper = objectMapper; + this.webClient = webClientBuilder.build(); + this.endpoint = endpoint; + this.resumableStreams = resumableStreams; + this.openConnectionOnStartup = openConnectionOnStartup; + this.activeSession.set(createTransportSession()); + } + + /** + * Create a stateful builder for creating {@link WebClientStreamableHttpTransport} + * instances. + * @param webClientBuilder the {@link WebClient.Builder} to use + * @return a builder which will create an instance of + * {@link WebClientStreamableHttpTransport} once {@link Builder#build()} is called + */ + public static Builder builder(WebClient.Builder webClientBuilder) { + return new Builder(webClientBuilder); + } + + @Override + public Mono connect(Function, Mono> handler) { + return Mono.deferContextual(ctx -> { + this.handler.set(handler); + if (openConnectionOnStartup) { + logger.debug("Eagerly opening connection on startup"); + return this.reconnect(null).then(); + } + return Mono.empty(); + }); + } + + private DefaultMcpTransportSession createTransportSession() { + Supplier> onClose = () -> { + DefaultMcpTransportSession transportSession = this.activeSession.get(); + return transportSession.sessionId().isEmpty() ? Mono.empty() + : webClient.delete().uri(this.endpoint).headers(httpHeaders -> { + httpHeaders.add("mcp-session-id", transportSession.sessionId().get()); + }).retrieve().toBodilessEntity().doOnError(e -> logger.info("Got response {}", e)).then(); + }; + return new DefaultMcpTransportSession(onClose); + } + + @Override + public void setExceptionHandler(Consumer handler) { + logger.debug("Exception handler registered"); + this.exceptionHandler.set(handler); + } + + private void handleException(Throwable t) { + logger.debug("Handling exception for session {}", sessionIdOrPlaceholder(this.activeSession.get()), t); + if (t instanceof McpTransportSessionNotFoundException) { + McpTransportSession invalidSession = this.activeSession.getAndSet(createTransportSession()); + logger.warn("Server does not recognize session {}. Invalidating.", invalidSession.sessionId()); + invalidSession.close(); + } + Consumer handler = this.exceptionHandler.get(); + if (handler != null) { + handler.accept(t); + } + } + + @Override + public Mono closeGracefully() { + return Mono.defer(() -> { + logger.debug("Graceful close triggered"); + DefaultMcpTransportSession currentSession = this.activeSession.getAndSet(createTransportSession()); + if (currentSession != null) { + return currentSession.closeGracefully(); + } + return Mono.empty(); + }); + } + + private Mono reconnect(McpTransportStream stream) { + return Mono.deferContextual(ctx -> { + if (stream != null) { + logger.debug("Reconnecting stream {} with lastId {}", stream.streamId(), stream.lastId()); + } + else { + logger.debug("Reconnecting with no prior stream"); + } + // Here we attempt to initialize the client. In case the server supports SSE, + // we will establish a long-running + // session here and listen for messages. If it doesn't, that's ok, the server + // is a simple, stateless one. + final AtomicReference disposableRef = new AtomicReference<>(); + final McpTransportSession transportSession = this.activeSession.get(); + + Disposable connection = webClient.get() + .uri(this.endpoint) + .accept(MediaType.TEXT_EVENT_STREAM) + .headers(httpHeaders -> { + transportSession.sessionId().ifPresent(id -> httpHeaders.add("mcp-session-id", id)); + if (stream != null) { + stream.lastId().ifPresent(id -> httpHeaders.add("last-event-id", id)); + } + }) + .exchangeToFlux(response -> { + if (isEventStream(response)) { + return eventStream(stream, response); + } + else if (isNotAllowed(response)) { + logger.debug("The server does not support SSE streams, using request-response mode."); + return Flux.empty(); + } + else if (isNotFound(response)) { + String sessionIdRepresentation = sessionIdOrPlaceholder(transportSession); + return mcpSessionNotFoundError(sessionIdRepresentation); + } + else { + return response.createError().doOnError(e -> { + logger.info("Opening an SSE stream failed. This can be safely ignored.", e); + }).flux(); + } + }) + .onErrorComplete(t -> { + this.handleException(t); + return true; + }) + .doFinally(s -> { + Disposable ref = disposableRef.getAndSet(null); + if (ref != null) { + transportSession.removeConnection(ref); + } + }) + .contextWrite(ctx) + .subscribe(); + + disposableRef.set(connection); + transportSession.addConnection(connection); + return Mono.just(connection); + }); + } + + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message) { + return Mono.create(sink -> { + logger.debug("Sending message {}", message); + // Here we attempt to initialize the client. + // In case the server supports SSE, we will establish a long-running session + // here and + // listen for messages. + // If it doesn't, nothing actually happens here, that's just the way it is... + final AtomicReference disposableRef = new AtomicReference<>(); + final McpTransportSession transportSession = this.activeSession.get(); + + Disposable connection = webClient.post() + .uri(this.endpoint) + .accept(MediaType.TEXT_EVENT_STREAM, MediaType.APPLICATION_JSON) + .headers(httpHeaders -> { + transportSession.sessionId().ifPresent(id -> httpHeaders.add("mcp-session-id", id)); + }) + .bodyValue(message) + .exchangeToFlux(response -> { + if (transportSession + .markInitialized(response.headers().asHttpHeaders().getFirst("mcp-session-id"))) { + // Once we have a session, we try to open an async stream for + // the server to send notifications and requests out-of-band. + reconnect(null).contextWrite(sink.contextView()).subscribe(); + } + + String sessionRepresentation = sessionIdOrPlaceholder(transportSession); + + // The spec mentions only ACCEPTED, but the existing SDKs can return + // 200 OK for notifications + if (response.statusCode().is2xxSuccessful()) { + Optional contentType = response.headers().contentType(); + // Existing SDKs consume notifications with no response body nor + // content type + if (contentType.isEmpty()) { + logger.trace("Message was successfully sent via POST for session {}", + sessionRepresentation); + // signal the caller that the message was successfully + // delivered + sink.success(); + // communicate to downstream there is no streamed data coming + return Flux.empty(); + } + else { + MediaType mediaType = contentType.get(); + if (mediaType.isCompatibleWith(MediaType.TEXT_EVENT_STREAM)) { + // communicate to caller that the message was delivered + sink.success(); + // starting a stream + return newEventStream(response, sessionRepresentation); + } + else if (mediaType.isCompatibleWith(MediaType.APPLICATION_JSON)) { + logger.trace("Received response to POST for session {}", sessionRepresentation); + // communicate to caller the message was delivered + sink.success(); + return responseFlux(response); + } + else { + logger.warn("Unknown media type {} returned for POST in session {}", contentType, + sessionRepresentation); + return Flux.error(new RuntimeException("Unknown media type returned: " + contentType)); + } + } + } + else { + if (isNotFound(response)) { + return mcpSessionNotFoundError(sessionRepresentation); + } + return extractError(response, sessionRepresentation); + } + }) + .flatMap(jsonRpcMessage -> this.handler.get().apply(Mono.just(jsonRpcMessage))) + .onErrorResume(t -> { + // handle the error first + this.handleException(t); + // inform the caller of sendMessage + sink.error(t); + return Flux.empty(); + }) + .doFinally(s -> { + Disposable ref = disposableRef.getAndSet(null); + if (ref != null) { + transportSession.removeConnection(ref); + } + }) + .contextWrite(sink.contextView()) + .subscribe(); + disposableRef.set(connection); + transportSession.addConnection(connection); + }); + } + + private static Flux mcpSessionNotFoundError(String sessionRepresentation) { + logger.warn("Session {} was not found on the MCP server", sessionRepresentation); + // inform the stream/connection subscriber + return Flux.error(new McpTransportSessionNotFoundException(sessionRepresentation)); + } + + private Flux extractError(ClientResponse response, String sessionRepresentation) { + return response.createError().onErrorResume(e -> { + WebClientResponseException responseException = (WebClientResponseException) e; + byte[] body = responseException.getResponseBodyAsByteArray(); + McpSchema.JSONRPCResponse.JSONRPCError jsonRpcError = null; + Exception toPropagate; + try { + McpSchema.JSONRPCResponse jsonRpcResponse = objectMapper.readValue(body, + McpSchema.JSONRPCResponse.class); + jsonRpcError = jsonRpcResponse.error(); + toPropagate = new McpError(jsonRpcError); + } + catch (IOException ex) { + toPropagate = new RuntimeException("Sending request failed", e); + logger.debug("Received content together with {} HTTP code response: {}", response.statusCode(), body); + } + + // Some implementations can return 400 when presented with a + // session id that it doesn't know about, so we will + // invalidate the session + // https://github.com/modelcontextprotocol/typescript-sdk/issues/389 + if (responseException.getStatusCode().isSameCodeAs(HttpStatus.BAD_REQUEST)) { + return Mono.error(new McpTransportSessionNotFoundException(sessionRepresentation, toPropagate)); + } + return Mono.empty(); + }).flux(); + } + + private Flux eventStream(McpTransportStream stream, ClientResponse response) { + McpTransportStream sessionStream = stream != null ? stream + : new DefaultMcpTransportStream<>(this.resumableStreams, this::reconnect); + logger.debug("Connected stream {}", sessionStream.streamId()); + + var idWithMessages = response.bodyToFlux(PARAMETERIZED_TYPE_REF).map(this::parse); + return Flux.from(sessionStream.consumeSseStream(idWithMessages)); + } + + private static boolean isNotFound(ClientResponse response) { + return response.statusCode().isSameCodeAs(HttpStatus.NOT_FOUND); + } + + private static boolean isNotAllowed(ClientResponse response) { + return response.statusCode().isSameCodeAs(HttpStatus.METHOD_NOT_ALLOWED); + } + + private static boolean isEventStream(ClientResponse response) { + return response.statusCode().is2xxSuccessful() && response.headers().contentType().isPresent() + && response.headers().contentType().get().isCompatibleWith(MediaType.TEXT_EVENT_STREAM); + } + + private static String sessionIdOrPlaceholder(McpTransportSession transportSession) { + return transportSession.sessionId().orElse("[missing_session_id]"); + } + + private Flux responseFlux(ClientResponse response) { + return response.bodyToMono(String.class).>handle((responseMessage, s) -> { + try { + McpSchema.JSONRPCMessage jsonRpcResponse = McpSchema.deserializeJsonRpcMessage(objectMapper, + responseMessage); + s.next(List.of(jsonRpcResponse)); + } + catch (IOException e) { + s.error(e); + } + }).flatMapIterable(Function.identity()); + } + + private Flux newEventStream(ClientResponse response, String sessionRepresentation) { + McpTransportStream sessionStream = new DefaultMcpTransportStream<>(this.resumableStreams, + this::reconnect); + logger.trace("Sent POST and opened a stream ({}) for session {}", sessionStream.streamId(), + sessionRepresentation); + return eventStream(sessionStream, response); + } + + @Override + public T unmarshalFrom(Object data, TypeReference typeRef) { + return this.objectMapper.convertValue(data, typeRef); + } + + private Tuple2, Iterable> parse(ServerSentEvent event) { + if (MESSAGE_EVENT_TYPE.equals(event.event())) { + try { + // We don't support batching ATM and probably won't since the next version + // considers removing it. + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(this.objectMapper, event.data()); + return Tuples.of(Optional.ofNullable(event.id()), List.of(message)); + } + catch (IOException ioException) { + throw new McpError("Error parsing JSON-RPC message: " + event.data()); + } + } + else { + throw new McpError("Received unrecognized SSE event type: " + event.event()); + } + } + + /** + * Builder for {@link WebClientStreamableHttpTransport}. + */ + public static class Builder { + + private ObjectMapper objectMapper; + + private WebClient.Builder webClientBuilder; + + private String endpoint = DEFAULT_ENDPOINT; + + private boolean resumableStreams = true; + + private boolean openConnectionOnStartup = false; + + private Builder(WebClient.Builder webClientBuilder) { + Assert.notNull(webClientBuilder, "WebClient.Builder must not be null"); + this.webClientBuilder = webClientBuilder; + } + + /** + * Configure the {@link ObjectMapper} to use. + * @param objectMapper instance to use + * @return the builder instance + */ + public Builder objectMapper(ObjectMapper objectMapper) { + Assert.notNull(objectMapper, "ObjectMapper must not be null"); + this.objectMapper = objectMapper; + return this; + } + + /** + * Configure the {@link WebClient.Builder} to construct the {@link WebClient}. + * @param webClientBuilder instance to use + * @return the builder instance + */ + public Builder webClientBuilder(WebClient.Builder webClientBuilder) { + Assert.notNull(webClientBuilder, "WebClient.Builder must not be null"); + this.webClientBuilder = webClientBuilder; + return this; + } + + /** + * Configure the endpoint to make HTTP requests against. + * @param endpoint endpoint to use + * @return the builder instance + */ + public Builder endpoint(String endpoint) { + Assert.hasText(endpoint, "endpoint must be a non-empty String"); + this.endpoint = endpoint; + return this; + } + + /** + * Configure whether to use the stream resumability feature by keeping track of + * SSE event ids. + * @param resumableStreams if {@code true} event ids will be tracked and upon + * disconnection, the last seen id will be used upon reconnection as a header to + * resume consuming messages. + * @return the builder instance + */ + public Builder resumableStreams(boolean resumableStreams) { + this.resumableStreams = resumableStreams; + return this; + } + + /** + * Configure whether the client should open an SSE connection upon startup. Not + * all servers support this (although it is in theory possible with the current + * specification), so use with caution. By default, this value is {@code false}. + * @param openConnectionOnStartup if {@code true} the {@link #connect(Function)} + * method call will try to open an SSE connection before sending any JSON-RPC + * request + * @return the builder instance + */ + public Builder openConnectionOnStartup(boolean openConnectionOnStartup) { + this.openConnectionOnStartup = openConnectionOnStartup; + return this; + } + + /** + * Construct a fresh instance of {@link WebClientStreamableHttpTransport} using + * the current builder configuration. + * @return a new instance of {@link WebClientStreamableHttpTransport} + */ + public WebClientStreamableHttpTransport build() { + ObjectMapper objectMapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); + + return new WebClientStreamableHttpTransport(objectMapper, this.webClientBuilder, endpoint, resumableStreams, + openConnectionOnStartup); + } + + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java index 37abe295..128cda4c 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java @@ -190,6 +190,9 @@ public WebFluxSseClientTransport(WebClient.Builder webClientBuilder, ObjectMappe */ @Override public Mono connect(Function, Mono> handler) { + // TODO: Avoid eager connection opening and enable resilience + // -> upon disconnects, re-establish connection + // -> allow optimizing for eager connection start using a constructor flag Flux> events = eventStream(); this.inboundSubscription = events.concatMap(event -> Mono.just(event).handle((e, s) -> { if (ENDPOINT_EVENT_TYPE.equals(event.event())) { diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java index 2ba04746..2f85654e 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java @@ -4,10 +4,11 @@ package io.modelcontextprotocol; import java.time.Duration; -import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import java.util.function.BiFunction; @@ -26,11 +27,11 @@ import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.*; -import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities.CompletionCapabilities; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; +import reactor.core.publisher.Mono; import reactor.netty.DisposableServer; import reactor.netty.http.server.HttpServer; @@ -39,6 +40,7 @@ import org.springframework.web.client.RestClient; import org.springframework.web.reactive.function.client.WebClient; import org.springframework.web.reactive.function.server.RouterFunctions; +import reactor.test.StepVerifier; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; @@ -329,6 +331,226 @@ void testCreateMessageWithRequestTimeoutFail(String clientType) throws Interrupt mcpServer.closeGracefully().block(); } + // --------------------------------------- + // Elicitation Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCreateElicitationWithoutElicitationCapabilities(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + exchange.createElicitation(mock(ElicitRequest.class)).block(); + + return Mono.just(mock(CallToolResult.class)); + }); + + var server = McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").tools(tool).build(); + + try ( + // Create client without elicitation capabilities + var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")).build()) { + + assertThat(client.initialize()).isNotNull(); + + try { + client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be configured with elicitation capabilities"); + } + } + server.closeGracefully().block(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCreateElicitationSuccess(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + + return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); + }; + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var elicitationRequest = ElicitRequest.builder() + .message("Test message") + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(ElicitResult.Action.ACCEPT); + assertThat(result.content().get("message")).isEqualTo("Test message"); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + } + mcpServer.closeGracefully().block(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCreateElicitationWithRequestTimeoutSuccess(String clientType) { + + // Client + var clientBuilder = clientBuilders.get(clientType); + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + try { + TimeUnit.SECONDS.sleep(2); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build(); + + // Server + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var elicitationRequest = ElicitRequest.builder() + .message("Test message") + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(ElicitResult.Action.ACCEPT); + assertThat(result.content().get("message")).isEqualTo("Test message"); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(3)) + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + + mcpClient.closeGracefully(); + mcpServer.closeGracefully().block(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCreateElicitationWithRequestTimeoutFail(String clientType) { + + // Client + var clientBuilder = clientBuilders.get(clientType); + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + try { + TimeUnit.SECONDS.sleep(2); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build(); + + // Server + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var elicitationRequest = ElicitRequest.builder() + .message("Test message") + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(ElicitResult.Action.ACCEPT); + assertThat(result.content().get("message")).isEqualTo("Test message"); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(1)) + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThatExceptionOfType(McpError.class).isThrownBy(() -> { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + }).withMessageContaining("within 1000ms"); + + mcpClient.closeGracefully(); + mcpServer.closeGracefully().block(); + } + // --------------------------------------- // Roots Tests // --------------------------------------- @@ -651,9 +873,11 @@ void testInitialize(String clientType) { @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "httpclient", "webflux" }) - void testLoggingNotification(String clientType) { + void testLoggingNotification(String clientType) throws InterruptedException { + int expectedNotificationsCount = 3; + CountDownLatch latch = new CountDownLatch(expectedNotificationsCount); // Create a list to store received logging notifications - List receivedNotifications = new ArrayList<>(); + List receivedNotifications = new CopyOnWriteArrayList<>(); var clientBuilder = clientBuilders.get(clientType); @@ -709,6 +933,7 @@ void testLoggingNotification(String clientType) { // Create client with logging notification handler var mcpClient = clientBuilder.loggingConsumer(notification -> { receivedNotifications.add(notification); + latch.countDown(); }).build()) { // Initialize client @@ -724,31 +949,28 @@ void testLoggingNotification(String clientType) { assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class); assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Logging test completed"); - // Wait for notifications to be processed - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(latch.await(5, TimeUnit.SECONDS)).as("Should receive notifications in reasonable time").isTrue(); - // Should have received 3 notifications (1 NOTICE and 2 ERROR) - assertThat(receivedNotifications).hasSize(3); + // Should have received 3 notifications (1 NOTICE and 2 ERROR) + assertThat(receivedNotifications).hasSize(expectedNotificationsCount); - Map notificationMap = receivedNotifications.stream() - .collect(Collectors.toMap(n -> n.data(), n -> n)); + Map notificationMap = receivedNotifications.stream() + .collect(Collectors.toMap(n -> n.data(), n -> n)); - // First notification should be NOTICE level - assertThat(notificationMap.get("Notice message").level()).isEqualTo(McpSchema.LoggingLevel.NOTICE); - assertThat(notificationMap.get("Notice message").logger()).isEqualTo("test-logger"); - assertThat(notificationMap.get("Notice message").data()).isEqualTo("Notice message"); + // First notification should be NOTICE level + assertThat(notificationMap.get("Notice message").level()).isEqualTo(McpSchema.LoggingLevel.NOTICE); + assertThat(notificationMap.get("Notice message").logger()).isEqualTo("test-logger"); + assertThat(notificationMap.get("Notice message").data()).isEqualTo("Notice message"); - // Second notification should be ERROR level - assertThat(notificationMap.get("Error message").level()).isEqualTo(McpSchema.LoggingLevel.ERROR); - assertThat(notificationMap.get("Error message").logger()).isEqualTo("test-logger"); - assertThat(notificationMap.get("Error message").data()).isEqualTo("Error message"); + // Second notification should be ERROR level + assertThat(notificationMap.get("Error message").level()).isEqualTo(McpSchema.LoggingLevel.ERROR); + assertThat(notificationMap.get("Error message").logger()).isEqualTo("test-logger"); + assertThat(notificationMap.get("Error message").data()).isEqualTo("Error message"); - // Third notification should be ERROR level - assertThat(notificationMap.get("Another error message").level()) - .isEqualTo(McpSchema.LoggingLevel.ERROR); - assertThat(notificationMap.get("Another error message").logger()).isEqualTo("test-logger"); - assertThat(notificationMap.get("Another error message").data()).isEqualTo("Another error message"); - }); + // Third notification should be ERROR level + assertThat(notificationMap.get("Another error message").level()).isEqualTo(McpSchema.LoggingLevel.ERROR); + assertThat(notificationMap.get("Another error message").logger()).isEqualTo("test-logger"); + assertThat(notificationMap.get("Another error message").data()).isEqualTo("Another error message"); } mcpServer.close(); } diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientResiliencyTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientResiliencyTests.java new file mode 100644 index 00000000..80fc671e --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientResiliencyTests.java @@ -0,0 +1,17 @@ +package io.modelcontextprotocol.client; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; +import org.junit.jupiter.api.Timeout; +import org.springframework.web.reactive.function.client.WebClient; + +@Timeout(15) +public class WebClientStreamableHttpAsyncClientResiliencyTests extends AbstractMcpAsyncClientResiliencyTests { + + @Override + protected McpClientTransport createMcpTransport() { + return WebClientStreamableHttpTransport.builder(WebClient.builder().baseUrl(host)).build(); + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientTests.java new file mode 100644 index 00000000..4c803265 --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientTests.java @@ -0,0 +1,42 @@ +package io.modelcontextprotocol.client; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; +import org.junit.jupiter.api.Timeout; +import org.springframework.web.reactive.function.client.WebClient; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.wait.strategy.Wait; +import org.testcontainers.images.builder.ImageFromDockerfile; + +@Timeout(15) +public class WebClientStreamableHttpAsyncClientTests extends AbstractMcpAsyncClientTests { + + static String host = "http://localhost:3001"; + + // Uses the https://github.com/tzolov/mcp-everything-server-docker-image + @SuppressWarnings("resource") + GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") + .withCommand("node dist/index.js streamableHttp") + .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) + .withExposedPorts(3001) + .waitingFor(Wait.forHttp("/").forStatusCode(404)); + + @Override + protected McpClientTransport createMcpTransport() { + return WebClientStreamableHttpTransport.builder(WebClient.builder().baseUrl(host)).build(); + } + + @Override + protected void onStart() { + container.start(); + int port = container.getMappedPort(3001); + host = "http://" + container.getHost() + ":" + port; + } + + @Override + public void onClose() { + container.stop(); + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpSyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpSyncClientTests.java new file mode 100644 index 00000000..a8cad489 --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpSyncClientTests.java @@ -0,0 +1,41 @@ +package io.modelcontextprotocol.client; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; +import org.junit.jupiter.api.Timeout; +import org.springframework.web.reactive.function.client.WebClient; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.wait.strategy.Wait; + +@Timeout(15) +public class WebClientStreamableHttpSyncClientTests extends AbstractMcpSyncClientTests { + + static String host = "http://localhost:3001"; + + // Uses the https://github.com/tzolov/mcp-everything-server-docker-image + @SuppressWarnings("resource") + GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") + .withCommand("node dist/index.js streamableHttp") + .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) + .withExposedPorts(3001) + .waitingFor(Wait.forHttp("/").forStatusCode(404)); + + @Override + protected McpClientTransport createMcpTransport() { + return WebClientStreamableHttpTransport.builder(WebClient.builder().baseUrl(host)).build(); + } + + @Override + protected void onStart() { + container.start(); + int port = container.getMappedPort(3001); + host = "http://" + container.getHost() + ":" + port; + } + + @Override + public void onClose() { + container.stop(); + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java index b43c1449..f0533cb4 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java @@ -26,7 +26,8 @@ class WebFluxSseMcpAsyncClientTests extends AbstractMcpAsyncClientTests { // Uses the https://github.com/tzolov/mcp-everything-server-docker-image @SuppressWarnings("resource") - GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v1") + GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") + .withCommand("node dist/index.js sse") .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) .withExposedPorts(3001) .waitingFor(Wait.forHttp("/").forStatusCode(404)); diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java index 66ac8a6d..9b0959a3 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java @@ -26,7 +26,8 @@ class WebFluxSseMcpSyncClientTests extends AbstractMcpSyncClientTests { // Uses the https://github.com/tzolov/mcp-everything-server-docker-image @SuppressWarnings("resource") - GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v1") + GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") + .withCommand("node dist/index.js sse") .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) .withExposedPorts(3001) .waitingFor(Wait.forHttp("/").forStatusCode(404)); diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java index c757d3da..42b91d14 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java @@ -41,7 +41,8 @@ class WebFluxSseClientTransportTests { static String host = "http://localhost:3001"; @SuppressWarnings("resource") - GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v1") + GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") + .withCommand("node dist/index.js sse") .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) .withExposedPorts(3001) .waitingFor(Wait.forHttp("/").forStatusCode(404)); diff --git a/mcp-spring/mcp-spring-webflux/src/test/resources/logback.xml b/mcp-spring/mcp-spring-webflux/src/test/resources/logback.xml index 5ad73374..abc831d1 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/resources/logback.xml +++ b/mcp-spring/mcp-spring-webflux/src/test/resources/logback.xml @@ -9,13 +9,13 @@ - + - - + + - + diff --git a/mcp-spring/mcp-spring-webmvc/pom.xml b/mcp-spring/mcp-spring-webmvc/pom.xml index 82fbbf3e..48d1c346 100644 --- a/mcp-spring/mcp-spring-webmvc/pom.xml +++ b/mcp-spring/mcp-spring-webmvc/pom.xml @@ -6,7 +6,7 @@ io.modelcontextprotocol.sdk mcp-parent - 0.10.0-SNAPSHOT + 0.11.0-SNAPSHOT ../../pom.xml mcp-spring-webmvc @@ -25,13 +25,13 @@ io.modelcontextprotocol.sdk mcp - 0.10.0-SNAPSHOT + 0.11.0-SNAPSHOT io.modelcontextprotocol.sdk mcp-test - 0.10.0-SNAPSHOT + 0.11.0-SNAPSHOT test diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java index b12d6843..3f3f7be6 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java @@ -357,6 +357,219 @@ void testCreateMessageWithRequestTimeoutFail() throws InterruptedException { mcpServer.close(); } + // --------------------------------------- + // Elicitation Tests + // --------------------------------------- + @Test + void testCreateElicitationWithoutElicitationCapabilities() { + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + exchange.createElicitation(mock(McpSchema.ElicitRequest.class)).block(); + + return Mono.just(mock(CallToolResult.class)); + }); + + var server = McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").tools(tool).build(); + + try ( + // Create client without elicitation capabilities + var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")).build()) { + + assertThat(client.initialize()).isNotNull(); + + try { + client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be configured with elicitation capabilities"); + } + } + server.closeGracefully().block(); + } + + @Test + void testCreateElicitationSuccess() { + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + + return new McpSchema.ElicitResult(McpSchema.ElicitResult.Action.ACCEPT, + Map.of("message", request.message())); + }; + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var elicitationRequest = McpSchema.ElicitRequest.builder() + .message("Test message") + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT); + assertThat(result.content().get("message")).isEqualTo("Test message"); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + } + mcpServer.closeGracefully().block(); + } + + @Test + void testCreateElicitationWithRequestTimeoutSuccess() { + + // Client + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + try { + TimeUnit.SECONDS.sleep(2); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new McpSchema.ElicitResult(McpSchema.ElicitResult.Action.ACCEPT, + Map.of("message", request.message())); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build(); + + // Server + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var elicitationRequest = McpSchema.ElicitRequest.builder() + .message("Test message") + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT); + assertThat(result.content().get("message")).isEqualTo("Test message"); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(3)) + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + + mcpClient.closeGracefully(); + mcpServer.closeGracefully().block(); + } + + @Test + void testCreateElicitationWithRequestTimeoutFail() { + + // Client + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + try { + TimeUnit.SECONDS.sleep(2); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new McpSchema.ElicitResult(McpSchema.ElicitResult.Action.ACCEPT, + Map.of("message", request.message())); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build(); + + // Server + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var elicitationRequest = McpSchema.ElicitRequest.builder() + .message("Test message") + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT); + assertThat(result.content().get("message")).isEqualTo("Test message"); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(1)) + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThatExceptionOfType(McpError.class).isThrownBy(() -> { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + }).withMessageContaining("Timeout"); + + mcpClient.closeGracefully(); + mcpServer.closeGracefully().block(); + } + // --------------------------------------- // Roots Tests // --------------------------------------- diff --git a/mcp-test/pom.xml b/mcp-test/pom.xml index f1484ae7..9998569d 100644 --- a/mcp-test/pom.xml +++ b/mcp-test/pom.xml @@ -6,7 +6,7 @@ io.modelcontextprotocol.sdk mcp-parent - 0.10.0-SNAPSHOT + 0.11.0-SNAPSHOT mcp-test jar @@ -24,7 +24,7 @@ io.modelcontextprotocol.sdk mcp - 0.10.0-SNAPSHOT + 0.11.0-SNAPSHOT @@ -68,6 +68,11 @@ junit-jupiter ${testcontainers.version} + + org.testcontainers + toxiproxy + ${toxiproxy.version} + org.awaitility diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java new file mode 100644 index 00000000..85d6a88e --- /dev/null +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java @@ -0,0 +1,222 @@ +package io.modelcontextprotocol.client; + +import eu.rekawek.toxiproxy.Proxy; +import eu.rekawek.toxiproxy.ToxiproxyClient; +import eu.rekawek.toxiproxy.model.ToxicDirection; +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpTransport; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.Network; +import org.testcontainers.containers.ToxiproxyContainer; +import org.testcontainers.containers.wait.strategy.Wait; +import reactor.test.StepVerifier; + +import java.io.IOException; +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import java.util.function.Function; + +import static org.assertj.core.api.Assertions.assertThatCode; + +/** + * Resiliency test suite for the {@link McpAsyncClient} that can be used with different + * {@link McpTransport} implementations that support Streamable HTTP. + * + * The purpose of these tests is to allow validating the transport layer resiliency + * instead of the functionality offered by the logical layer of MCP concepts such as + * tools, resources, prompts, etc. + * + * @author Dariusz Jędrzejczyk + */ +public abstract class AbstractMcpAsyncClientResiliencyTests { + + private static final Logger logger = LoggerFactory.getLogger(AbstractMcpAsyncClientResiliencyTests.class); + + static Network network = Network.newNetwork(); + static String host = "http://localhost:3001"; + + // Uses the https://github.com/tzolov/mcp-everything-server-docker-image + @SuppressWarnings("resource") + static GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") + .withCommand("node dist/index.js streamableHttp") + .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) + .withNetwork(network) + .withNetworkAliases("everything-server") + .withExposedPorts(3001) + .waitingFor(Wait.forHttp("/").forStatusCode(404)); + + static ToxiproxyContainer toxiproxy = new ToxiproxyContainer("ghcr.io/shopify/toxiproxy:2.5.0").withNetwork(network) + .withExposedPorts(8474, 3000); + + static Proxy proxy; + + static { + container.start(); + + toxiproxy.start(); + + final ToxiproxyClient toxiproxyClient = new ToxiproxyClient(toxiproxy.getHost(), toxiproxy.getControlPort()); + try { + proxy = toxiproxyClient.createProxy("everything-server", "0.0.0.0:3000", "everything-server:3001"); + } + catch (IOException e) { + throw new RuntimeException("Can't create proxy!", e); + } + + final String ipAddressViaToxiproxy = toxiproxy.getHost(); + final int portViaToxiproxy = toxiproxy.getMappedPort(3000); + + host = "http://" + ipAddressViaToxiproxy + ":" + portViaToxiproxy; + } + + private static void disconnect() { + long start = System.nanoTime(); + try { + // disconnect + // proxy.toxics().bandwidth("CUT_CONNECTION_DOWNSTREAM", + // ToxicDirection.DOWNSTREAM, 0); + // proxy.toxics().bandwidth("CUT_CONNECTION_UPSTREAM", + // ToxicDirection.UPSTREAM, 0); + proxy.toxics().resetPeer("RESET_DOWNSTREAM", ToxicDirection.DOWNSTREAM, 0); + proxy.toxics().resetPeer("RESET_UPSTREAM", ToxicDirection.UPSTREAM, 0); + logger.info("Disconnect took {} ms", Duration.ofNanos(System.nanoTime() - start).toMillis()); + } + catch (IOException e) { + throw new RuntimeException("Failed to disconnect", e); + } + } + + private static void reconnect() { + long start = System.nanoTime(); + try { + proxy.toxics().get("RESET_UPSTREAM").remove(); + proxy.toxics().get("RESET_DOWNSTREAM").remove(); + // proxy.toxics().get("CUT_CONNECTION_DOWNSTREAM").remove(); + // proxy.toxics().get("CUT_CONNECTION_UPSTREAM").remove(); + logger.info("Reconnect took {} ms", Duration.ofNanos(System.nanoTime() - start).toMillis()); + } + catch (IOException e) { + throw new RuntimeException("Failed to reconnect", e); + } + } + + private static void restartMcpServer() { + container.stop(); + container.start(); + } + + abstract McpClientTransport createMcpTransport(); + + protected Duration getRequestTimeout() { + return Duration.ofSeconds(14); + } + + protected Duration getInitializationTimeout() { + return Duration.ofSeconds(2); + } + + McpAsyncClient client(McpClientTransport transport) { + return client(transport, Function.identity()); + } + + McpAsyncClient client(McpClientTransport transport, Function customizer) { + AtomicReference client = new AtomicReference<>(); + + assertThatCode(() -> { + McpClient.AsyncSpec builder = McpClient.async(transport) + .requestTimeout(getRequestTimeout()) + .initializationTimeout(getInitializationTimeout()) + .capabilities(McpSchema.ClientCapabilities.builder().roots(true).build()); + builder = customizer.apply(builder); + client.set(builder.build()); + }).doesNotThrowAnyException(); + + return client.get(); + } + + void withClient(McpClientTransport transport, Consumer c) { + withClient(transport, Function.identity(), c); + } + + void withClient(McpClientTransport transport, Function customizer, + Consumer c) { + var client = client(transport, customizer); + try { + c.accept(client); + } + finally { + StepVerifier.create(client.closeGracefully()).expectComplete().verify(Duration.ofSeconds(10)); + } + } + + @Test + void testPing() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize()).expectNextCount(1).verifyComplete(); + + disconnect(); + + StepVerifier.create(mcpAsyncClient.ping()).expectError().verify(); + + reconnect(); + + StepVerifier.create(mcpAsyncClient.ping()).expectNextCount(1).verifyComplete(); + }); + } + + @Test + void testSessionInvalidation() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize()).expectNextCount(1).verifyComplete(); + + restartMcpServer(); + + // The first try will face the session mismatch exception and the second one + // will go through the re-initialization process. + StepVerifier.create(mcpAsyncClient.ping().retry(1)).expectNextCount(1).verifyComplete(); + }); + } + + @Test + void testCallTool() { + withClient(createMcpTransport(), mcpAsyncClient -> { + AtomicReference> tools = new AtomicReference<>(); + StepVerifier.create(mcpAsyncClient.initialize()).expectNextCount(1).verifyComplete(); + StepVerifier.create(mcpAsyncClient.listTools()) + .consumeNextWith(list -> tools.set(list.tools())) + .verifyComplete(); + + disconnect(); + + String name = tools.get().get(0).name(); + // Assuming this is the echo tool + McpSchema.CallToolRequest request = new McpSchema.CallToolRequest(name, Map.of("message", "hello")); + StepVerifier.create(mcpAsyncClient.callTool(request)).expectError().verify(); + + reconnect(); + + StepVerifier.create(mcpAsyncClient.callTool(request)).expectNextCount(1).verifyComplete(); + }); + } + + @Test + void testSessionClose() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize()).expectNextCount(1).verifyComplete(); + // In case of Streamable HTTP this call should issue a HTTP DELETE request + // invalidating the session + StepVerifier.create(mcpAsyncClient.closeGracefully()).expectComplete().verify(); + // The next use should immediately re-initialize with no issue and send the + // request without any broken connections. + StepVerifier.create(mcpAsyncClient.ping()).expectNextCount(1).verifyComplete(); + }); + } + +} diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java index 5452c8ea..049bea00 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java @@ -110,14 +110,16 @@ void tearDown() { onClose(); } - void verifyInitializationTimeout(Function> operation, String action) { + void verifyNotificationSucceedsWithImplicitInitialization(Function> operation, + String action) { withClient(createMcpTransport(), mcpAsyncClient -> { - StepVerifier.withVirtualTime(() -> operation.apply(mcpAsyncClient)) - .expectSubscription() - .thenAwait(getInitializationTimeout()) - .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before " + action)) - .verify(); + StepVerifier.create(operation.apply(mcpAsyncClient)).verifyComplete(); + }); + } + + void verifyCallSucceedsWithImplicitInitialization(Function> operation, String action) { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(operation.apply(mcpAsyncClient)).expectNextCount(1).verifyComplete(); }); } @@ -133,7 +135,7 @@ void testConstructorWithInvalidArguments() { @Test void testListToolsWithoutInitialization() { - verifyInitializationTimeout(client -> client.listTools(null), "listing tools"); + verifyCallSucceedsWithImplicitInitialization(client -> client.listTools(null), "listing tools"); } @Test @@ -153,7 +155,7 @@ void testListTools() { @Test void testPingWithoutInitialization() { - verifyInitializationTimeout(client -> client.ping(), "pinging the server"); + verifyCallSucceedsWithImplicitInitialization(client -> client.ping(), "pinging the server"); } @Test @@ -168,7 +170,7 @@ void testPing() { @Test void testCallToolWithoutInitialization() { CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", ECHO_TEST_MESSAGE)); - verifyInitializationTimeout(client -> client.callTool(callToolRequest), "calling tools"); + verifyCallSucceedsWithImplicitInitialization(client -> client.callTool(callToolRequest), "calling tools"); } @Test @@ -202,7 +204,7 @@ void testCallToolWithInvalidTool() { @Test void testListResourcesWithoutInitialization() { - verifyInitializationTimeout(client -> client.listResources(null), "listing resources"); + verifyCallSucceedsWithImplicitInitialization(client -> client.listResources(null), "listing resources"); } @Test @@ -233,7 +235,7 @@ void testMcpAsyncClientState() { @Test void testListPromptsWithoutInitialization() { - verifyInitializationTimeout(client -> client.listPrompts(null), "listing " + "prompts"); + verifyCallSucceedsWithImplicitInitialization(client -> client.listPrompts(null), "listing " + "prompts"); } @Test @@ -258,7 +260,7 @@ void testListPrompts() { @Test void testGetPromptWithoutInitialization() { GetPromptRequest request = new GetPromptRequest("simple_prompt", Map.of()); - verifyInitializationTimeout(client -> client.getPrompt(request), "getting " + "prompts"); + verifyCallSucceedsWithImplicitInitialization(client -> client.getPrompt(request), "getting " + "prompts"); } @Test @@ -279,7 +281,7 @@ void testGetPrompt() { @Test void testRootsListChangedWithoutInitialization() { - verifyInitializationTimeout(client -> client.rootsListChangedNotification(), + verifyNotificationSucceedsWithImplicitInitialization(client -> client.rootsListChangedNotification(), "sending roots list changed notification"); } @@ -354,7 +356,8 @@ void testReadResource() { @Test void testListResourceTemplatesWithoutInitialization() { - verifyInitializationTimeout(client -> client.listResourceTemplates(), "listing resource templates"); + verifyCallSucceedsWithImplicitInitialization(client -> client.listResourceTemplates(), + "listing resource templates"); } @Test @@ -447,8 +450,8 @@ void testInitializeWithAllCapabilities() { @Test void testLoggingLevelsWithoutInitialization() { - verifyInitializationTimeout(client -> client.setLoggingLevel(McpSchema.LoggingLevel.DEBUG), - "setting logging level"); + verifyNotificationSucceedsWithImplicitInitialization( + client -> client.setLoggingLevel(McpSchema.LoggingLevel.DEBUG), "setting logging level"); } @Test diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java index 128441f8..3785fd64 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java @@ -5,6 +5,7 @@ package io.modelcontextprotocol.client; import java.time.Duration; +import java.util.List; import java.util.Map; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; @@ -12,7 +13,6 @@ import java.util.function.Function; import io.modelcontextprotocol.spec.McpClientTransport; -import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; @@ -112,33 +112,18 @@ void tearDown() { static final Object DUMMY_RETURN_VALUE = new Object(); - void verifyNotificationTimesOut(Consumer operation, String action) { - verifyCallTimesOut(client -> { + void verifyNotificationSucceedsWithImplicitInitialization(Consumer operation, String action) { + verifyCallSucceedsWithImplicitInitialization(client -> { operation.accept(client); return DUMMY_RETURN_VALUE; }, action); } - void verifyCallTimesOut(Function blockingOperation, String action) { + void verifyCallSucceedsWithImplicitInitialization(Function blockingOperation, String action) { withClient(createMcpTransport(), mcpSyncClient -> { - // This scheduler is not replaced by virtual time scheduler - Scheduler customScheduler = Schedulers.newBoundedElastic(1, 1, "actualBoundedElastic"); - - StepVerifier.withVirtualTime(() -> Mono.fromSupplier(() -> blockingOperation.apply(mcpSyncClient)) - // Offload the blocking call to the real scheduler - .subscribeOn(customScheduler)) - .expectSubscription() - // This works without actually waiting but executes all the - // tasks pending execution on the VirtualTimeScheduler. - // It is possible to execute the blocking code from the operation - // because it is blocked on a dedicated Scheduler and the main - // flow is not blocked and uses the VirtualTimeScheduler. - .thenAwait(getInitializationTimeout()) - .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before " + action)) - .verify(); - - customScheduler.dispose(); + StepVerifier.create(Mono.fromSupplier(() -> blockingOperation.apply(mcpSyncClient))) + .expectNextCount(1) + .verifyComplete(); }); } @@ -154,7 +139,7 @@ void testConstructorWithInvalidArguments() { @Test void testListToolsWithoutInitialization() { - verifyCallTimesOut(client -> client.listTools(null), "listing tools"); + verifyCallSucceedsWithImplicitInitialization(client -> client.listTools(null), "listing tools"); } @Test @@ -175,8 +160,8 @@ void testListTools() { @Test void testCallToolsWithoutInitialization() { - verifyCallTimesOut(client -> client.callTool(new CallToolRequest("add", Map.of("a", 3, "b", 4))), - "calling tools"); + verifyCallSucceedsWithImplicitInitialization( + client -> client.callTool(new CallToolRequest("add", Map.of("a", 3, "b", 4))), "calling tools"); } @Test @@ -200,7 +185,7 @@ void testCallTools() { @Test void testPingWithoutInitialization() { - verifyCallTimesOut(client -> client.ping(), "pinging the server"); + verifyCallSucceedsWithImplicitInitialization(client -> client.ping(), "pinging the server"); } @Test @@ -214,7 +199,7 @@ void testPing() { @Test void testCallToolWithoutInitialization() { CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", TEST_MESSAGE)); - verifyCallTimesOut(client -> client.callTool(callToolRequest), "calling tools"); + verifyCallSucceedsWithImplicitInitialization(client -> client.callTool(callToolRequest), "calling tools"); } @Test @@ -243,7 +228,7 @@ void testCallToolWithInvalidTool() { @Test void testRootsListChangedWithoutInitialization() { - verifyNotificationTimesOut(client -> client.rootsListChangedNotification(), + verifyNotificationSucceedsWithImplicitInitialization(client -> client.rootsListChangedNotification(), "sending roots list changed notification"); } @@ -257,7 +242,7 @@ void testRootsListChanged() { @Test void testListResourcesWithoutInitialization() { - verifyCallTimesOut(client -> client.listResources(null), "listing resources"); + verifyCallSucceedsWithImplicitInitialization(client -> client.listResources(null), "listing resources"); } @Test @@ -333,8 +318,14 @@ void testRemoveNonExistentRoot() { @Test void testReadResourceWithoutInitialization() { - Resource resource = new Resource("test://uri", "Test Resource", null, null, null); - verifyCallTimesOut(client -> client.readResource(resource), "reading resources"); + AtomicReference> resources = new AtomicReference<>(); + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + resources.set(mcpSyncClient.listResources().resources()); + }); + + verifyCallSucceedsWithImplicitInitialization(client -> client.readResource(resources.get().get(0)), + "reading resources"); } @Test @@ -355,7 +346,8 @@ void testReadResource() { @Test void testListResourceTemplatesWithoutInitialization() { - verifyCallTimesOut(client -> client.listResourceTemplates(null), "listing resource templates"); + verifyCallSucceedsWithImplicitInitialization(client -> client.listResourceTemplates(null), + "listing resource templates"); } @Test @@ -413,8 +405,8 @@ void testNotificationHandlers() { @Test void testLoggingLevelsWithoutInitialization() { - verifyNotificationTimesOut(client -> client.setLoggingLevel(McpSchema.LoggingLevel.DEBUG), - "setting logging level"); + verifyNotificationSucceedsWithImplicitInitialization( + client -> client.setLoggingLevel(McpSchema.LoggingLevel.DEBUG), "setting logging level"); } @Test diff --git a/mcp/pom.xml b/mcp/pom.xml index 17693ab3..77343282 100644 --- a/mcp/pom.xml +++ b/mcp/pom.xml @@ -6,7 +6,7 @@ io.modelcontextprotocol.sdk mcp-parent - 0.10.0-SNAPSHOT + 0.11.0-SNAPSHOT mcp jar diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java index e3a997ba..8f0433eb 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java @@ -9,9 +9,9 @@ import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.TimeoutException; -import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; +import java.util.function.Supplier; import com.fasterxml.jackson.core.type.TypeReference; import io.modelcontextprotocol.spec.McpClientSession; @@ -23,6 +23,8 @@ import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.ElicitRequest; +import io.modelcontextprotocol.spec.McpSchema.ElicitResult; import io.modelcontextprotocol.spec.McpSchema.GetPromptRequest; import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; import io.modelcontextprotocol.spec.McpSchema.ListPromptsResult; @@ -30,7 +32,7 @@ import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; import io.modelcontextprotocol.spec.McpSchema.PaginatedRequest; import io.modelcontextprotocol.spec.McpSchema.Root; -import io.modelcontextprotocol.spec.McpTransport; +import io.modelcontextprotocol.spec.McpTransportSessionNotFoundException; import io.modelcontextprotocol.util.Assert; import io.modelcontextprotocol.util.Utils; import org.slf4j.Logger; @@ -75,29 +77,37 @@ * @see McpClient * @see McpSchema * @see McpClientSession + * @see McpClientTransport */ public class McpAsyncClient { private static final Logger logger = LoggerFactory.getLogger(McpAsyncClient.class); - private static TypeReference VOID_TYPE_REFERENCE = new TypeReference<>() { + private static final TypeReference VOID_TYPE_REFERENCE = new TypeReference<>() { }; - protected final Sinks.One initializedSink = Sinks.one(); + public static final TypeReference OBJECT_TYPE_REF = new TypeReference<>() { + }; + + public static final TypeReference PAGINATED_REQUEST_TYPE_REF = new TypeReference<>() { + }; + + public static final TypeReference INITIALIZE_RESULT_TYPE_REF = new TypeReference<>() { + }; - private AtomicBoolean initialized = new AtomicBoolean(false); + public static final TypeReference CREATE_MESSAGE_REQUEST_TYPE_REF = new TypeReference<>() { + }; + + public static final TypeReference LOGGING_MESSAGE_NOTIFICATION_TYPE_REF = new TypeReference<>() { + }; + + private final AtomicReference initializationRef = new AtomicReference<>(); /** * The max timeout to await for the client-server connection to be initialized. */ private final Duration initializationTimeout; - /** - * The MCP session implementation that manages bidirectional JSON-RPC communication - * between clients and servers. - */ - private final McpClientSession mcpSession; - /** * Client capabilities. */ @@ -108,21 +118,6 @@ public class McpAsyncClient { */ private final McpSchema.Implementation clientInfo; - /** - * Server capabilities. - */ - private McpSchema.ServerCapabilities serverCapabilities; - - /** - * Server instructions. - */ - private String serverInstructions; - - /** - * Server implementation information. - */ - private McpSchema.Implementation serverInfo; - /** * Roots define the boundaries of where servers can operate within the filesystem, * allowing them to understand which directories and files they have access to. @@ -141,16 +136,31 @@ public class McpAsyncClient { */ private Function> samplingHandler; + /** + * MCP provides a standardized way for servers to request additional information from + * users through the client during interactions. This flow allows clients to maintain + * control over user interactions and data sharing while enabling servers to gather + * necessary information dynamically. Servers can request structured data from users + * with optional JSON schemas to validate responses. + */ + private Function> elicitationHandler; + /** * Client transport implementation. */ - private final McpTransport transport; + private final McpClientTransport transport; /** * Supported protocol versions. */ private List protocolVersions = List.of(McpSchema.LATEST_PROTOCOL_VERSION); + /** + * The MCP session supplier that manages bidirectional JSON-RPC communication between + * clients and servers. + */ + private final Supplier sessionSupplier; + /** * Create a new McpAsyncClient with the given transport and session request-response * timeout. @@ -189,6 +199,15 @@ public class McpAsyncClient { requestHandlers.put(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE, samplingCreateMessageHandler()); } + // Elicitation Handler + if (this.clientCapabilities.elicitation() != null) { + if (features.elicitationHandler() == null) { + throw new McpError("Elicitation handler must not be null when client capabilities include elicitation"); + } + this.elicitationHandler = features.elicitationHandler(); + requestHandlers.put(McpSchema.METHOD_ELICITATION_CREATE, elicitationCreateHandler()); + } + // Notification Handlers Map notificationHandlers = new HashMap<>(); @@ -234,16 +253,38 @@ public class McpAsyncClient { notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_MESSAGE, asyncLoggingNotificationHandler(loggingConsumersFinal)); - this.mcpSession = new McpClientSession(requestTimeout, transport, requestHandlers, notificationHandlers); + this.transport.setExceptionHandler(this::handleException); + this.sessionSupplier = () -> new McpClientSession(requestTimeout, transport, requestHandlers, + notificationHandlers); } + private void handleException(Throwable t) { + logger.warn("Handling exception", t); + if (t instanceof McpTransportSessionNotFoundException) { + Initialization previous = this.initializationRef.getAndSet(null); + if (previous != null) { + previous.close(); + } + // Providing an empty operation since we are only interested in triggering the + // implicit initialization step. + withSession("re-initializing", result -> Mono.empty()).subscribe(); + } + } + + private McpSchema.InitializeResult currentInitializationResult() { + Initialization current = this.initializationRef.get(); + McpSchema.InitializeResult initializeResult = current != null ? current.result.get() : null; + return initializeResult; + } + /** * Get the server capabilities that define the supported features and functionality. * @return The server capabilities */ public McpSchema.ServerCapabilities getServerCapabilities() { - return this.serverCapabilities; + McpSchema.InitializeResult initializeResult = currentInitializationResult(); + return initializeResult != null ? initializeResult.capabilities() : null; } /** @@ -252,7 +293,8 @@ public McpSchema.ServerCapabilities getServerCapabilities() { * @return The server instructions */ public String getServerInstructions() { - return this.serverInstructions; + McpSchema.InitializeResult initializeResult = currentInitializationResult(); + return initializeResult != null ? initializeResult.instructions() : null; } /** @@ -260,7 +302,8 @@ public String getServerInstructions() { * @return The server implementation details */ public McpSchema.Implementation getServerInfo() { - return this.serverInfo; + McpSchema.InitializeResult initializeResult = currentInitializationResult(); + return initializeResult != null ? initializeResult.serverInfo() : null; } /** @@ -268,7 +311,8 @@ public McpSchema.Implementation getServerInfo() { * @return true if the client-server connection is initialized */ public boolean isInitialized() { - return this.initialized.get(); + Initialization current = this.initializationRef.get(); + return current != null && (current.result.get() != null); } /** @@ -291,7 +335,11 @@ public McpSchema.Implementation getClientInfo() { * Closes the client connection immediately. */ public void close() { - this.mcpSession.close(); + Initialization current = this.initializationRef.getAndSet(null); + if (current != null) { + current.close(); + } + this.transport.close(); } /** @@ -299,14 +347,21 @@ public void close() { * @return A Mono that completes when the connection is closed */ public Mono closeGracefully() { - return this.mcpSession.closeGracefully(); + return Mono.defer(() -> { + Initialization current = this.initializationRef.getAndSet(null); + Mono sessionClose = current != null ? current.closeGracefully() : Mono.empty(); + return sessionClose.then(transport.closeGracefully()); + }); } // -------------------------- // Initialization // -------------------------- /** - * The initialization phase MUST be the first interaction between client and server. + * The initialization phase should be the first interaction between client and server. + * The client will ensure it happens in case it has not been explicitly called and in + * case of transport session invalidation. + *

* During this phase, the client and server: *

    *
  • Establish protocol version compatibility
  • @@ -326,9 +381,13 @@ public Mono closeGracefully() { * @see MCP * Initialization Spec + *

    */ public Mono initialize() { + return withSession("by explicit API call", init -> Mono.just(init.get())); + } + private Mono doInitialize(McpClientSession mcpClientSession) { String latestVersion = this.protocolVersions.get(this.protocolVersions.size() - 1); McpSchema.InitializeRequest initializeRequest = new McpSchema.InitializeRequest(// @formatter:off @@ -336,16 +395,10 @@ public Mono initialize() { this.clientCapabilities, this.clientInfo); // @formatter:on - Mono result = this.mcpSession.sendRequest(McpSchema.METHOD_INITIALIZE, - initializeRequest, new TypeReference() { - }); + Mono result = mcpClientSession.sendRequest(McpSchema.METHOD_INITIALIZE, + initializeRequest, INITIALIZE_RESULT_TYPE_REF); return result.flatMap(initializeResult -> { - - this.serverCapabilities = initializeResult.capabilities(); - this.serverInstructions = initializeResult.instructions(); - this.serverInfo = initializeResult.serverInfo(); - logger.info("Server response with Protocol: {}, Capabilities: {}, Info: {} and Instructions {}", initializeResult.protocolVersion(), initializeResult.capabilities(), initializeResult.serverInfo(), initializeResult.instructions()); @@ -355,28 +408,93 @@ public Mono initialize() { "Unsupported protocol version from the server: " + initializeResult.protocolVersion())); } - return this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_INITIALIZED, null).doOnSuccess(v -> { - this.initialized.set(true); - this.initializedSink.tryEmitValue(initializeResult); - }).thenReturn(initializeResult); + return mcpClientSession.sendNotification(McpSchema.METHOD_NOTIFICATION_INITIALIZED, null) + .thenReturn(initializeResult); }); } + private static class Initialization { + + private final Sinks.One initSink = Sinks.one(); + + private final AtomicReference result = new AtomicReference<>(); + + private final AtomicReference mcpClientSession = new AtomicReference<>(); + + static Initialization create() { + return new Initialization(); + } + + void setMcpClientSession(McpClientSession mcpClientSession) { + this.mcpClientSession.set(mcpClientSession); + } + + McpClientSession mcpSession() { + return this.mcpClientSession.get(); + } + + McpSchema.InitializeResult get() { + return this.result.get(); + } + + Mono await() { + return this.initSink.asMono(); + } + + void complete(McpSchema.InitializeResult initializeResult) { + // first ensure the result is cached + this.result.set(initializeResult); + // inform all the subscribers waiting for the initialization + this.initSink.emitValue(initializeResult, Sinks.EmitFailureHandler.FAIL_FAST); + } + + void error(Throwable t) { + this.initSink.emitError(t, Sinks.EmitFailureHandler.FAIL_FAST); + } + + void close() { + this.mcpSession().close(); + } + + Mono closeGracefully() { + return this.mcpSession().closeGracefully(); + } + + } + /** - * Utility method to handle the common pattern of checking initialization before + * Utility method to handle the common pattern of ensuring initialization before * executing an operation. * @param The type of the result Mono - * @param actionName The action to perform if the client is initialized - * @param operation The operation to execute if the client is initialized + * @param actionName The action to perform when the client is initialized + * @param operation The operation to execute when the client is initialized * @return A Mono that completes with the result of the operation */ - private Mono withInitializationCheck(String actionName, - Function> operation) { - return this.initializedSink.asMono() - .timeout(this.initializationTimeout) - .onErrorResume(TimeoutException.class, - ex -> Mono.error(new McpError("Client must be initialized before " + actionName))) - .flatMap(operation); + private Mono withSession(String actionName, Function> operation) { + return Mono.defer(() -> { + Initialization newInit = Initialization.create(); + Initialization previous = this.initializationRef.compareAndExchange(null, newInit); + + boolean needsToInitialize = previous == null; + logger.debug(needsToInitialize ? "Initialization process started" : "Joining previous initialization"); + if (needsToInitialize) { + newInit.setMcpClientSession(this.sessionSupplier.get()); + } + + Mono initializationJob = needsToInitialize + ? doInitialize(newInit.mcpSession()).doOnNext(newInit::complete).onErrorResume(ex -> { + newInit.error(ex); + return Mono.error(ex); + }) : previous.await(); + + return initializationJob.map(initializeResult -> this.initializationRef.get()) + .timeout(this.initializationTimeout) + .onErrorResume(ex -> { + logger.warn("Failed to initialize", ex); + return Mono.error(new McpError("Client failed to initialize " + actionName)); + }) + .flatMap(operation); + }); } // -------------------------- @@ -388,9 +506,8 @@ private Mono withInitializationCheck(String actionName, * @return A Mono that completes with the server's ping response */ public Mono ping() { - return this.withInitializationCheck("pinging the server", initializedResult -> this.mcpSession - .sendRequest(McpSchema.METHOD_PING, null, new TypeReference() { - })); + return this.withSession("pinging the server", + init -> init.mcpSession().sendRequest(McpSchema.METHOD_PING, null, OBJECT_TYPE_REF)); } // -------------------------- @@ -470,16 +587,14 @@ public Mono removeRoot(String rootUri) { * @return A Mono that completes when the notification is sent. */ public Mono rootsListChangedNotification() { - return this.withInitializationCheck("sending roots list changed notification", - initResult -> this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED)); + return this.withSession("sending roots list changed notification", + init -> init.mcpSession().sendNotification(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED)); } private RequestHandler rootsListRequestHandler() { return params -> { @SuppressWarnings("unused") - McpSchema.PaginatedRequest request = transport.unmarshalFrom(params, - new TypeReference() { - }); + McpSchema.PaginatedRequest request = transport.unmarshalFrom(params, PAGINATED_REQUEST_TYPE_REF); List roots = this.roots.values().stream().toList(); @@ -492,14 +607,24 @@ private RequestHandler rootsListRequestHandler() { // -------------------------- private RequestHandler samplingCreateMessageHandler() { return params -> { - McpSchema.CreateMessageRequest request = transport.unmarshalFrom(params, - new TypeReference() { - }); + McpSchema.CreateMessageRequest request = transport.unmarshalFrom(params, CREATE_MESSAGE_REQUEST_TYPE_REF); return this.samplingHandler.apply(request); }; } + // -------------------------- + // Elicitation + // -------------------------- + private RequestHandler elicitationCreateHandler() { + return params -> { + ElicitRequest request = transport.unmarshalFrom(params, new TypeReference<>() { + }); + + return this.elicitationHandler.apply(request); + }; + } + // -------------------------- // Tools // -------------------------- @@ -521,11 +646,12 @@ private RequestHandler samplingCreateMessageHandler() { * @see #listTools() */ public Mono callTool(McpSchema.CallToolRequest callToolRequest) { - return this.withInitializationCheck("calling tools", initializedResult -> { - if (this.serverCapabilities.tools() == null) { + return this.withSession("calling tools", init -> { + if (init.get().capabilities().tools() == null) { return Mono.error(new McpError("Server does not provide tools capability")); } - return this.mcpSession.sendRequest(McpSchema.METHOD_TOOLS_CALL, callToolRequest, CALL_TOOL_RESULT_TYPE_REF); + return init.mcpSession() + .sendRequest(McpSchema.METHOD_TOOLS_CALL, callToolRequest, CALL_TOOL_RESULT_TYPE_REF); }); } @@ -543,12 +669,13 @@ public Mono listTools() { * @return A Mono that emits the list of tools result */ public Mono listTools(String cursor) { - return this.withInitializationCheck("listing tools", initializedResult -> { - if (this.serverCapabilities.tools() == null) { + return this.withSession("listing tools", init -> { + if (init.get().capabilities().tools() == null) { return Mono.error(new McpError("Server does not provide tools capability")); } - return this.mcpSession.sendRequest(McpSchema.METHOD_TOOLS_LIST, new McpSchema.PaginatedRequest(cursor), - LIST_TOOLS_RESULT_TYPE_REF); + return init.mcpSession() + .sendRequest(McpSchema.METHOD_TOOLS_LIST, new McpSchema.PaginatedRequest(cursor), + LIST_TOOLS_RESULT_TYPE_REF); }); } @@ -600,12 +727,13 @@ public Mono listResources() { * @see #readResource(McpSchema.Resource) */ public Mono listResources(String cursor) { - return this.withInitializationCheck("listing resources", initializedResult -> { - if (this.serverCapabilities.resources() == null) { + return this.withSession("listing resources", init -> { + if (init.get().capabilities().resources() == null) { return Mono.error(new McpError("Server does not provide the resources capability")); } - return this.mcpSession.sendRequest(McpSchema.METHOD_RESOURCES_LIST, new McpSchema.PaginatedRequest(cursor), - LIST_RESOURCES_RESULT_TYPE_REF); + return init.mcpSession() + .sendRequest(McpSchema.METHOD_RESOURCES_LIST, new McpSchema.PaginatedRequest(cursor), + LIST_RESOURCES_RESULT_TYPE_REF); }); } @@ -631,12 +759,12 @@ public Mono readResource(McpSchema.Resource resour * @see McpSchema.ReadResourceResult */ public Mono readResource(McpSchema.ReadResourceRequest readResourceRequest) { - return this.withInitializationCheck("reading resources", initializedResult -> { - if (this.serverCapabilities.resources() == null) { + return this.withSession("reading resources", init -> { + if (init.get().capabilities().resources() == null) { return Mono.error(new McpError("Server does not provide the resources capability")); } - return this.mcpSession.sendRequest(McpSchema.METHOD_RESOURCES_READ, readResourceRequest, - READ_RESOURCE_RESULT_TYPE_REF); + return init.mcpSession() + .sendRequest(McpSchema.METHOD_RESOURCES_READ, readResourceRequest, READ_RESOURCE_RESULT_TYPE_REF); }); } @@ -660,12 +788,13 @@ public Mono listResourceTemplates() { * @see McpSchema.ListResourceTemplatesResult */ public Mono listResourceTemplates(String cursor) { - return this.withInitializationCheck("listing resource templates", initializedResult -> { - if (this.serverCapabilities.resources() == null) { + return this.withSession("listing resource templates", init -> { + if (init.get().capabilities().resources() == null) { return Mono.error(new McpError("Server does not provide the resources capability")); } - return this.mcpSession.sendRequest(McpSchema.METHOD_RESOURCES_TEMPLATES_LIST, - new McpSchema.PaginatedRequest(cursor), LIST_RESOURCE_TEMPLATES_RESULT_TYPE_REF); + return init.mcpSession() + .sendRequest(McpSchema.METHOD_RESOURCES_TEMPLATES_LIST, new McpSchema.PaginatedRequest(cursor), + LIST_RESOURCE_TEMPLATES_RESULT_TYPE_REF); }); } @@ -679,7 +808,7 @@ public Mono listResourceTemplates(String * @see #unsubscribeResource(McpSchema.UnsubscribeRequest) */ public Mono subscribeResource(McpSchema.SubscribeRequest subscribeRequest) { - return this.withInitializationCheck("subscribing to resources", initializedResult -> this.mcpSession + return this.withSession("subscribing to resources", init -> init.mcpSession() .sendRequest(McpSchema.METHOD_RESOURCES_SUBSCRIBE, subscribeRequest, VOID_TYPE_REFERENCE)); } @@ -693,7 +822,7 @@ public Mono subscribeResource(McpSchema.SubscribeRequest subscribeRequest) * @see #subscribeResource(McpSchema.SubscribeRequest) */ public Mono unsubscribeResource(McpSchema.UnsubscribeRequest unsubscribeRequest) { - return this.withInitializationCheck("unsubscribing from resources", initializedResult -> this.mcpSession + return this.withSession("unsubscribing from resources", init -> init.mcpSession() .sendRequest(McpSchema.METHOD_RESOURCES_UNSUBSCRIBE, unsubscribeRequest, VOID_TYPE_REFERENCE)); } @@ -735,7 +864,7 @@ public Mono listPrompts() { * @see #getPrompt(GetPromptRequest) */ public Mono listPrompts(String cursor) { - return this.withInitializationCheck("listing prompts", initializedResult -> this.mcpSession + return this.withSession("listing prompts", init -> init.mcpSession() .sendRequest(McpSchema.METHOD_PROMPT_LIST, new PaginatedRequest(cursor), LIST_PROMPTS_RESULT_TYPE_REF)); } @@ -749,7 +878,7 @@ public Mono listPrompts(String cursor) { * @see #listPrompts() */ public Mono getPrompt(GetPromptRequest getPromptRequest) { - return this.withInitializationCheck("getting prompts", initializedResult -> this.mcpSession + return this.withSession("getting prompts", init -> init.mcpSession() .sendRequest(McpSchema.METHOD_PROMPT_GET, getPromptRequest, GET_PROMPT_RESULT_TYPE_REF)); } @@ -780,8 +909,7 @@ private NotificationHandler asyncLoggingNotificationHandler( return params -> { McpSchema.LoggingMessageNotification loggingMessageNotification = transport.unmarshalFrom(params, - new TypeReference() { - }); + LOGGING_MESSAGE_NOTIFICATION_TYPE_REF); return Flux.fromIterable(loggingConsumers) .flatMap(consumer -> consumer.apply(loggingMessageNotification)) @@ -801,10 +929,9 @@ public Mono setLoggingLevel(LoggingLevel loggingLevel) { return Mono.error(new McpError("Logging level must not be null")); } - return this.withInitializationCheck("setting logging level", initializedResult -> { + return this.withSession("setting logging level", init -> { var params = new McpSchema.SetLevelRequest(loggingLevel); - return this.mcpSession.sendRequest(McpSchema.METHOD_LOGGING_SET_LEVEL, params, new TypeReference() { - }).then(); + return init.mcpSession().sendRequest(McpSchema.METHOD_LOGGING_SET_LEVEL, params, OBJECT_TYPE_REF).then(); }); } @@ -834,7 +961,7 @@ void setProtocolVersions(List protocolVersions) { * @see McpSchema.CompleteResult */ public Mono completeCompletion(McpSchema.CompleteRequest completeRequest) { - return this.withInitializationCheck("complete completions", initializedResult -> this.mcpSession + return this.withSession("complete completions", init -> init.mcpSession() .sendRequest(McpSchema.METHOD_COMPLETION_COMPLETE, completeRequest, COMPLETION_COMPLETE_RESULT_TYPE_REF)); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java index a1dc1168..280906cf 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java @@ -18,6 +18,8 @@ import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.ElicitRequest; +import io.modelcontextprotocol.spec.McpSchema.ElicitResult; import io.modelcontextprotocol.spec.McpSchema.Implementation; import io.modelcontextprotocol.spec.McpSchema.Root; import io.modelcontextprotocol.util.Assert; @@ -175,6 +177,8 @@ class SyncSpec { private Function samplingHandler; + private Function elicitationHandler; + private SyncSpec(McpClientTransport transport) { Assert.notNull(transport, "Transport must not be null"); this.transport = transport; @@ -283,6 +287,21 @@ public SyncSpec sampling(Function sam return this; } + /** + * Sets a custom elicitation handler for processing elicitation message requests. + * The elicitation handler can modify or validate messages before they are sent to + * the server, enabling custom processing logic. + * @param elicitationHandler A function that processes elicitation requests and + * returns results. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if elicitationHandler is null + */ + public SyncSpec elicitation(Function elicitationHandler) { + Assert.notNull(elicitationHandler, "Elicitation handler must not be null"); + this.elicitationHandler = elicitationHandler; + return this; + } + /** * Adds a consumer to be notified when the available tools change. This allows the * client to react to changes in the server's tool capabilities, such as tools @@ -364,7 +383,7 @@ public SyncSpec loggingConsumers(List> samplingHandler; + private Function> elicitationHandler; + private AsyncSpec(McpClientTransport transport) { Assert.notNull(transport, "Transport must not be null"); this.transport = transport; @@ -522,6 +543,21 @@ public AsyncSpec sampling(Function> elicitationHandler) { + Assert.notNull(elicitationHandler, "Elicitation handler must not be null"); + this.elicitationHandler = elicitationHandler; + return this; + } + /** * Adds a consumer to be notified when the available tools change. This allows the * client to react to changes in the server's tool capabilities, such as tools @@ -606,7 +642,7 @@ public McpAsyncClient build() { return new McpAsyncClient(this.transport, this.requestTimeout, this.initializationTimeout, new McpClientFeatures.Async(this.clientInfo, this.capabilities, this.roots, this.toolsChangeConsumers, this.resourcesChangeConsumers, this.promptsChangeConsumers, - this.loggingConsumers, this.samplingHandler)); + this.loggingConsumers, this.samplingHandler, this.elicitationHandler)); } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java index 284b93f8..23d7c6a6 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java @@ -60,13 +60,15 @@ class McpClientFeatures { * @param promptsChangeConsumers the prompts change consumers. * @param loggingConsumers the logging consumers. * @param samplingHandler the sampling handler. + * @param elicitationHandler the elicitation handler. */ record Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities clientCapabilities, Map roots, List, Mono>> toolsChangeConsumers, List, Mono>> resourcesChangeConsumers, List, Mono>> promptsChangeConsumers, List>> loggingConsumers, - Function> samplingHandler) { + Function> samplingHandler, + Function> elicitationHandler) { /** * Create an instance and validate the arguments. @@ -77,6 +79,7 @@ record Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c * @param promptsChangeConsumers the prompts change consumers. * @param loggingConsumers the logging consumers. * @param samplingHandler the sampling handler. + * @param elicitationHandler the elicitation handler. */ public Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities clientCapabilities, Map roots, @@ -84,14 +87,16 @@ public Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c List, Mono>> resourcesChangeConsumers, List, Mono>> promptsChangeConsumers, List>> loggingConsumers, - Function> samplingHandler) { + Function> samplingHandler, + Function> elicitationHandler) { Assert.notNull(clientInfo, "Client info must not be null"); this.clientInfo = clientInfo; this.clientCapabilities = (clientCapabilities != null) ? clientCapabilities : new McpSchema.ClientCapabilities(null, !Utils.isEmpty(roots) ? new McpSchema.ClientCapabilities.RootCapabilities(false) : null, - samplingHandler != null ? new McpSchema.ClientCapabilities.Sampling() : null); + samplingHandler != null ? new McpSchema.ClientCapabilities.Sampling() : null, + elicitationHandler != null ? new McpSchema.ClientCapabilities.Elicitation() : null); this.roots = roots != null ? new ConcurrentHashMap<>(roots) : new ConcurrentHashMap<>(); this.toolsChangeConsumers = toolsChangeConsumers != null ? toolsChangeConsumers : List.of(); @@ -99,6 +104,7 @@ public Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c this.promptsChangeConsumers = promptsChangeConsumers != null ? promptsChangeConsumers : List.of(); this.loggingConsumers = loggingConsumers != null ? loggingConsumers : List.of(); this.samplingHandler = samplingHandler; + this.elicitationHandler = elicitationHandler; } /** @@ -138,9 +144,14 @@ public static Async fromSync(Sync syncSpec) { Function> samplingHandler = r -> Mono .fromCallable(() -> syncSpec.samplingHandler().apply(r)) .subscribeOn(Schedulers.boundedElastic()); + + Function> elicitationHandler = r -> Mono + .fromCallable(() -> syncSpec.elicitationHandler().apply(r)) + .subscribeOn(Schedulers.boundedElastic()); + return new Async(syncSpec.clientInfo(), syncSpec.clientCapabilities(), syncSpec.roots(), toolsChangeConsumers, resourcesChangeConsumers, promptsChangeConsumers, loggingConsumers, - samplingHandler); + samplingHandler, elicitationHandler); } } @@ -156,13 +167,15 @@ public static Async fromSync(Sync syncSpec) { * @param promptsChangeConsumers the prompts change consumers. * @param loggingConsumers the logging consumers. * @param samplingHandler the sampling handler. + * @param elicitationHandler the elicitation handler. */ public record Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities clientCapabilities, Map roots, List>> toolsChangeConsumers, List>> resourcesChangeConsumers, List>> promptsChangeConsumers, List> loggingConsumers, - Function samplingHandler) { + Function samplingHandler, + Function elicitationHandler) { /** * Create an instance and validate the arguments. @@ -174,20 +187,23 @@ public record Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabili * @param promptsChangeConsumers the prompts change consumers. * @param loggingConsumers the logging consumers. * @param samplingHandler the sampling handler. + * @param elicitationHandler the elicitation handler. */ public Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities clientCapabilities, Map roots, List>> toolsChangeConsumers, List>> resourcesChangeConsumers, List>> promptsChangeConsumers, List> loggingConsumers, - Function samplingHandler) { + Function samplingHandler, + Function elicitationHandler) { Assert.notNull(clientInfo, "Client info must not be null"); this.clientInfo = clientInfo; this.clientCapabilities = (clientCapabilities != null) ? clientCapabilities : new McpSchema.ClientCapabilities(null, !Utils.isEmpty(roots) ? new McpSchema.ClientCapabilities.RootCapabilities(false) : null, - samplingHandler != null ? new McpSchema.ClientCapabilities.Sampling() : null); + samplingHandler != null ? new McpSchema.ClientCapabilities.Sampling() : null, + elicitationHandler != null ? new McpSchema.ClientCapabilities.Elicitation() : null); this.roots = roots != null ? new HashMap<>(roots) : new HashMap<>(); this.toolsChangeConsumers = toolsChangeConsumers != null ? toolsChangeConsumers : List.of(); @@ -195,6 +211,7 @@ public Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities cl this.promptsChangeConsumers = promptsChangeConsumers != null ? promptsChangeConsumers : List.of(); this.loggingConsumers = loggingConsumers != null ? loggingConsumers : List.of(); this.samplingHandler = samplingHandler; + this.elicitationHandler = elicitationHandler; } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java index 9d71cbb4..246ff11c 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java @@ -112,6 +112,7 @@ public StdioClientTransport(ServerParameters params, ObjectMapper objectMapper) @Override public Mono connect(Function, Mono> handler) { return Mono.fromRunnable(() -> { + logger.info("MCP server starting."); handleIncomingMessages(handler); handleIncomingErrors(); @@ -142,6 +143,7 @@ public Mono connect(Function, Mono> h startInboundProcessing(); startOutboundProcessing(); startErrorProcessing(); + logger.info("MCP server started"); }).subscribeOn(Schedulers.boundedElastic()); } @@ -366,6 +368,9 @@ public Mono closeGracefully() { if (process.exitValue() != 0) { logger.warn("Process terminated with code " + process.exitValue()); } + else { + logger.info("MCP server process stopped"); + } }).then(Mono.fromRunnable(() -> { try { // The Threads are blocked on readLine so disposeGracefully would not diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java index 889dc66d..cfb07d26 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java @@ -36,6 +36,9 @@ public class McpAsyncServerExchange { private static final TypeReference LIST_ROOTS_RESULT_TYPE_REF = new TypeReference<>() { }; + private static final TypeReference ELICITATION_RESULT_TYPE_REF = new TypeReference<>() { + }; + /** * Create a new asynchronous exchange with the client. * @param session The server session representing a 1-1 interaction. @@ -93,6 +96,31 @@ public Mono createMessage(McpSchema.CreateMessage CREATE_MESSAGE_RESULT_TYPE_REF); } + /** + * Creates a new elicitation. MCP provides a standardized way for servers to request + * additional information from users through the client during interactions. This flow + * allows clients to maintain control over user interactions and data sharing while + * enabling servers to gather necessary information dynamically. Servers can request + * structured data from users with optional JSON schemas to validate responses. + * @param elicitRequest The request to create a new elicitation + * @return A Mono that completes when the elicitation has been resolved. + * @see McpSchema.ElicitRequest + * @see McpSchema.ElicitResult + * @see Elicitation + * Specification + */ + public Mono createElicitation(McpSchema.ElicitRequest elicitRequest) { + if (this.clientCapabilities == null) { + return Mono.error(new McpError("Client must be initialized. Call the initialize method first!")); + } + if (this.clientCapabilities.elicitation() == null) { + return Mono.error(new McpError("Client must be configured with elicitation capabilities")); + } + return this.session.sendRequest(McpSchema.METHOD_ELICITATION_CREATE, elicitRequest, + ELICITATION_RESULT_TYPE_REF); + } + /** * Retrieves the list of all roots provided by the client. * @return A Mono that emits the list of roots result. diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java index 52360e54..084412b9 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java @@ -64,6 +64,24 @@ public McpSchema.CreateMessageResult createMessage(McpSchema.CreateMessageReques return this.exchange.createMessage(createMessageRequest).block(); } + /** + * Creates a new elicitation. MCP provides a standardized way for servers to request + * additional information from users through the client during interactions. This flow + * allows clients to maintain control over user interactions and data sharing while + * enabling servers to gather necessary information dynamically. Servers can request + * structured data from users with optional JSON schemas to validate responses. + * @param elicitRequest The request to create a new elicitation + * @return A result containing the elicitation response. + * @see McpSchema.ElicitRequest + * @see McpSchema.ElicitResult + * @see Elicitation + * Specification + */ + public McpSchema.ElicitResult createElicitation(McpSchema.ElicitRequest elicitRequest) { + return this.exchange.createElicitation(elicitRequest).block(); + } + /** * Retrieves the list of all roots provided by the client. * @return The list of roots result. diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportSession.java new file mode 100644 index 00000000..d06d5b32 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportSession.java @@ -0,0 +1,79 @@ +package io.modelcontextprotocol.spec; + +import org.reactivestreams.Publisher; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.Disposable; +import reactor.core.Disposables; +import reactor.core.publisher.Mono; + +import java.util.Optional; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; + +/** + * Default implementation of {@link McpTransportSession} which manages the open + * connections using tye {@link Disposable} type and allows to perform clean up using the + * {@link Disposable#dispose()} method. + * + * @author Dariusz Jędrzejczyk + */ +public class DefaultMcpTransportSession implements McpTransportSession { + + private static final Logger logger = LoggerFactory.getLogger(DefaultMcpTransportSession.class); + + private final Disposable.Composite openConnections = Disposables.composite(); + + private final AtomicBoolean initialized = new AtomicBoolean(false); + + private final AtomicReference sessionId = new AtomicReference<>(); + + private final Supplier> onClose; + + public DefaultMcpTransportSession(Supplier> onClose) { + this.onClose = onClose; + } + + @Override + public Optional sessionId() { + return Optional.ofNullable(this.sessionId.get()); + } + + @Override + public boolean markInitialized(String sessionId) { + boolean flipped = this.initialized.compareAndSet(false, true); + if (flipped) { + this.sessionId.set(sessionId); + logger.debug("Established session with id {}", sessionId); + } + else { + if (sessionId != null && !sessionId.equals(this.sessionId.get())) { + logger.warn("Different session id provided in response. Expecting {} but server returned {}", + this.sessionId.get(), sessionId); + } + } + return flipped; + } + + @Override + public void addConnection(Disposable connection) { + this.openConnections.add(connection); + } + + @Override + public void removeConnection(Disposable connection) { + this.openConnections.remove(connection); + } + + @Override + public void close() { + this.closeGracefully().subscribe(); + } + + @Override + public Mono closeGracefully() { + return Mono.from(this.onClose.get()).then(Mono.fromRunnable(this.openConnections::dispose)); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportStream.java b/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportStream.java new file mode 100644 index 00000000..ecc6f866 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportStream.java @@ -0,0 +1,74 @@ +package io.modelcontextprotocol.spec; + +import org.reactivestreams.Publisher; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.util.function.Tuple2; + +import java.util.Optional; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; + +/** + * An implementation of {@link McpTransportStream} using Project Reactor types. + * + * @param the resource serving the stream + * @author Dariusz Jędrzejczyk + */ +public class DefaultMcpTransportStream implements McpTransportStream { + + private static final Logger logger = LoggerFactory.getLogger(DefaultMcpTransportStream.class); + + private static final AtomicLong counter = new AtomicLong(); + + private final AtomicReference lastId = new AtomicReference<>(); + + // Used only for internal accounting + private final long streamId; + + private final boolean resumable; + + private final Function, Publisher> reconnect; + + /** + * Constructs a new instance representing a particular stream that can resume using + * the provided reconnect mechanism. + * @param resumable whether the stream is resumable and should try to reconnect + * @param reconnect the mechanism to use in case an error is observed on the current + * event stream to asynchronously kick off a resumed stream consumption, potentially + * using the stored {@link #lastId()}. + */ + public DefaultMcpTransportStream(boolean resumable, + Function, Publisher> reconnect) { + this.reconnect = reconnect; + this.streamId = counter.getAndIncrement(); + this.resumable = resumable; + } + + @Override + public Optional lastId() { + return Optional.ofNullable(this.lastId.get()); + } + + @Override + public long streamId() { + return this.streamId; + } + + @Override + public Publisher consumeSseStream( + Publisher, Iterable>> eventStream) { + return Flux.deferContextual(ctx -> Flux.from(eventStream).doOnError(e -> { + if (resumable && !(e instanceof McpTransportSessionNotFoundException)) { + Mono.from(reconnect.apply(this)).contextWrite(ctx).subscribe(); + } + }).doOnNext(idAndMessage -> idAndMessage.getT1().ifPresent(id -> { + String previousId = this.lastId.getAndSet(id); + logger.debug("Updating last id {} -> {} for stream {}", previousId, id, this.streamId); + })).flatMapIterable(Tuple2::getT2)); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java index f577b493..c8399240 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java @@ -4,20 +4,19 @@ package io.modelcontextprotocol.spec; -import java.time.Duration; -import java.util.Map; -import java.util.UUID; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.atomic.AtomicLong; - import com.fasterxml.jackson.core.type.TypeReference; import io.modelcontextprotocol.util.Assert; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import reactor.core.Disposable; import reactor.core.publisher.Mono; import reactor.core.publisher.MonoSink; +import java.time.Duration; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicLong; + /** * Default implementation of the MCP (Model Context Protocol) session that manages * bidirectional JSON-RPC communication between clients and servers. This implementation @@ -37,7 +36,6 @@ */ public class McpClientSession implements McpSession { - /** Logger for this class */ private static final Logger logger = LoggerFactory.getLogger(McpClientSession.class); /** Duration to wait for request responses before timing out */ @@ -61,8 +59,6 @@ public class McpClientSession implements McpSession { /** Atomic counter for generating unique request IDs */ private final AtomicLong requestCounter = new AtomicLong(0); - private final Disposable connection; - /** * Functional interface for handling incoming JSON-RPC requests. Implementations * should process the request parameters and return a response. @@ -117,12 +113,15 @@ public McpClientSession(Duration requestTimeout, McpClientTransport transport, this.requestHandlers.putAll(requestHandlers); this.notificationHandlers.putAll(notificationHandlers); - // TODO: consider mono.transformDeferredContextual where the Context contains - // the - // Observation associated with the individual message - it can be used to - // create child Observation and emit it together with the message to the - // consumer - this.connection = this.transport.connect(mono -> mono.doOnNext(this::handle)).subscribe(); + this.transport.connect(mono -> mono.doOnNext(this::handle)).subscribe(); + } + + private void dismissPendingResponses() { + this.pendingResponses.forEach((id, sink) -> { + logger.warn("Abruptly terminating exchange for request {}", id); + sink.error(new RuntimeException("MCP session with server terminated")); + }); + this.pendingResponses.clear(); } private void handle(McpSchema.JSONRPCMessage message) { @@ -231,17 +230,15 @@ public Mono sendRequest(String method, Object requestParams, TypeReferenc String requestId = this.generateRequestId(); return Mono.deferContextual(ctx -> Mono.create(sink -> { + logger.debug("Sending message for method {}", method); this.pendingResponses.put(requestId, sink); McpSchema.JSONRPCRequest jsonrpcRequest = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, method, requestId, requestParams); - this.transport.sendMessage(jsonrpcRequest) - .contextWrite(ctx) - // TODO: It's most efficient to create a dedicated Subscriber here - .subscribe(v -> { - }, error -> { - this.pendingResponses.remove(requestId); - sink.error(error); - }); + this.transport.sendMessage(jsonrpcRequest).contextWrite(ctx).subscribe(v -> { + }, error -> { + this.pendingResponses.remove(requestId); + sink.error(error); + }); })).timeout(this.requestTimeout).handle((jsonRpcResponse, sink) -> { if (jsonRpcResponse.error() != null) { logger.error("Error handling request: {}", jsonRpcResponse.error()); @@ -277,10 +274,7 @@ public Mono sendNotification(String method, Object params) { */ @Override public Mono closeGracefully() { - return Mono.defer(() -> { - this.connection.dispose(); - return transport.closeGracefully(); - }); + return Mono.fromRunnable(this::dismissPendingResponses); } /** @@ -288,8 +282,7 @@ public Mono closeGracefully() { */ @Override public void close() { - this.connection.dispose(); - transport.close(); + dismissPendingResponses(); } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java index f2909124..5c3b3313 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java @@ -3,18 +3,38 @@ */ package io.modelcontextprotocol.spec; +import java.util.function.Consumer; import java.util.function.Function; import reactor.core.publisher.Mono; /** - * Marker interface for the client-side MCP transport. + * Interface for the client side of the {@link McpTransport}. It allows setting handlers + * for messages that are incoming from the MCP server and hooking in to exceptions raised + * on the transport layer. * * @author Christian Tzolov * @author Dariusz Jędrzejczyk */ public interface McpClientTransport extends McpTransport { + /** + * Used to register the incoming messages' handler and potentially (eagerly) connect + * to the server. + * @param handler a transformer for incoming messages + * @return a {@link Mono} that terminates upon successful client setup. It can mean + * establishing a connection (which can be later disposed) but it doesn't have to, + * depending on the transport type. The successful termination of the returned + * {@link Mono} simply means the client can now be used. An error can be retried + * according to the application requirements. + */ Mono connect(Function, Mono> handler); + /** + * Sets the exception handler for exceptions raised on the transport layer. + * @param handler Allows reacting to transport level exceptions by the higher layers + */ + default void setExceptionHandler(Consumer handler) { + } + } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java index 8df8a158..e21d53c8 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java @@ -94,6 +94,9 @@ private McpSchema() { // Sampling Methods public static final String METHOD_SAMPLING_CREATE_MESSAGE = "sampling/createMessage"; + // Elicitation Methods + public static final String METHOD_ELICITATION_CREATE = "elicitation/create"; + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); // --------------------------- @@ -131,8 +134,8 @@ public static final class ErrorCodes { } - public sealed interface Request - permits InitializeRequest, CallToolRequest, CreateMessageRequest, CompleteRequest, GetPromptRequest { + public sealed interface Request permits InitializeRequest, CallToolRequest, CreateMessageRequest, ElicitRequest, + CompleteRequest, GetPromptRequest { } @@ -181,6 +184,8 @@ public sealed interface JSONRPCMessage permits JSONRPCRequest, JSONRPCNotificati @JsonInclude(JsonInclude.Include.NON_ABSENT) @JsonIgnoreProperties(ignoreUnknown = true) + // TODO: batching support + // @JsonFormat(with = JsonFormat.Feature.ACCEPT_SINGLE_VALUE_AS_ARRAY) public record JSONRPCRequest( // @formatter:off @JsonProperty("jsonrpc") String jsonrpc, @JsonProperty("method") String method, @@ -190,6 +195,8 @@ public record JSONRPCRequest( // @formatter:off @JsonInclude(JsonInclude.Include.NON_ABSENT) @JsonIgnoreProperties(ignoreUnknown = true) + // TODO: batching support + // @JsonFormat(with = JsonFormat.Feature.ACCEPT_SINGLE_VALUE_AS_ARRAY) public record JSONRPCNotification( // @formatter:off @JsonProperty("jsonrpc") String jsonrpc, @JsonProperty("method") String method, @@ -198,6 +205,8 @@ public record JSONRPCNotification( // @formatter:off @JsonInclude(JsonInclude.Include.NON_ABSENT) @JsonIgnoreProperties(ignoreUnknown = true) + // TODO: batching support + // @JsonFormat(with = JsonFormat.Feature.ACCEPT_SINGLE_VALUE_AS_ARRAY) public record JSONRPCResponse( // @formatter:off @JsonProperty("jsonrpc") String jsonrpc, @JsonProperty("id") Object id, @@ -221,7 +230,7 @@ public record JSONRPCError( public record InitializeRequest( // @formatter:off @JsonProperty("protocolVersion") String protocolVersion, @JsonProperty("capabilities") ClientCapabilities capabilities, - @JsonProperty("clientInfo") Implementation clientInfo) implements Request { + @JsonProperty("clientInfo") Implementation clientInfo) implements Request { } // @formatter:on @JsonInclude(JsonInclude.Include.NON_ABSENT) @@ -245,6 +254,8 @@ public record InitializeResult( // @formatter:off * access to. * @param sampling Provides a standardized way for servers to request LLM sampling * (“completions” or “generations”) from language models via clients. + * @param elicitation Provides a standardized way for servers to request additional + * information from users through the client during interactions. * */ @JsonInclude(JsonInclude.Include.NON_ABSENT) @@ -252,7 +263,8 @@ public record InitializeResult( // @formatter:off public record ClientCapabilities( // @formatter:off @JsonProperty("experimental") Map experimental, @JsonProperty("roots") RootCapabilities roots, - @JsonProperty("sampling") Sampling sampling) { + @JsonProperty("sampling") Sampling sampling, + @JsonProperty("elicitation") Elicitation elicitation) { /** * Roots define the boundaries of where servers can operate within the filesystem, @@ -264,7 +276,7 @@ public record ClientCapabilities( // @formatter:off * has changed since the last time the server checked. */ @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) + @JsonIgnoreProperties(ignoreUnknown = true) public record RootCapabilities( @JsonProperty("listChanged") Boolean listChanged) { } @@ -279,10 +291,22 @@ public record RootCapabilities( * image-based interactions and optionally include context * from MCP servers in their prompts. */ - @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonInclude(JsonInclude.Include.NON_ABSENT) public record Sampling() { } + /** + * Provides a standardized way for servers to request additional + * information from users through the client during interactions. + * This flow allows clients to maintain control over user + * interactions and data sharing while enabling servers to gather + * necessary information dynamically. Servers can request structured + * data from users with optional JSON schemas to validate responses. + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + public record Elicitation() { + } + public static Builder builder() { return new Builder(); } @@ -291,6 +315,7 @@ public static class Builder { private Map experimental; private RootCapabilities roots; private Sampling sampling; + private Elicitation elicitation; public Builder experimental(Map experimental) { this.experimental = experimental; @@ -307,8 +332,13 @@ public Builder sampling() { return this; } + public Builder elicitation() { + this.elicitation = new Elicitation(); + return this; + } + public ClientCapabilities build() { - return new ClientCapabilities(experimental, roots, sampling); + return new ClientCapabilities(experimental, roots, sampling, elicitation); } } }// @formatter:on @@ -326,11 +356,11 @@ public record ServerCapabilities( // @formatter:off @JsonInclude(JsonInclude.Include.NON_ABSENT) public record CompletionCapabilities() { } - + @JsonInclude(JsonInclude.Include.NON_ABSENT) public record LoggingCapabilities() { } - + @JsonInclude(JsonInclude.Include.NON_ABSENT) public record PromptCapabilities( @JsonProperty("listChanged") Boolean listChanged) { @@ -727,11 +757,11 @@ public record Tool( // @formatter:off @JsonProperty("name") String name, @JsonProperty("description") String description, @JsonProperty("inputSchema") JsonSchema inputSchema) { - + public Tool(String name, String description, String schema) { this(name, description, parseSchema(schema)); } - + } // @formatter:on private static JsonSchema parseSchema(String schema) { @@ -758,7 +788,7 @@ public record CallToolRequest(// @formatter:off @JsonProperty("arguments") Map arguments) implements Request { public CallToolRequest(String name, String jsonArguments) { - this(name, parseJsonArguments(jsonArguments)); + this(name, parseJsonArguments(jsonArguments)); } private static Map parseJsonArguments(String jsonArguments) { @@ -893,7 +923,7 @@ public record ModelPreferences(// @formatter:off @JsonProperty("costPriority") Double costPriority, @JsonProperty("speedPriority") Double speedPriority, @JsonProperty("intelligencePriority") Double intelligencePriority) { - + public static Builder builder() { return new Builder(); } @@ -963,7 +993,7 @@ public record CreateMessageRequest(// @formatter:off @JsonProperty("includeContext") ContextInclusionStrategy includeContext, @JsonProperty("temperature") Double temperature, @JsonProperty("maxTokens") int maxTokens, - @JsonProperty("stopSequences") List stopSequences, + @JsonProperty("stopSequences") List stopSequences, @JsonProperty("metadata") Map metadata) implements Request { public enum ContextInclusionStrategy { @@ -971,7 +1001,7 @@ public enum ContextInclusionStrategy { @JsonProperty("thisServer") THIS_SERVER, @JsonProperty("allServers") ALL_SERVERS } - + public static Builder builder() { return new Builder(); } @@ -1040,7 +1070,7 @@ public record CreateMessageResult(// @formatter:off @JsonProperty("content") Content content, @JsonProperty("model") String model, @JsonProperty("stopReason") StopReason stopReason) { - + public enum StopReason { @JsonProperty("endTurn") END_TURN, @JsonProperty("stopSequence") STOP_SEQUENCE, @@ -1088,6 +1118,79 @@ public CreateMessageResult build() { } }// @formatter:on + // Elicitation + /** + * Used by the server to send an elicitation to the client. + * + * @param message The body of the elicitation message. + * @param requestedSchema The elicitation response schema that must be satisfied. + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ElicitRequest(// @formatter:off + @JsonProperty("message") String message, + @JsonProperty("requestedSchema") Map requestedSchema) implements Request { + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + private String message; + private Map requestedSchema; + + public Builder message(String message) { + this.message = message; + return this; + } + + public Builder requestedSchema(Map requestedSchema) { + this.requestedSchema = requestedSchema; + return this; + } + + public ElicitRequest build() { + return new ElicitRequest(message, requestedSchema); + } + } + }// @formatter:on + + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ElicitResult(// @formatter:off + @JsonProperty("action") Action action, + @JsonProperty("content") Map content) { + + public enum Action { + @JsonProperty("accept") ACCEPT, + @JsonProperty("decline") DECLINE, + @JsonProperty("cancel") CANCEL + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + private Action action; + private Map content; + + public Builder message(Action action) { + this.action = action; + return this; + } + + public Builder content(Map content) { + this.content = content; + return this; + } + + public ElicitResult build() { + return new ElicitResult(action, content); + } + } + }// @formatter:on + // --------------------------- // Pagination Interfaces // --------------------------- diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportSession.java new file mode 100644 index 00000000..555f018f --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportSession.java @@ -0,0 +1,60 @@ +package io.modelcontextprotocol.spec; + +import org.reactivestreams.Publisher; + +import java.util.Optional; + +/** + * An abstraction of the session as perceived from the MCP transport layer. Not to be + * confused with the {@link McpSession} type that operates at the level of the JSON-RPC + * communication protocol and matches asynchronous responses with previously issued + * requests. + * + * @param the resource representing the connection that the transport + * manages. + * @author Dariusz Jędrzejczyk + */ +public interface McpTransportSession { + + /** + * In case of stateful MCP servers, the value is present and contains the String + * identifier for the transport-level session. + * @return optional session id + */ + Optional sessionId(); + + /** + * Stateful operation that flips the un-initialized state to initialized if this is + * the first call. If the transport provides a session id for the communication, + * argument should not be null to record the current identifier. + * @param sessionId session identifier as provided by the server + * @return if successful, this method returns {@code true} and means that a + * post-initialization step can be performed + */ + boolean markInitialized(String sessionId); + + /** + * Adds a resource that this transport session can monitor and dismiss when needed. + * @param connection the managed resource + */ + void addConnection(CONNECTION connection); + + /** + * Called when the resource is terminating by itself and the transport session does + * not need to track it anymore. + * @param connection the resource to remove from the monitored collection + */ + void removeConnection(CONNECTION connection); + + /** + * Close and clear the monitored resources. Potentially asynchronous. + */ + void close(); + + /** + * Close and clear the monitored resources in a graceful manner. + * @return completes once all resources have been dismissed + */ + Publisher closeGracefully(); + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportSessionNotFoundException.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportSessionNotFoundException.java new file mode 100644 index 00000000..474a18ae --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportSessionNotFoundException.java @@ -0,0 +1,29 @@ +package io.modelcontextprotocol.spec; + +/** + * Exception that signifies that the server does not recognize the connecting client via + * the presented transport session identifier. + * + * @author Dariusz Jędrzejczyk + */ +public class McpTransportSessionNotFoundException extends RuntimeException { + + /** + * Construct an instance with a known {@link Exception cause}. + * @param sessionId transport session identifier + * @param cause the cause that was identified as a session not found error + */ + public McpTransportSessionNotFoundException(String sessionId, Exception cause) { + super("Session " + sessionId + " not found on the server", cause); + } + + /** + * Construct an instance with the session identifier but without a {@link Exception + * cause}. + * @param sessionId transport session identifier + */ + public McpTransportSessionNotFoundException(String sessionId) { + super("Session " + sessionId + " not found on the server"); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportStream.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportStream.java new file mode 100644 index 00000000..2d6dcce7 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportStream.java @@ -0,0 +1,45 @@ +package io.modelcontextprotocol.spec; + +import org.reactivestreams.Publisher; +import reactor.util.function.Tuple2; + +import java.util.Optional; + +/** + * A representation of a stream at the transport layer of the MCP protocol. In particular, + * it is currently used in the Streamable HTTP implementation to potentially be able to + * resume a broken connection from where it left off by optionally keeping track of + * attached SSE event ids. + * + * @param the resource on which the stream is being served and consumed via + * this mechanism + * @author Dariusz Jędrzejczyk + */ +public interface McpTransportStream { + + /** + * The last observed event identifier. + * @return if not empty, contains the most recent event that was consumed + */ + Optional lastId(); + + /** + * An internal stream identifier used to distinguish streams while debugging. + * @return a {@code long} stream identifier value + */ + long streamId(); + + /** + * Allows keeping track of the transport stream of events (currently an SSE stream + * from Streamable HTTP specification) and enable resumability and reconnects in case + * of stream errors. + * @param eventStream a {@link Publisher} of tuples (pairs) of an optional identifier + * associated with a collection of messages + * @return a flattened {@link Publisher} of + * {@link io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage JSON-RPC messages} + * with the identifier stripped away + */ + Publisher consumeSseStream( + Publisher, Iterable>> eventStream); + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java index 72b409af..37f9e71a 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java @@ -19,6 +19,8 @@ import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.ElicitRequest; +import io.modelcontextprotocol.spec.McpSchema.ElicitResult; import io.modelcontextprotocol.spec.McpSchema.GetPromptRequest; import io.modelcontextprotocol.spec.McpSchema.Prompt; import io.modelcontextprotocol.spec.McpSchema.Resource; @@ -111,14 +113,16 @@ void tearDown() { onClose(); } - void verifyInitializationTimeout(Function> operation, String action) { + void verifyNotificationSucceedsWithImplicitInitialization(Function> operation, + String action) { withClient(createMcpTransport(), mcpAsyncClient -> { - StepVerifier.withVirtualTime(() -> operation.apply(mcpAsyncClient)) - .expectSubscription() - .thenAwait(getInitializationTimeout()) - .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before " + action)) - .verify(); + StepVerifier.create(operation.apply(mcpAsyncClient)).verifyComplete(); + }); + } + + void verifyCallSucceedsWithImplicitInitialization(Function> operation, String action) { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(operation.apply(mcpAsyncClient)).expectNextCount(1).verifyComplete(); }); } @@ -134,7 +138,7 @@ void testConstructorWithInvalidArguments() { @Test void testListToolsWithoutInitialization() { - verifyInitializationTimeout(client -> client.listTools(null), "listing tools"); + verifyCallSucceedsWithImplicitInitialization(client -> client.listTools(null), "listing tools"); } @Test @@ -154,7 +158,7 @@ void testListTools() { @Test void testPingWithoutInitialization() { - verifyInitializationTimeout(client -> client.ping(), "pinging the server"); + verifyCallSucceedsWithImplicitInitialization(client -> client.ping(), "pinging the server"); } @Test @@ -169,7 +173,7 @@ void testPing() { @Test void testCallToolWithoutInitialization() { CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", ECHO_TEST_MESSAGE)); - verifyInitializationTimeout(client -> client.callTool(callToolRequest), "calling tools"); + verifyCallSucceedsWithImplicitInitialization(client -> client.callTool(callToolRequest), "calling tools"); } @Test @@ -203,7 +207,7 @@ void testCallToolWithInvalidTool() { @Test void testListResourcesWithoutInitialization() { - verifyInitializationTimeout(client -> client.listResources(null), "listing resources"); + verifyCallSucceedsWithImplicitInitialization(client -> client.listResources(null), "listing resources"); } @Test @@ -234,7 +238,7 @@ void testMcpAsyncClientState() { @Test void testListPromptsWithoutInitialization() { - verifyInitializationTimeout(client -> client.listPrompts(null), "listing " + "prompts"); + verifyCallSucceedsWithImplicitInitialization(client -> client.listPrompts(null), "listing " + "prompts"); } @Test @@ -259,7 +263,7 @@ void testListPrompts() { @Test void testGetPromptWithoutInitialization() { GetPromptRequest request = new GetPromptRequest("simple_prompt", Map.of()); - verifyInitializationTimeout(client -> client.getPrompt(request), "getting " + "prompts"); + verifyCallSucceedsWithImplicitInitialization(client -> client.getPrompt(request), "getting " + "prompts"); } @Test @@ -280,7 +284,7 @@ void testGetPrompt() { @Test void testRootsListChangedWithoutInitialization() { - verifyInitializationTimeout(client -> client.rootsListChangedNotification(), + verifyNotificationSucceedsWithImplicitInitialization(client -> client.rootsListChangedNotification(), "sending roots list changed notification"); } @@ -355,7 +359,8 @@ void testReadResource() { @Test void testListResourceTemplatesWithoutInitialization() { - verifyInitializationTimeout(client -> client.listResourceTemplates(), "listing resource templates"); + verifyCallSucceedsWithImplicitInitialization(client -> client.listResourceTemplates(), + "listing resource templates"); } @Test @@ -422,6 +427,20 @@ void testInitializeWithSamplingCapability() { }); } + @Test + void testInitializeWithElicitationCapability() { + ClientCapabilities capabilities = ClientCapabilities.builder().elicitation().build(); + ElicitResult elicitResult = ElicitResult.builder() + .message(ElicitResult.Action.ACCEPT) + .content(Map.of("foo", "bar")) + .build(); + withClient(createMcpTransport(), + builder -> builder.capabilities(capabilities).elicitation(request -> Mono.just(elicitResult)), + client -> { + StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete(); + }); + } + @Test void testInitializeWithAllCapabilities() { var capabilities = ClientCapabilities.builder() @@ -433,7 +452,11 @@ void testInitializeWithAllCapabilities() { Function> samplingHandler = request -> Mono .just(CreateMessageResult.builder().message("test").model("test-model").build()); - withClient(createMcpTransport(), builder -> builder.capabilities(capabilities).sampling(samplingHandler), + Function> elicitationHandler = request -> Mono + .just(ElicitResult.builder().message(ElicitResult.Action.ACCEPT).content(Map.of("foo", "bar")).build()); + + withClient(createMcpTransport(), + builder -> builder.capabilities(capabilities).sampling(samplingHandler).elicitation(elicitationHandler), client -> StepVerifier.create(client.initialize()).assertNext(result -> { @@ -448,8 +471,8 @@ void testInitializeWithAllCapabilities() { @Test void testLoggingLevelsWithoutInitialization() { - verifyInitializationTimeout(client -> client.setLoggingLevel(McpSchema.LoggingLevel.DEBUG), - "setting logging level"); + verifyNotificationSucceedsWithImplicitInitialization( + client -> client.setLoggingLevel(McpSchema.LoggingLevel.DEBUG), "setting logging level"); } @Test diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java index 24c161eb..77989577 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java @@ -5,6 +5,7 @@ package io.modelcontextprotocol.client; import java.time.Duration; +import java.util.List; import java.util.Map; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; @@ -12,7 +13,6 @@ import java.util.function.Function; import io.modelcontextprotocol.spec.McpClientTransport; -import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; @@ -113,33 +113,18 @@ void tearDown() { static final Object DUMMY_RETURN_VALUE = new Object(); - void verifyNotificationTimesOut(Consumer operation, String action) { - verifyCallTimesOut(client -> { + void verifyNotificationSucceedsWithImplicitInitialization(Consumer operation, String action) { + verifyCallSucceedsWithImplicitInitialization(client -> { operation.accept(client); return DUMMY_RETURN_VALUE; }, action); } - void verifyCallTimesOut(Function blockingOperation, String action) { + void verifyCallSucceedsWithImplicitInitialization(Function blockingOperation, String action) { withClient(createMcpTransport(), mcpSyncClient -> { - // This scheduler is not replaced by virtual time scheduler - Scheduler customScheduler = Schedulers.newBoundedElastic(1, 1, "actualBoundedElastic"); - - StepVerifier.withVirtualTime(() -> Mono.fromSupplier(() -> blockingOperation.apply(mcpSyncClient)) + StepVerifier.create(Mono.fromSupplier(() -> blockingOperation.apply(mcpSyncClient)) // Offload the blocking call to the real scheduler - .subscribeOn(customScheduler)) - .expectSubscription() - // This works without actually waiting but executes all the - // tasks pending execution on the VirtualTimeScheduler. - // It is possible to execute the blocking code from the operation - // because it is blocked on a dedicated Scheduler and the main - // flow is not blocked and uses the VirtualTimeScheduler. - .thenAwait(getInitializationTimeout()) - .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before " + action)) - .verify(); - - customScheduler.dispose(); + .subscribeOn(Schedulers.boundedElastic())).expectNextCount(1).verifyComplete(); }); } @@ -155,7 +140,7 @@ void testConstructorWithInvalidArguments() { @Test void testListToolsWithoutInitialization() { - verifyCallTimesOut(client -> client.listTools(null), "listing tools"); + verifyCallSucceedsWithImplicitInitialization(client -> client.listTools(null), "listing tools"); } @Test @@ -176,8 +161,8 @@ void testListTools() { @Test void testCallToolsWithoutInitialization() { - verifyCallTimesOut(client -> client.callTool(new CallToolRequest("add", Map.of("a", 3, "b", 4))), - "calling tools"); + verifyCallSucceedsWithImplicitInitialization( + client -> client.callTool(new CallToolRequest("add", Map.of("a", 3, "b", 4))), "calling tools"); } @Test @@ -201,7 +186,7 @@ void testCallTools() { @Test void testPingWithoutInitialization() { - verifyCallTimesOut(client -> client.ping(), "pinging the server"); + verifyCallSucceedsWithImplicitInitialization(client -> client.ping(), "pinging the server"); } @Test @@ -215,7 +200,7 @@ void testPing() { @Test void testCallToolWithoutInitialization() { CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", TEST_MESSAGE)); - verifyCallTimesOut(client -> client.callTool(callToolRequest), "calling tools"); + verifyCallSucceedsWithImplicitInitialization(client -> client.callTool(callToolRequest), "calling tools"); } @Test @@ -244,7 +229,7 @@ void testCallToolWithInvalidTool() { @Test void testRootsListChangedWithoutInitialization() { - verifyNotificationTimesOut(client -> client.rootsListChangedNotification(), + verifyNotificationSucceedsWithImplicitInitialization(client -> client.rootsListChangedNotification(), "sending roots list changed notification"); } @@ -258,7 +243,7 @@ void testRootsListChanged() { @Test void testListResourcesWithoutInitialization() { - verifyCallTimesOut(client -> client.listResources(null), "listing resources"); + verifyCallSucceedsWithImplicitInitialization(client -> client.listResources(null), "listing resources"); } @Test @@ -334,8 +319,14 @@ void testRemoveNonExistentRoot() { @Test void testReadResourceWithoutInitialization() { - Resource resource = new Resource("test://uri", "Test Resource", null, null, null); - verifyCallTimesOut(client -> client.readResource(resource), "reading resources"); + AtomicReference> resources = new AtomicReference<>(); + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + resources.set(mcpSyncClient.listResources().resources()); + }); + + verifyCallSucceedsWithImplicitInitialization(client -> client.readResource(resources.get().get(0)), + "reading resources"); } @Test @@ -356,7 +347,8 @@ void testReadResource() { @Test void testListResourceTemplatesWithoutInitialization() { - verifyCallTimesOut(client -> client.listResourceTemplates(null), "listing resource templates"); + verifyCallSucceedsWithImplicitInitialization(client -> client.listResourceTemplates(null), + "listing resource templates"); } @Test @@ -414,8 +406,8 @@ void testNotificationHandlers() { @Test void testLoggingLevelsWithoutInitialization() { - verifyNotificationTimesOut(client -> client.setLoggingLevel(McpSchema.LoggingLevel.DEBUG), - "setting logging level"); + verifyNotificationSucceedsWithImplicitInitialization( + client -> client.setLoggingLevel(McpSchema.LoggingLevel.DEBUG), "setting logging level"); } @Test diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java index fdff4b77..1b66a98c 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java @@ -22,7 +22,8 @@ class HttpSseMcpAsyncClientTests extends AbstractMcpAsyncClientTests { // Uses the https://github.com/tzolov/mcp-everything-server-docker-image @SuppressWarnings("resource") - GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v1") + GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") + .withCommand("node dist/index.js sse") .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) .withExposedPorts(3001) .waitingFor(Wait.forHttp("/").forStatusCode(404)); diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java index 204cf298..8646c1b4 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java @@ -22,7 +22,8 @@ class HttpSseMcpSyncClientTests extends AbstractMcpSyncClientTests { // Uses the https://github.com/tzolov/mcp-everything-server-docker-image @SuppressWarnings("resource") - GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v1") + GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") + .withCommand("node dist/index.js sse") .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) .withExposedPorts(3001) .waitingFor(Wait.forHttp("/").forStatusCode(404)); diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java index 4510b152..e6cde8e3 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java @@ -19,6 +19,8 @@ import io.modelcontextprotocol.spec.McpSchema.InitializeResult; import io.modelcontextprotocol.spec.McpSchema.Root; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.EnumSource; import reactor.core.publisher.Mono; import static io.modelcontextprotocol.spec.McpSchema.METHOD_INITIALIZE; @@ -349,4 +351,152 @@ void testSamplingCreateMessageRequestHandlingWithNullHandler() { .hasMessage("Sampling handler must not be null when client capabilities include sampling"); } + @Test + @SuppressWarnings("unchecked") + void testElicitationCreateRequestHandling() { + MockMcpClientTransport transport = initializationEnabledTransport(); + + // Create a test elicitation handler that echoes back the input + Function> elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isInstanceOf(Map.class); + assertThat(request.requestedSchema().get("type")).isEqualTo("object"); + + var properties = request.requestedSchema().get("properties"); + assertThat(properties).isNotNull(); + assertThat(((Map) properties).get("message")).isInstanceOf(Map.class); + + return Mono.just(McpSchema.ElicitResult.builder() + .message(McpSchema.ElicitResult.Action.ACCEPT) + .content(Map.of("message", request.message())) + .build()); + }; + + // Create client with elicitation capability and handler + McpAsyncClient asyncMcpClient = McpClient.async(transport) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build(); + + assertThat(asyncMcpClient.initialize().block()).isNotNull(); + + // Create a mock elicitation + var elicitRequest = McpSchema.ElicitRequest.builder() + .message("Test message") + .requestedSchema(Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + // Simulate incoming request + McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, + McpSchema.METHOD_ELICITATION_CREATE, "test-id", elicitRequest); + transport.simulateIncomingMessage(request); + + // Verify response + McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessage(); + assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCResponse.class); + + McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) sentMessage; + assertThat(response.id()).isEqualTo("test-id"); + assertThat(response.error()).isNull(); + + McpSchema.ElicitResult result = transport.unmarshalFrom(response.result(), new TypeReference<>() { + }); + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT); + assertThat(result.content()).isEqualTo(Map.of("message", "Test message")); + + asyncMcpClient.closeGracefully(); + } + + @ParameterizedTest + @EnumSource(value = McpSchema.ElicitResult.Action.class, names = { "DECLINE", "CANCEL" }) + void testElicitationFailRequestHandling(McpSchema.ElicitResult.Action action) { + MockMcpClientTransport transport = initializationEnabledTransport(); + + // Create a test elicitation handler to decline the request + Function> elicitationHandler = request -> Mono + .just(McpSchema.ElicitResult.builder().message(action).build()); + + // Create client with elicitation capability and handler + McpAsyncClient asyncMcpClient = McpClient.async(transport) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build(); + + assertThat(asyncMcpClient.initialize().block()).isNotNull(); + + // Create a mock elicitation + var elicitRequest = McpSchema.ElicitRequest.builder() + .message("Test message") + .requestedSchema(Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + // Simulate incoming request + McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, + McpSchema.METHOD_ELICITATION_CREATE, "test-id", elicitRequest); + transport.simulateIncomingMessage(request); + + // Verify response + McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessage(); + assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCResponse.class); + + McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) sentMessage; + assertThat(response.id()).isEqualTo("test-id"); + assertThat(response.error()).isNull(); + + McpSchema.ElicitResult result = transport.unmarshalFrom(response.result(), new TypeReference<>() { + }); + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(action); + assertThat(result.content()).isNull(); + + asyncMcpClient.closeGracefully(); + } + + @Test + void testElicitationCreateRequestHandlingWithoutCapability() { + MockMcpClientTransport transport = initializationEnabledTransport(); + + // Create client without elicitation capability + McpAsyncClient asyncMcpClient = McpClient.async(transport) + .capabilities(ClientCapabilities.builder().build()) // No elicitation + // capability + .build(); + + assertThat(asyncMcpClient.initialize().block()).isNotNull(); + + // Create a mock elicitation + var elicitRequest = new McpSchema.ElicitRequest("test", + Map.of("type", "object", "properties", Map.of("test", Map.of("type", "boolean", "defaultValue", true, + "description", "test-description", "title", "test-title")))); + + // Simulate incoming request + McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, + McpSchema.METHOD_ELICITATION_CREATE, "test-id", elicitRequest); + transport.simulateIncomingMessage(request); + + // Verify error response + McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessage(); + assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCResponse.class); + + McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) sentMessage; + assertThat(response.id()).isEqualTo("test-id"); + assertThat(response.result()).isNull(); + assertThat(response.error()).isNotNull(); + assertThat(response.error().message()).contains("Method not found: elicitation/create"); + + asyncMcpClient.closeGracefully(); + } + + @Test + void testElicitationCreateRequestHandlingWithNullHandler() { + MockMcpClientTransport transport = new MockMcpClientTransport(); + + // Create client with elicitation capability but null handler + assertThatThrownBy(() -> McpClient.async(transport) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .build()).isInstanceOf(McpError.class) + .hasMessage("Elicitation handler must not be null when client capabilities include elicitation"); + } + } diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java index c3908013..8c0069d6 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java @@ -25,12 +25,12 @@ protected McpClientTransport createMcpTransport() { ServerParameters stdioParams; if (System.getProperty("os.name").toLowerCase().contains("win")) { stdioParams = ServerParameters.builder("cmd.exe") - .args("/c", "npx.cmd", "-y", "@modelcontextprotocol/server-everything", "dir") + .args("/c", "npx.cmd", "-y", "@modelcontextprotocol/server-everything", "stdio") .build(); } else { stdioParams = ServerParameters.builder("npx") - .args("-y", "@modelcontextprotocol/server-everything", "dir") + .args("-y", "@modelcontextprotocol/server-everything", "stdio") .build(); } return new StdioClientTransport(stdioParams); diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java index 8e75c4a3..4b5f4f9c 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java @@ -33,12 +33,12 @@ protected McpClientTransport createMcpTransport() { ServerParameters stdioParams; if (System.getProperty("os.name").toLowerCase().contains("win")) { stdioParams = ServerParameters.builder("cmd.exe") - .args("/c", "npx.cmd", "-y", "@modelcontextprotocol/server-everything", "dir") + .args("/c", "npx.cmd", "-y", "@modelcontextprotocol/server-everything", "stdio") .build(); } else { stdioParams = ServerParameters.builder("npx") - .args("-y", "@modelcontextprotocol/server-everything", "dir") + .args("-y", "@modelcontextprotocol/server-everything", "stdio") .build(); } return new StdioClientTransport(stdioParams); @@ -68,7 +68,7 @@ void customErrorHandlerShouldReceiveErrors() throws InterruptedException { } protected Duration getInitializationTimeout() { - return Duration.ofSeconds(6); + return Duration.ofSeconds(10); } } diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java index 762264de..1b1c7201 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java @@ -51,7 +51,8 @@ class HttpClientSseClientTransportTests { static String host = "http://localhost:3001"; @SuppressWarnings("resource") - GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v1") + GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") + .withCommand("node dist/index.js sse") .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) .withExposedPorts(3001) .waitingFor(Wait.forHttp("/").forStatusCode(404)); diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java index 2ff6325a..dc9d1cfa 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java @@ -24,6 +24,8 @@ import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.ElicitRequest; +import io.modelcontextprotocol.spec.McpSchema.ElicitResult; import io.modelcontextprotocol.spec.McpSchema.InitializeResult; import io.modelcontextprotocol.spec.McpSchema.ModelPreferences; import io.modelcontextprotocol.spec.McpSchema.Role; @@ -339,6 +341,217 @@ void testCreateMessageWithRequestTimeoutFail() throws InterruptedException { mcpServer.close(); } + // --------------------------------------- + // Elicitation Tests + // --------------------------------------- + @Test + @Disabled + void testCreateElicitationWithoutElicitationCapabilities() { + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + exchange.createElicitation(mock(ElicitRequest.class)).block(); + + return Mono.just(mock(CallToolResult.class)); + }); + + var server = McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").tools(tool).build(); + + try ( + // Create client without elicitation capabilities + var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")).build()) { + + assertThat(client.initialize()).isNotNull(); + + try { + client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be configured with elicitation capabilities"); + } + } + server.closeGracefully().block(); + } + + @Test + void testCreateElicitationSuccess() { + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + + return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); + }; + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var elicitationRequest = ElicitRequest.builder() + .message("Test message") + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(ElicitResult.Action.ACCEPT); + assertThat(result.content().get("message")).isEqualTo("Test message"); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + } + mcpServer.closeGracefully().block(); + } + + @Test + void testCreateElicitationWithRequestTimeoutSuccess() { + + // Client + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + try { + TimeUnit.SECONDS.sleep(2); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build(); + + // Server + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var elicitationRequest = ElicitRequest.builder() + .message("Test message") + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(ElicitResult.Action.ACCEPT); + assertThat(result.content().get("message")).isEqualTo("Test message"); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(3)) + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + + mcpClient.closeGracefully(); + mcpServer.closeGracefully().block(); + } + + @Test + void testCreateElicitationWithRequestTimeoutFail() { + + // Client + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + try { + TimeUnit.SECONDS.sleep(2); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build(); + + // Server + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var elicitationRequest = ElicitRequest.builder() + .message("Test message") + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(ElicitResult.Action.ACCEPT); + assertThat(result.content().get("message")).isEqualTo("Test message"); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(1)) + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThatExceptionOfType(McpError.class).isThrownBy(() -> { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + }).withMessageContaining("Timeout"); + + mcpClient.closeGracefully(); + mcpServer.closeGracefully().block(); + } + // --------------------------------------- // Roots Tests // --------------------------------------- diff --git a/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java b/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java index ff78c1bf..99015d8c 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java @@ -807,6 +807,40 @@ void testCreateMessageResult() throws Exception { {"role":"assistant","content":{"type":"text","text":"Assistant response"},"model":"gpt-4","stopReason":"endTurn"}""")); } + // Elicitation Tests + + @Test + void testCreateElicitationRequest() throws Exception { + McpSchema.ElicitRequest request = McpSchema.ElicitRequest.builder() + .requestedSchema(Map.of("type", "object", "required", List.of("a"), "properties", + Map.of("foo", Map.of("type", "string")))) + .build(); + + String value = mapper.writeValueAsString(request); + + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"requestedSchema":{"properties":{"foo":{"type":"string"}},"required":["a"],"type":"object"}}""")); + } + + @Test + void testCreateElicitationResult() throws Exception { + McpSchema.ElicitResult result = McpSchema.ElicitResult.builder() + .content(Map.of("foo", "bar")) + .message(McpSchema.ElicitResult.Action.ACCEPT) + .build(); + + String value = mapper.writeValueAsString(result); + + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"action":"accept","content":{"foo":"bar"}}""")); + } + // Roots Tests @Test diff --git a/pom.xml b/pom.xml index 63845740..3fd0857e 100644 --- a/pom.xml +++ b/pom.xml @@ -6,7 +6,7 @@ io.modelcontextprotocol.sdk mcp-parent - 0.10.0-SNAPSHOT + 0.11.0-SNAPSHOT pom https://github.com/modelcontextprotocol/java-sdk @@ -63,7 +63,8 @@ 5.10.2 5.17.0 1.20.4 - 1.17.5 + 1.17.5 + 1.21.0 2.0.16 1.5.15