diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
new file mode 100644
index 00000000..7c73d9f3
--- /dev/null
+++ b/.github/workflows/ci.yml
@@ -0,0 +1,22 @@
+name: CI
+
+on:
+ pull_request: {}
+
+jobs:
+ build:
+ name: Build branch
+ runs-on: ubuntu-latest
+ steps:
+ - name: Checkout source code
+ uses: actions/checkout@v4
+
+ - name: Set up JDK 17
+ uses: actions/setup-java@v4
+ with:
+ java-version: '17'
+ distribution: 'temurin'
+ cache: 'maven'
+
+ - name: Build
+ run: mvn verify
diff --git a/.github/workflows/continuous-integration.yml b/.github/workflows/publish-snapshot.yml
similarity index 98%
rename from .github/workflows/continuous-integration.yml
rename to .github/workflows/publish-snapshot.yml
index e0939f08..5d9b4aa3 100644
--- a/.github/workflows/continuous-integration.yml
+++ b/.github/workflows/publish-snapshot.yml
@@ -1,4 +1,4 @@
-name: CI/CD build
+name: Publish Snapshot
on:
push:
diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md
new file mode 100644
index 00000000..6009a645
--- /dev/null
+++ b/CODE_OF_CONDUCT.md
@@ -0,0 +1,119 @@
+# Contributor Covenant Code of Conduct
+
+## Our Pledge
+
+We as members, contributors, and leaders pledge to make participation in our community a
+harassment-free experience for everyone, regardless of age, body size, visible or
+invisible disability, ethnicity, sex characteristics, gender identity and expression,
+level of experience, education, socio-economic status, nationality, personal appearance,
+race, religion, or sexual identity and orientation.
+
+We pledge to act and interact in ways that contribute to an open, welcoming, diverse,
+inclusive, and healthy community.
+
+## Our Standards
+
+Examples of behavior that contributes to a positive environment for our community
+include:
+
+- Demonstrating empathy and kindness toward other people
+- Being respectful of differing opinions, viewpoints, and experiences
+- Giving and gracefully accepting constructive feedback
+- Accepting responsibility and apologizing to those affected by our mistakes, and
+ learning from the experience
+- Focusing on what is best not just for us as individuals, but for the overall community
+
+Examples of unacceptable behavior include:
+
+- The use of sexualized language or imagery, and sexual attention or advances of any kind
+- Trolling, insulting or derogatory comments, and personal or political attacks
+- Public or private harassment
+- Publishing others' private information, such as a physical or email address, without
+ their explicit permission
+- Other conduct which could reasonably be considered inappropriate in a professional
+ setting
+
+## Enforcement Responsibilities
+
+Community leaders are responsible for clarifying and enforcing our standards of
+acceptable behavior and will take appropriate and fair corrective action in response to
+any behavior that they deem inappropriate, threatening, offensive, or harmful.
+
+Community leaders have the right and responsibility to remove, edit, or reject comments,
+commits, code, wiki edits, issues, and other contributions that are not aligned to this
+Code of Conduct, and will communicate reasons for moderation decisions when appropriate.
+
+## Scope
+
+This Code of Conduct applies within all community spaces, and also applies when an
+individual is officially representing the community in public spaces. Examples of
+representing our community include using an official e-mail address, posting via an
+official social media account, or acting as an appointed representative at an online or
+offline event.
+
+## Enforcement
+
+Instances of abusive, harassing, or otherwise unacceptable behavior may be reported to
+the community leaders responsible for enforcement at mcp-coc@anthropic.com. All
+complaints will be reviewed and investigated promptly and fairly.
+
+All community leaders are obligated to respect the privacy and security of the reporter
+of any incident.
+
+## Enforcement Guidelines
+
+Community leaders will follow these Community Impact Guidelines in determining the
+consequences for any action they deem in violation of this Code of Conduct:
+
+### 1. Correction
+
+**Community Impact**: Use of inappropriate language or other behavior deemed
+unprofessional or unwelcome in the community.
+
+**Consequence**: A private, written warning from community leaders, providing clarity
+around the nature of the violation and an explanation of why the behavior was
+inappropriate. A public apology may be requested.
+
+### 2. Warning
+
+**Community Impact**: A violation through a single incident or series of actions.
+
+**Consequence**: A warning with consequences for continued behavior. No interaction with
+the people involved, including unsolicited interaction with those enforcing the Code of
+Conduct, for a specified period of time. This includes avoiding interactions in community
+spaces as well as external channels like social media. Violating these terms may lead to
+a temporary or permanent ban.
+
+### 3. Temporary Ban
+
+**Community Impact**: A serious violation of community standards, including sustained
+inappropriate behavior.
+
+**Consequence**: A temporary ban from any sort of interaction or public communication
+with the community for a specified period of time. No public or private interaction with
+the people involved, including unsolicited interaction with those enforcing the Code of
+Conduct, is allowed during this period. Violating these terms may lead to a permanent
+ban.
+
+### 4. Permanent Ban
+
+**Community Impact**: Demonstrating a pattern of violation of community standards,
+including sustained inappropriate behavior, harassment of an individual, or aggression
+toward or disparagement of classes of individuals.
+
+**Consequence**: A permanent ban from any sort of public interaction within the
+community.
+
+## Attribution
+
+This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 2.0,
+available at https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
+
+Community Impact Guidelines were inspired by
+[Mozilla's code of conduct enforcement ladder](https://github.com/mozilla/diversity).
+
+[homepage]: https://www.contributor-covenant.org
+
+For answers to common questions about this code of conduct, see the FAQ at
+https://www.contributor-covenant.org/faq. Translations are available at
+https://www.contributor-covenant.org/translations.
\ No newline at end of file
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
new file mode 100644
index 00000000..517f3255
--- /dev/null
+++ b/CONTRIBUTING.md
@@ -0,0 +1,94 @@
+# Contributing to Model Context Protocol Java SDK
+
+Thank you for your interest in contributing to the Model Context Protocol Java SDK!
+This document outlines how to contribute to this project.
+
+## Prerequisites
+
+The following software is required to work on the codebase:
+
+- `Java 17` or above
+- `Docker`
+- `npx`
+
+## Getting Started
+
+1. Fork the repository
+2. Clone your fork:
+
+```bash
+git clone https://github.com/YOUR-USERNAME/java-sdk.git
+cd java-sdk
+```
+
+3. Build from source:
+
+```bash
+./mvnw clean install -DskipTests # skip the tests
+./mvnw test # run tests
+```
+
+## Reporting Issues
+
+Please create an issue in the repository if you discover a bug or would like to
+propose an enhancement. Bug reports should have a reproducer in the form of a code
+sample or a repository attached that the maintainers or contributors can work with to
+address the problem.
+
+## Making Changes
+
+1. Create a new branch:
+
+```bash
+git checkout -b feature/your-feature-name
+```
+
+2. Make your changes
+3. Validate your changes:
+
+```bash
+./mvnw clean test
+```
+
+### Change Proposal Guidelines
+
+#### Principles of MCP
+
+1. **Simple + Minimal**: It is much easier to add things to the codebase than it is to
+ remove them. To maintain simplicity, we keep a high bar for adding new concepts and
+ primitives as each addition requires maintenance and compatibility consideration.
+2. **Concrete**: Code changes need to be based on specific usage and implementation
+ challenges and not on speculative ideas. Most importantly, the SDK is meant to
+ implement the MCP specification.
+
+## Submitting Changes
+
+1. For non-trivial changes, please clarify with the maintainers in an issue whether
+ you can contribute the change and the desired scope of the change.
+2. For trivial changes (for example a couple of lines or documentation changes) there
+ is no need to open an issue first.
+3. Push your changes to your fork.
+4. Submit a pull request to the main repository.
+5. Follow the pull request template.
+6. Wait for review.
+7. For any follow-up work, please add new commits instead of force-pushing. This will
+ allow the reviewer to focus on incremental changes instead of having to restart the
+ review process.
+
+## Code of Conduct
+
+This project follows a Code of Conduct. Please review it in
+[CODE_OF_CONDUCT.md](CODE_OF_CONDUCT.md).
+
+## Questions
+
+If you have questions, please create a discussion in the repository.
+
+## License
+
+By contributing, you agree that your contributions will be licensed under the MIT
+License.
+
+## Security
+
+Please review our [Security Policy](SECURITY.md) for reporting security issues.
\ No newline at end of file
diff --git a/README.md b/README.md
index caa6bf0c..0cd3f84a 100644
--- a/README.md
+++ b/README.md
@@ -1,10 +1,10 @@
# MCP Java SDK
-[](https://github.com/modelcontextprotocol/java-sdk/actions/workflows/continuous-integration.yml)
+[](https://github.com/modelcontextprotocol/java-sdk/actions/workflows/publish-snapshot.yml)
A set of projects that provide Java SDK integration for the [Model Context Protocol](https://modelcontextprotocol.org/docs/concepts/architecture).
This SDK enables Java applications to interact with AI models and tools through a standardized interface, supporting both synchronous and asynchronous communication patterns.
-## 📚 Reference Documentation
+## 📚 Reference Documentation
#### MCP Java SDK documentation
For comprehensive guides and SDK API documentation, visit the [MCP Java SDK Reference Documentation](https://modelcontextprotocol.io/sdk/java/mcp-overview).
@@ -30,11 +30,8 @@ To run the tests you have to pre-install `Docker` and `npx`.
## Contributing
-Contributions are welcome! Please:
-
-1. Fork the repository
-2. Create a feature branch
-3. Submit a Pull Request
+Contributions are welcome!
+Please follow the [Contributing Guidelines](CONTRIBUTING.md).
## Team
diff --git a/SECURITY.md b/SECURITY.md
new file mode 100644
index 00000000..74e9880f
--- /dev/null
+++ b/SECURITY.md
@@ -0,0 +1,21 @@
+# Security Policy
+
+Thank you for helping us keep the SDKs and systems they interact with secure.
+
+## Reporting Security Issues
+
+This SDK is maintained by [Anthropic](https://www.anthropic.com/) as part of the Model
+Context Protocol project.
+
+The security of our systems and user data is Anthropic’s top priority. We appreciate the
+work of security researchers acting in good faith in identifying and reporting potential
+vulnerabilities.
+
+Our security program is managed on HackerOne and we ask that any validated vulnerability
+in this functionality be reported through their
+[submission form](https://hackerone.com/anthropic-vdp/reports/new?type=team&report_type=vulnerability).
+
+## Vulnerability Disclosure Program
+
+Our Vulnerability Program Guidelines are defined on our
+[HackerOne program page](https://hackerone.com/anthropic-vdp).
\ No newline at end of file
diff --git a/mcp-bom/pom.xml b/mcp-bom/pom.xml
index 3b2ad42c..7214dacd 100644
--- a/mcp-bom/pom.xml
+++ b/mcp-bom/pom.xml
@@ -7,7 +7,7 @@
io.modelcontextprotocol.sdkmcp-parent
- 0.8.0-SNAPSHOT
+ 0.11.0-SNAPSHOTmcp-bom
diff --git a/mcp-spring/mcp-spring-webflux/pom.xml b/mcp-spring/mcp-spring-webflux/pom.xml
index 4d9f96e5..a8b92bd0 100644
--- a/mcp-spring/mcp-spring-webflux/pom.xml
+++ b/mcp-spring/mcp-spring-webflux/pom.xml
@@ -6,7 +6,7 @@
io.modelcontextprotocol.sdkmcp-parent
- 0.8.0-SNAPSHOT
+ 0.11.0-SNAPSHOT../../pom.xmlmcp-spring-webflux
@@ -25,13 +25,13 @@
io.modelcontextprotocol.sdkmcp
- 0.8.0-SNAPSHOT
+ 0.11.0-SNAPSHOTio.modelcontextprotocol.sdkmcp-test
- 0.8.0-SNAPSHOT
+ 0.11.0-SNAPSHOTtest
@@ -82,6 +82,12 @@
${mockito.version}test
+
+ net.bytebuddy
+ byte-buddy
+ ${byte-buddy.version}
+ test
+ io.projectreactorreactor-test
diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java
index 8ea65fd7..37abe295 100644
--- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java
+++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java
@@ -9,7 +9,7 @@
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
-import io.modelcontextprotocol.spec.ClientMcpTransport;
+import io.modelcontextprotocol.spec.McpClientTransport;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage;
@@ -58,7 +58,7 @@
* "https://spec.modelcontextprotocol.io/specification/basic/transports/#http-with-sse">MCP
* HTTP with SSE Transport Specification
*/
-public class WebFluxSseClientTransport implements ClientMcpTransport {
+public class WebFluxSseClientTransport implements McpClientTransport {
private static final Logger logger = LoggerFactory.getLogger(WebFluxSseClientTransport.class);
@@ -79,7 +79,7 @@ public class WebFluxSseClientTransport implements ClientMcpTransport {
* Default SSE endpoint path as specified by the MCP transport specification. This
* endpoint is used to establish the SSE connection with the server.
*/
- private static final String SSE_ENDPOINT = "/sse";
+ private static final String DEFAULT_SSE_ENDPOINT = "/sse";
/**
* Type reference for parsing SSE events containing string data.
@@ -117,6 +117,12 @@ public class WebFluxSseClientTransport implements ClientMcpTransport {
*/
protected final Sinks.One messageEndpointSink = Sinks.one();
+ /**
+ * The SSE endpoint URI provided by the server. Used for sending outbound messages via
+ * HTTP POST requests.
+ */
+ private String sseEndpoint;
+
/**
* Constructs a new SseClientTransport with the specified WebClient builder. Uses a
* default ObjectMapper instance for JSON processing.
@@ -137,11 +143,27 @@ public WebFluxSseClientTransport(WebClient.Builder webClientBuilder) {
* @throws IllegalArgumentException if either parameter is null
*/
public WebFluxSseClientTransport(WebClient.Builder webClientBuilder, ObjectMapper objectMapper) {
+ this(webClientBuilder, objectMapper, DEFAULT_SSE_ENDPOINT);
+ }
+
+ /**
+ * Constructs a new SseClientTransport with the specified WebClient builder and
+ * ObjectMapper. Initializes both inbound and outbound message processing pipelines.
+ * @param webClientBuilder the WebClient.Builder to use for creating the WebClient
+ * instance
+ * @param objectMapper the ObjectMapper to use for JSON processing
+ * @param sseEndpoint the SSE endpoint URI to use for establishing the connection
+ * @throws IllegalArgumentException if either parameter is null
+ */
+ public WebFluxSseClientTransport(WebClient.Builder webClientBuilder, ObjectMapper objectMapper,
+ String sseEndpoint) {
Assert.notNull(objectMapper, "ObjectMapper must not be null");
Assert.notNull(webClientBuilder, "WebClient.Builder must not be null");
+ Assert.hasText(sseEndpoint, "SSE endpoint must not be null or empty");
this.objectMapper = objectMapper;
this.webClient = webClientBuilder.build();
+ this.sseEndpoint = sseEndpoint;
}
/**
@@ -254,7 +276,7 @@ public Mono sendMessage(JSONRPCMessage message) {
protected Flux> eventStream() {// @formatter:off
return this.webClient
.get()
- .uri(SSE_ENDPOINT)
+ .uri(this.sseEndpoint)
.accept(MediaType.TEXT_EVENT_STREAM)
.retrieve()
.bodyToFlux(SSE_TYPE)
@@ -321,4 +343,66 @@ public T unmarshalFrom(Object data, TypeReference typeRef) {
return this.objectMapper.convertValue(data, typeRef);
}
+ /**
+ * Creates a new builder for {@link WebFluxSseClientTransport}.
+ * @param webClientBuilder the WebClient.Builder to use for creating the WebClient
+ * instance
+ * @return a new builder instance
+ */
+ public static Builder builder(WebClient.Builder webClientBuilder) {
+ return new Builder(webClientBuilder);
+ }
+
+ /**
+ * Builder for {@link WebFluxSseClientTransport}.
+ */
+ public static class Builder {
+
+ private final WebClient.Builder webClientBuilder;
+
+ private String sseEndpoint = DEFAULT_SSE_ENDPOINT;
+
+ private ObjectMapper objectMapper = new ObjectMapper();
+
+ /**
+ * Creates a new builder with the specified WebClient.Builder.
+ * @param webClientBuilder the WebClient.Builder to use
+ */
+ public Builder(WebClient.Builder webClientBuilder) {
+ Assert.notNull(webClientBuilder, "WebClient.Builder must not be null");
+ this.webClientBuilder = webClientBuilder;
+ }
+
+ /**
+ * Sets the SSE endpoint path.
+ * @param sseEndpoint the SSE endpoint path
+ * @return this builder
+ */
+ public Builder sseEndpoint(String sseEndpoint) {
+ Assert.hasText(sseEndpoint, "sseEndpoint must not be empty");
+ this.sseEndpoint = sseEndpoint;
+ return this;
+ }
+
+ /**
+ * Sets the object mapper for JSON serialization/deserialization.
+ * @param objectMapper the object mapper
+ * @return this builder
+ */
+ public Builder objectMapper(ObjectMapper objectMapper) {
+ Assert.notNull(objectMapper, "objectMapper must not be null");
+ this.objectMapper = objectMapper;
+ return this;
+ }
+
+ /**
+ * Builds a new {@link WebFluxSseClientTransport} instance.
+ * @return a new transport instance
+ */
+ public WebFluxSseClientTransport build() {
+ return new WebFluxSseClientTransport(webClientBuilder, objectMapper, sseEndpoint);
+ }
+
+ }
+
}
diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransport.java
deleted file mode 100644
index bed7293e..00000000
--- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransport.java
+++ /dev/null
@@ -1,410 +0,0 @@
-package io.modelcontextprotocol.server.transport;
-
-import java.io.IOException;
-import java.time.Duration;
-import java.util.List;
-import java.util.UUID;
-import java.util.concurrent.ConcurrentHashMap;
-import java.util.function.Function;
-
-import com.fasterxml.jackson.core.type.TypeReference;
-import com.fasterxml.jackson.databind.ObjectMapper;
-import io.modelcontextprotocol.spec.McpError;
-import io.modelcontextprotocol.spec.McpSchema;
-import io.modelcontextprotocol.spec.ServerMcpTransport;
-import io.modelcontextprotocol.util.Assert;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-import reactor.core.publisher.Flux;
-import reactor.core.publisher.Mono;
-import reactor.core.publisher.Sinks;
-
-import org.springframework.http.HttpStatus;
-import org.springframework.http.MediaType;
-import org.springframework.http.codec.ServerSentEvent;
-import org.springframework.web.reactive.function.server.RouterFunction;
-import org.springframework.web.reactive.function.server.RouterFunctions;
-import org.springframework.web.reactive.function.server.ServerRequest;
-import org.springframework.web.reactive.function.server.ServerResponse;
-
-/**
- * Server-side implementation of the MCP (Model Context Protocol) HTTP transport using
- * Server-Sent Events (SSE). This implementation provides a bidirectional communication
- * channel between MCP clients and servers using HTTP POST for client-to-server messages
- * and SSE for server-to-client messages.
- *
- *
- * Key features:
- *
- *
Implements the {@link ServerMcpTransport} interface for MCP server transport
- * functionality
- *
Uses WebFlux for non-blocking request handling and SSE support
- *
Maintains client sessions for reliable message delivery
- *
Supports graceful shutdown with session cleanup
- *
Thread-safe message broadcasting to multiple clients
- *
- *
- *
- * The transport sets up two main endpoints:
- *
- *
SSE endpoint (/sse) - For establishing SSE connections with clients
- *
Message endpoint (configurable) - For receiving JSON-RPC messages from clients
- *
- *
- *
- * This implementation is thread-safe and can handle multiple concurrent client
- * connections. It uses {@link ConcurrentHashMap} for session management and Reactor's
- * {@link Sinks} for thread-safe message broadcasting.
- *
- * @author Christian Tzolov
- * @author Alexandros Pappas
- * @see ServerMcpTransport
- * @see ServerSentEvent
- */
-public class WebFluxSseServerTransport implements ServerMcpTransport {
-
- private static final Logger logger = LoggerFactory.getLogger(WebFluxSseServerTransport.class);
-
- /**
- * Event type for JSON-RPC messages sent through the SSE connection.
- */
- public static final String MESSAGE_EVENT_TYPE = "message";
-
- /**
- * Event type for sending the message endpoint URI to clients.
- */
- public static final String ENDPOINT_EVENT_TYPE = "endpoint";
-
- /**
- * Default SSE endpoint path as specified by the MCP transport specification.
- */
- public static final String DEFAULT_SSE_ENDPOINT = "/sse";
-
- private final ObjectMapper objectMapper;
-
- private final String messageEndpoint;
-
- private final String sseEndpoint;
-
- private final RouterFunction> routerFunction;
-
- /**
- * Map of active client sessions, keyed by session ID.
- */
- private final ConcurrentHashMap sessions = new ConcurrentHashMap<>();
-
- /**
- * Flag indicating if the transport is shutting down.
- */
- private volatile boolean isClosing = false;
-
- private Function, Mono> connectHandler;
-
- /**
- * Constructs a new WebFlux SSE server transport instance.
- * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization
- * of MCP messages. Must not be null.
- * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC
- * messages. This endpoint will be communicated to clients during SSE connection
- * setup. Must not be null.
- * @throws IllegalArgumentException if either parameter is null
- */
- public WebFluxSseServerTransport(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) {
- Assert.notNull(objectMapper, "ObjectMapper must not be null");
- Assert.notNull(messageEndpoint, "Message endpoint must not be null");
- Assert.notNull(sseEndpoint, "SSE endpoint must not be null");
-
- this.objectMapper = objectMapper;
- this.messageEndpoint = messageEndpoint;
- this.sseEndpoint = sseEndpoint;
- this.routerFunction = RouterFunctions.route()
- .GET(this.sseEndpoint, this::handleSseConnection)
- .POST(this.messageEndpoint, this::handleMessage)
- .build();
- }
-
- /**
- * Constructs a new WebFlux SSE server transport instance with the default SSE
- * endpoint.
- * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization
- * of MCP messages. Must not be null.
- * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC
- * messages. This endpoint will be communicated to clients during SSE connection
- * setup. Must not be null.
- * @throws IllegalArgumentException if either parameter is null
- */
- public WebFluxSseServerTransport(ObjectMapper objectMapper, String messageEndpoint) {
- this(objectMapper, messageEndpoint, DEFAULT_SSE_ENDPOINT);
- }
-
- /**
- * Configures the message handler for this transport. In the WebFlux SSE
- * implementation, this method stores the handler for processing incoming messages but
- * doesn't establish any connections since the server accepts connections rather than
- * initiating them.
- * @param handler A function that processes incoming JSON-RPC messages and returns
- * responses. This handler will be called for each message received through the
- * message endpoint.
- * @return An empty Mono since the server doesn't initiate connections
- */
- @Override
- public Mono connect(Function, Mono> handler) {
- this.connectHandler = handler;
- // Server-side transport doesn't initiate connections
- return Mono.empty().then();
- }
-
- /**
- * Broadcasts a JSON-RPC message to all connected clients through their SSE
- * connections. The message is serialized to JSON and sent as a server-sent event to
- * each active session.
- *
- *
- * The method:
- *
- *
Serializes the message to JSON
- *
Creates a server-sent event with the message data
- *
Attempts to send the event to all active sessions
- *
Tracks and reports any delivery failures
- *
- * @param message The JSON-RPC message to broadcast
- * @return A Mono that completes when the message has been sent to all sessions, or
- * errors if any session fails to receive the message
- */
- @Override
- public Mono sendMessage(McpSchema.JSONRPCMessage message) {
- if (sessions.isEmpty()) {
- logger.debug("No active sessions to broadcast message to");
- return Mono.empty();
- }
-
- return Mono.create(sink -> {
- try {// @formatter:off
- String jsonText = objectMapper.writeValueAsString(message);
- ServerSentEvent
+
+ net.bytebuddy
+ byte-buddy
+ ${byte-buddy.version}
+ test
+ org.testcontainersjunit-jupiter
diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransport.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java
similarity index 56%
rename from mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransport.java
rename to mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java
index 00928ec7..fc86cfaa 100644
--- a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransport.java
+++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java
@@ -6,18 +6,21 @@
import java.io.IOException;
import java.time.Duration;
+import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
-import java.util.function.Function;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
-import io.modelcontextprotocol.spec.ServerMcpTransport;
+import io.modelcontextprotocol.spec.McpServerTransport;
+import io.modelcontextprotocol.spec.McpServerTransportProvider;
+import io.modelcontextprotocol.spec.McpServerSession;
import io.modelcontextprotocol.util.Assert;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import org.springframework.http.HttpStatus;
@@ -60,12 +63,12 @@
*
* @author Christian Tzolov
* @author Alexandros Pappas
- * @see ServerMcpTransport
+ * @see McpServerTransportProvider
* @see RouterFunction
*/
-public class WebMvcSseServerTransport implements ServerMcpTransport {
+public class WebMvcSseServerTransportProvider implements McpServerTransportProvider {
- private static final Logger logger = LoggerFactory.getLogger(WebMvcSseServerTransport.class);
+ private static final Logger logger = LoggerFactory.getLogger(WebMvcSseServerTransportProvider.class);
/**
* Event type for JSON-RPC messages sent through the SSE connection.
@@ -88,12 +91,16 @@ public class WebMvcSseServerTransport implements ServerMcpTransport {
private final String sseEndpoint;
+ private final String baseUrl;
+
private final RouterFunction routerFunction;
+ private McpServerSession.Factory sessionFactory;
+
/**
* Map of active client sessions, keyed by session ID.
*/
- private final ConcurrentHashMap sessions = new ConcurrentHashMap<>();
+ private final ConcurrentHashMap sessions = new ConcurrentHashMap<>();
/**
* Flag indicating if the transport is shutting down.
@@ -101,25 +108,54 @@ public class WebMvcSseServerTransport implements ServerMcpTransport {
private volatile boolean isClosing = false;
/**
- * The function to process incoming JSON-RPC messages and produce responses.
+ * Constructs a new WebMvcSseServerTransportProvider instance with the default SSE
+ * endpoint.
+ * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization
+ * of messages.
+ * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC
+ * messages via HTTP POST. This endpoint will be communicated to clients through the
+ * SSE connection's initial endpoint event.
+ * @throws IllegalArgumentException if either objectMapper or messageEndpoint is null
*/
- private Function, Mono> connectHandler;
+ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint) {
+ this(objectMapper, messageEndpoint, DEFAULT_SSE_ENDPOINT);
+ }
/**
- * Constructs a new WebMvcSseServerTransport instance.
+ * Constructs a new WebMvcSseServerTransportProvider instance.
* @param objectMapper The ObjectMapper to use for JSON serialization/deserialization
* of messages.
* @param messageEndpoint The endpoint URI where clients should send their JSON-RPC
* messages via HTTP POST. This endpoint will be communicated to clients through the
* SSE connection's initial endpoint event.
- * @throws IllegalArgumentException if either objectMapper or messageEndpoint is null
+ * @param sseEndpoint The endpoint URI where clients establish their SSE connections.
+ * @throws IllegalArgumentException if any parameter is null
*/
- public WebMvcSseServerTransport(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) {
+ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) {
+ this(objectMapper, "", messageEndpoint, sseEndpoint);
+ }
+
+ /**
+ * Constructs a new WebMvcSseServerTransportProvider instance.
+ * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization
+ * of messages.
+ * @param baseUrl The base URL for the message endpoint, used to construct the full
+ * endpoint URL for clients.
+ * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC
+ * messages via HTTP POST. This endpoint will be communicated to clients through the
+ * SSE connection's initial endpoint event.
+ * @param sseEndpoint The endpoint URI where clients establish their SSE connections.
+ * @throws IllegalArgumentException if any parameter is null
+ */
+ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint,
+ String sseEndpoint) {
Assert.notNull(objectMapper, "ObjectMapper must not be null");
+ Assert.notNull(baseUrl, "Message base URL must not be null");
Assert.notNull(messageEndpoint, "Message endpoint must not be null");
Assert.notNull(sseEndpoint, "SSE endpoint must not be null");
this.objectMapper = objectMapper;
+ this.baseUrl = baseUrl;
this.messageEndpoint = messageEndpoint;
this.sseEndpoint = sseEndpoint;
this.routerFunction = RouterFunctions.route()
@@ -128,69 +164,68 @@ public WebMvcSseServerTransport(ObjectMapper objectMapper, String messageEndpoin
.build();
}
- /**
- * Constructs a new WebMvcSseServerTransport instance with the default SSE endpoint.
- * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization
- * of messages.
- * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC
- * messages via HTTP POST. This endpoint will be communicated to clients through the
- * SSE connection's initial endpoint event.
- * @throws IllegalArgumentException if either objectMapper or messageEndpoint is null
- */
- public WebMvcSseServerTransport(ObjectMapper objectMapper, String messageEndpoint) {
- this(objectMapper, messageEndpoint, DEFAULT_SSE_ENDPOINT);
+ @Override
+ public void setSessionFactory(McpServerSession.Factory sessionFactory) {
+ this.sessionFactory = sessionFactory;
}
/**
- * Sets up the message handler for this transport. In the WebMVC SSE implementation,
- * this method only stores the handler for later use, as connections are initiated by
- * clients rather than the server.
- * @param connectionHandler The function to process incoming JSON-RPC messages and
- * produce responses
- * @return An empty Mono since the server doesn't initiate connections
+ * Broadcasts a notification to all connected clients through their SSE connections.
+ * The message is serialized to JSON and sent as an SSE event with type "message". If
+ * any errors occur during sending to a particular client, they are logged but don't
+ * prevent sending to other clients.
+ * @param method The method name for the notification
+ * @param params The parameters for the notification
+ * @return A Mono that completes when the broadcast attempt is finished
*/
@Override
- public Mono connect(
- Function, Mono> connectionHandler) {
- this.connectHandler = connectionHandler;
- // Server-side transport doesn't initiate connections
- return Mono.empty();
+ public Mono notifyClients(String method, Object params) {
+ if (sessions.isEmpty()) {
+ logger.debug("No active sessions to broadcast message to");
+ return Mono.empty();
+ }
+
+ logger.debug("Attempting to broadcast message to {} active sessions", sessions.size());
+
+ return Flux.fromIterable(sessions.values())
+ .flatMap(session -> session.sendNotification(method, params)
+ .doOnError(
+ e -> logger.error("Failed to send message to session {}: {}", session.getId(), e.getMessage()))
+ .onErrorComplete())
+ .then();
}
/**
- * Broadcasts a message to all connected clients through their SSE connections. The
- * message is serialized to JSON and sent as an SSE event with type "message". If any
- * errors occur during sending to a particular client, they are logged but don't
- * prevent sending to other clients.
- * @param message The JSON-RPC message to broadcast to all connected clients
- * @return A Mono that completes when the broadcast attempt is finished
+ * Initiates a graceful shutdown of the transport. This method:
+ *
+ *
Sets the closing flag to prevent new connections
+ *
Closes all active SSE connections
+ *
Removes all session records
+ *
+ * @return A Mono that completes when all cleanup operations are finished
*/
@Override
- public Mono sendMessage(McpSchema.JSONRPCMessage message) {
- return Mono.fromRunnable(() -> {
- if (sessions.isEmpty()) {
- logger.debug("No active sessions to broadcast message to");
- return;
- }
+ public Mono closeGracefully() {
+ return Flux.fromIterable(sessions.values()).doFirst(() -> {
+ this.isClosing = true;
+ logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size());
+ })
+ .flatMap(McpServerSession::closeGracefully)
+ .then()
+ .doOnSuccess(v -> logger.debug("Graceful shutdown completed"));
+ }
- try {
- String jsonText = objectMapper.writeValueAsString(message);
- logger.debug("Attempting to broadcast message to {} active sessions", sessions.size());
-
- sessions.values().forEach(session -> {
- try {
- session.sseBuilder.id(session.id).event(MESSAGE_EVENT_TYPE).data(jsonText);
- }
- catch (Exception e) {
- logger.error("Failed to send message to session {}: {}", session.id, e.getMessage());
- session.sseBuilder.error(e);
- }
- });
- }
- catch (IOException e) {
- logger.error("Failed to serialize message: {}", e.getMessage());
- }
- });
+ /**
+ * Returns the RouterFunction that defines the HTTP endpoints for this transport. The
+ * router function handles two endpoints:
+ *
+ *
GET /sse - For establishing SSE connections
+ *
POST [messageEndpoint] - For receiving JSON-RPC messages from clients
+ *
+ * @return The configured RouterFunction for handling HTTP requests
+ */
+ public RouterFunction getRouterFunction() {
+ return this.routerFunction;
}
/**
@@ -198,7 +233,7 @@ public Mono sendMessage(McpSchema.JSONRPCMessage message) {
* establishing an SSE connection. This method:
*
*
Generates a unique session ID
- *
Creates a new ClientSession with an SSE builder
+ *
Creates a new session with a WebMvcMcpSessionTransport
*
Sends an initial endpoint event to inform the client where to send
* messages
Deserializes the request body into a JSON-RPC message
- *
Processes the message through the configured connect handler
+ *
Processes the message through the session's handle method
*
Returns appropriate HTTP responses based on the processing result
*
* @param request The incoming server request containing the JSON-RPC message
@@ -262,14 +300,23 @@ private ServerResponse handleMessage(ServerRequest request) {
return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down");
}
+ if (request.param("sessionId").isEmpty()) {
+ return ServerResponse.badRequest().body(new McpError("Session ID missing in message endpoint"));
+ }
+
+ String sessionId = request.param("sessionId").get();
+ McpServerSession session = sessions.get(sessionId);
+
+ if (session == null) {
+ return ServerResponse.status(HttpStatus.NOT_FOUND).body(new McpError("Session not found: " + sessionId));
+ }
+
try {
String body = request.body(String.class);
McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body);
- // Convert the message to a Mono, apply the handler, and block for the
- // response
- @SuppressWarnings("unused")
- McpSchema.JSONRPCMessage response = Mono.just(message).transform(connectHandler).block();
+ // Process the message through the session's handle method
+ session.handle(message).block(); // Block for WebMVC compatibility
return ServerResponse.ok().build();
}
@@ -284,99 +331,90 @@ private ServerResponse handleMessage(ServerRequest request) {
}
/**
- * Represents an active client session with its associated SSE connection. Each
- * session maintains:
- *
- *
A unique session identifier
- *
An SSE builder for sending server events to the client
- *
Logging of session lifecycle events
- *
+ * Implementation of McpServerTransport for WebMVC SSE sessions. This class handles
+ * the transport-level communication for a specific client session.
*/
- private static class ClientSession {
+ private class WebMvcMcpSessionTransport implements McpServerTransport {
- private final String id;
+ private final String sessionId;
private final SseBuilder sseBuilder;
/**
- * Creates a new client session with the specified ID and SSE builder.
- * @param id The unique identifier for this session
+ * Creates a new session transport with the specified ID and SSE builder.
+ * @param sessionId The unique identifier for this session
* @param sseBuilder The SSE builder for sending server events to the client
*/
- ClientSession(String id, SseBuilder sseBuilder) {
- this.id = id;
+ WebMvcMcpSessionTransport(String sessionId, SseBuilder sseBuilder) {
+ this.sessionId = sessionId;
this.sseBuilder = sseBuilder;
- logger.debug("Session {} initialized with SSE emitter", id);
+ logger.debug("Session transport {} initialized with SSE builder", sessionId);
}
/**
- * Closes this session by completing the SSE connection. Any errors during
- * completion are logged but do not prevent the session from being marked as
- * closed.
+ * Sends a JSON-RPC message to the client through the SSE connection.
+ * @param message The JSON-RPC message to send
+ * @return A Mono that completes when the message has been sent
*/
- void close() {
- logger.debug("Closing session: {}", id);
+ @Override
+ public Mono sendMessage(McpSchema.JSONRPCMessage message) {
+ return Mono.fromRunnable(() -> {
+ try {
+ String jsonText = objectMapper.writeValueAsString(message);
+ sseBuilder.id(sessionId).event(MESSAGE_EVENT_TYPE).data(jsonText);
+ logger.debug("Message sent to session {}", sessionId);
+ }
+ catch (Exception e) {
+ logger.error("Failed to send message to session {}: {}", sessionId, e.getMessage());
+ sseBuilder.error(e);
+ }
+ });
+ }
+
+ /**
+ * Converts data from one type to another using the configured ObjectMapper.
+ * @param data The source data object to convert
+ * @param typeRef The target type reference
+ * @return The converted object of type T
+ * @param The target type
+ */
+ @Override
+ public T unmarshalFrom(Object data, TypeReference typeRef) {
+ return objectMapper.convertValue(data, typeRef);
+ }
+
+ /**
+ * Initiates a graceful shutdown of the transport.
+ * @return A Mono that completes when the shutdown is complete
+ */
+ @Override
+ public Mono closeGracefully() {
+ return Mono.fromRunnable(() -> {
+ logger.debug("Closing session transport: {}", sessionId);
+ try {
+ sseBuilder.complete();
+ logger.debug("Successfully completed SSE builder for session {}", sessionId);
+ }
+ catch (Exception e) {
+ logger.warn("Failed to complete SSE builder for session {}: {}", sessionId, e.getMessage());
+ }
+ });
+ }
+
+ /**
+ * Closes the transport immediately.
+ */
+ @Override
+ public void close() {
try {
sseBuilder.complete();
- logger.debug("Successfully completed SSE emitter for session {}", id);
+ logger.debug("Successfully completed SSE builder for session {}", sessionId);
}
catch (Exception e) {
- logger.warn("Failed to complete SSE emitter for session {}: {}", id, e.getMessage());
- // sseBuilder.error(e);
+ logger.warn("Failed to complete SSE builder for session {}: {}", sessionId, e.getMessage());
}
}
}
- /**
- * Converts data from one type to another using the configured ObjectMapper. This is
- * particularly useful for handling complex JSON-RPC parameter types.
- * @param data The source data object to convert
- * @param typeRef The target type reference
- * @return The converted object of type T
- * @param The target type
- */
- @Override
- public T unmarshalFrom(Object data, TypeReference typeRef) {
- return this.objectMapper.convertValue(data, typeRef);
- }
-
- /**
- * Initiates a graceful shutdown of the transport. This method:
- *
- *
Sets the closing flag to prevent new connections
- *
Closes all active SSE connections
- *
Removes all session records
- *
- * @return A Mono that completes when all cleanup operations are finished
- */
- @Override
- public Mono closeGracefully() {
- return Mono.fromRunnable(() -> {
- this.isClosing = true;
- logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size());
-
- sessions.values().forEach(session -> {
- String sessionId = session.id;
- session.close();
- sessions.remove(sessionId);
- });
-
- logger.debug("Graceful shutdown completed");
- });
- }
-
- /**
- * Returns the RouterFunction that defines the HTTP endpoints for this transport. The
- * router function handles two endpoints:
- *
- *
GET /sse - For establishing SSE connections
- *
POST [messageEndpoint] - For receiving JSON-RPC messages from clients
- *
- * @return The configured RouterFunction for handling HTTP requests
- */
- public RouterFunction getRouterFunction() {
- return this.routerFunction;
- }
-
}
diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/TomcatTestUtil.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/TomcatTestUtil.java
new file mode 100644
index 00000000..ccf9e2d7
--- /dev/null
+++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/TomcatTestUtil.java
@@ -0,0 +1,68 @@
+/*
+* Copyright 2025 - 2025 the original author or authors.
+*/
+package io.modelcontextprotocol.server;
+
+import java.io.IOException;
+import java.net.InetSocketAddress;
+import java.net.ServerSocket;
+
+import org.apache.catalina.Context;
+import org.apache.catalina.startup.Tomcat;
+
+import org.springframework.web.context.support.AnnotationConfigWebApplicationContext;
+import org.springframework.web.servlet.DispatcherServlet;
+
+/**
+ * @author Christian Tzolov
+ */
+public class TomcatTestUtil {
+
+ TomcatTestUtil() {
+ // Prevent instantiation
+ }
+
+ public record TomcatServer(Tomcat tomcat, AnnotationConfigWebApplicationContext appContext) {
+ }
+
+ public static TomcatServer createTomcatServer(String contextPath, int port, Class> componentClass) {
+
+ // Set up Tomcat first
+ var tomcat = new Tomcat();
+ tomcat.setPort(port);
+
+ // Set Tomcat base directory to java.io.tmpdir to avoid permission issues
+ String baseDir = System.getProperty("java.io.tmpdir");
+ tomcat.setBaseDir(baseDir);
+
+ // Use the same directory for document base
+ Context context = tomcat.addContext(contextPath, baseDir);
+
+ // Create and configure Spring WebMvc context
+ var appContext = new AnnotationConfigWebApplicationContext();
+ appContext.register(componentClass);
+ appContext.setServletContext(context.getServletContext());
+ appContext.refresh();
+
+ // Create DispatcherServlet with our Spring context
+ DispatcherServlet dispatcherServlet = new DispatcherServlet(appContext);
+
+ // Add servlet to Tomcat and get the wrapper
+ var wrapper = Tomcat.addServlet(context, "dispatcherServlet", dispatcherServlet);
+ wrapper.setLoadOnStartup(1);
+ wrapper.setAsyncSupported(true);
+ context.addServletMappingDecoded("/*", "dispatcherServlet");
+
+ try {
+ // Configure and start the connector with async support
+ var connector = tomcat.getConnector();
+ connector.setAsyncTimeout(3000); // 3 seconds timeout for async requests
+ }
+ catch (Exception e) {
+ throw new RuntimeException("Failed to start Tomcat", e);
+ }
+
+ return new TomcatServer(tomcat, appContext);
+ }
+
+}
diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportTests.java
index a819920c..6a6ad17e 100644
--- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportTests.java
+++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportTests.java
@@ -5,8 +5,8 @@
package io.modelcontextprotocol.server;
import com.fasterxml.jackson.databind.ObjectMapper;
-import io.modelcontextprotocol.server.transport.WebMvcSseServerTransport;
-import io.modelcontextprotocol.spec.ServerMcpTransport;
+import io.modelcontextprotocol.server.transport.WebMvcSseServerTransportProvider;
+import io.modelcontextprotocol.spec.McpServerTransportProvider;
import org.apache.catalina.Context;
import org.apache.catalina.LifecycleException;
import org.apache.catalina.startup.Tomcat;
@@ -25,24 +25,24 @@ class WebMvcSseAsyncServerTransportTests extends AbstractMcpAsyncServerTests {
private static final String MESSAGE_ENDPOINT = "/mcp/message";
- private static final int PORT = 8181;
+ private static final int PORT = TestUtil.findAvailablePort();
private Tomcat tomcat;
- private WebMvcSseServerTransport transport;
+ private McpServerTransportProvider transportProvider;
@Configuration
@EnableWebMvc
static class TestConfig {
@Bean
- public WebMvcSseServerTransport webMvcSseServerTransport() {
- return new WebMvcSseServerTransport(new ObjectMapper(), MESSAGE_ENDPOINT);
+ public WebMvcSseServerTransportProvider webMvcSseServerTransportProvider() {
+ return new WebMvcSseServerTransportProvider(new ObjectMapper(), MESSAGE_ENDPOINT);
}
@Bean
- public RouterFunction routerFunction(WebMvcSseServerTransport transport) {
- return transport.getRouterFunction();
+ public RouterFunction routerFunction(WebMvcSseServerTransportProvider transportProvider) {
+ return transportProvider.getRouterFunction();
}
}
@@ -50,7 +50,7 @@ public RouterFunction routerFunction(WebMvcSseServerTransport tr
private AnnotationConfigWebApplicationContext appContext;
@Override
- protected ServerMcpTransport createMcpTransport() {
+ protected McpServerTransportProvider createMcpTransportProvider() {
// Set up Tomcat first
tomcat = new Tomcat();
tomcat.setPort(PORT);
@@ -69,11 +69,10 @@ protected ServerMcpTransport createMcpTransport() {
appContext.refresh();
// Get the transport from Spring context
- transport = appContext.getBean(WebMvcSseServerTransport.class);
+ transportProvider = appContext.getBean(WebMvcSseServerTransportProvider.class);
// Create DispatcherServlet with our Spring context
DispatcherServlet dispatcherServlet = new DispatcherServlet(appContext);
- // dispatcherServlet.setThrowExceptionIfNoHandlerFound(true);
// Add servlet to Tomcat and get the wrapper
var wrapper = Tomcat.addServlet(context, "dispatcherServlet", dispatcherServlet);
@@ -88,7 +87,7 @@ protected ServerMcpTransport createMcpTransport() {
throw new RuntimeException("Failed to start Tomcat", e);
}
- return transport;
+ return transportProvider;
}
@Override
@@ -97,8 +96,8 @@ protected void onStart() {
@Override
protected void onClose() {
- if (transport != null) {
- transport.closeGracefully().block();
+ if (transportProvider != null) {
+ transportProvider.closeGracefully().block();
}
if (appContext != null) {
appContext.close();
diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomContextPathTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomContextPathTests.java
new file mode 100644
index 00000000..1b5218cc
--- /dev/null
+++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomContextPathTests.java
@@ -0,0 +1,105 @@
+/*
+ * Copyright 2024 - 2024 the original author or authors.
+ */
+package io.modelcontextprotocol.server;
+
+import com.fasterxml.jackson.databind.ObjectMapper;
+import io.modelcontextprotocol.client.McpClient;
+import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport;
+import io.modelcontextprotocol.server.transport.WebMvcSseServerTransportProvider;
+import io.modelcontextprotocol.spec.McpSchema;
+import org.apache.catalina.LifecycleException;
+import org.apache.catalina.LifecycleState;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+import org.springframework.context.annotation.Bean;
+import org.springframework.context.annotation.Configuration;
+import org.springframework.web.servlet.config.annotation.EnableWebMvc;
+import org.springframework.web.servlet.function.RouterFunction;
+import org.springframework.web.servlet.function.ServerResponse;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+class WebMvcSseCustomContextPathTests {
+
+ private static final String CUSTOM_CONTEXT_PATH = "/app/1";
+
+ private static final int PORT = TestUtil.findAvailablePort();
+
+ private static final String MESSAGE_ENDPOINT = "/mcp/message";
+
+ private WebMvcSseServerTransportProvider mcpServerTransportProvider;
+
+ McpClient.SyncSpec clientBuilder;
+
+ private TomcatTestUtil.TomcatServer tomcatServer;
+
+ @BeforeEach
+ public void before() {
+
+ tomcatServer = TomcatTestUtil.createTomcatServer(CUSTOM_CONTEXT_PATH, PORT, TestConfig.class);
+
+ try {
+ tomcatServer.tomcat().start();
+ assertThat(tomcatServer.tomcat().getServer().getState()).isEqualTo(LifecycleState.STARTED);
+ }
+ catch (Exception e) {
+ throw new RuntimeException("Failed to start Tomcat", e);
+ }
+
+ var clientTransport = HttpClientSseClientTransport.builder("http://localhost:" + PORT)
+ .sseEndpoint(CUSTOM_CONTEXT_PATH + WebMvcSseServerTransportProvider.DEFAULT_SSE_ENDPOINT)
+ .build();
+
+ clientBuilder = McpClient.sync(clientTransport);
+
+ mcpServerTransportProvider = tomcatServer.appContext().getBean(WebMvcSseServerTransportProvider.class);
+ }
+
+ @AfterEach
+ public void after() {
+ if (mcpServerTransportProvider != null) {
+ mcpServerTransportProvider.closeGracefully().block();
+ }
+ if (tomcatServer.appContext() != null) {
+ tomcatServer.appContext().close();
+ }
+ if (tomcatServer.tomcat() != null) {
+ try {
+ tomcatServer.tomcat().stop();
+ tomcatServer.tomcat().destroy();
+ }
+ catch (LifecycleException e) {
+ throw new RuntimeException("Failed to stop Tomcat", e);
+ }
+ }
+ }
+
+ @Test
+ void testCustomContextPath() {
+ McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").build();
+ var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")).build();
+ assertThat(client.initialize()).isNotNull();
+ }
+
+ @Configuration
+ @EnableWebMvc
+ static class TestConfig {
+
+ @Bean
+ public WebMvcSseServerTransportProvider webMvcSseServerTransportProvider() {
+
+ return new WebMvcSseServerTransportProvider(new ObjectMapper(), CUSTOM_CONTEXT_PATH, MESSAGE_ENDPOINT,
+ WebMvcSseServerTransportProvider.DEFAULT_SSE_ENDPOINT);
+ }
+
+ @Bean
+ public RouterFunction routerFunction(WebMvcSseServerTransportProvider transportProvider) {
+ return transportProvider.getRouterFunction();
+ }
+
+ }
+
+}
diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java
index 62f69637..b12d6843 100644
--- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java
+++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java
@@ -6,13 +6,14 @@
import java.time.Duration;
import java.util.List;
import java.util.Map;
+import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.modelcontextprotocol.client.McpClient;
import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport;
-import io.modelcontextprotocol.server.transport.WebMvcSseServerTransport;
+import io.modelcontextprotocol.server.transport.WebMvcSseServerTransportProvider;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpSchema.CallToolResult;
@@ -20,39 +21,38 @@
import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest;
import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult;
import io.modelcontextprotocol.spec.McpSchema.InitializeResult;
+import io.modelcontextprotocol.spec.McpSchema.ModelPreferences;
import io.modelcontextprotocol.spec.McpSchema.Role;
import io.modelcontextprotocol.spec.McpSchema.Root;
import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities;
import io.modelcontextprotocol.spec.McpSchema.Tool;
-import org.apache.catalina.Context;
import org.apache.catalina.LifecycleException;
import org.apache.catalina.LifecycleState;
-import org.apache.catalina.startup.Tomcat;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
+import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.client.RestClient;
-import org.springframework.web.context.support.AnnotationConfigWebApplicationContext;
-import org.springframework.web.servlet.DispatcherServlet;
import org.springframework.web.servlet.config.annotation.EnableWebMvc;
import org.springframework.web.servlet.function.RouterFunction;
import org.springframework.web.servlet.function.ServerResponse;
import static org.assertj.core.api.Assertions.assertThat;
-import static org.assertj.core.api.Assertions.assertThatThrownBy;
+import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.awaitility.Awaitility.await;
+import static org.mockito.Mockito.mock;
-public class WebMvcSseIntegrationTests {
+class WebMvcSseIntegrationTests {
- private static final int PORT = 8183;
+ private static final int PORT = TestUtil.findAvailablePort();
private static final String MESSAGE_ENDPOINT = "/mcp/message";
- private WebMvcSseServerTransport mcpServerTransport;
+ private WebMvcSseServerTransportProvider mcpServerTransportProvider;
McpClient.SyncSpec clientBuilder;
@@ -61,80 +61,51 @@ public class WebMvcSseIntegrationTests {
static class TestConfig {
@Bean
- public WebMvcSseServerTransport webMvcSseServerTransport() {
- return new WebMvcSseServerTransport(new ObjectMapper(), MESSAGE_ENDPOINT);
+ public WebMvcSseServerTransportProvider webMvcSseServerTransportProvider() {
+ return new WebMvcSseServerTransportProvider(new ObjectMapper(), MESSAGE_ENDPOINT);
}
@Bean
- public RouterFunction routerFunction(WebMvcSseServerTransport transport) {
- return transport.getRouterFunction();
+ public RouterFunction routerFunction(WebMvcSseServerTransportProvider transportProvider) {
+ return transportProvider.getRouterFunction();
}
}
- private Tomcat tomcat;
-
- private AnnotationConfigWebApplicationContext appContext;
+ private TomcatTestUtil.TomcatServer tomcatServer;
@BeforeEach
public void before() {
- // Set up Tomcat first
- tomcat = new Tomcat();
- tomcat.setPort(PORT);
-
- // Set Tomcat base directory to java.io.tmpdir to avoid permission issues
- String baseDir = System.getProperty("java.io.tmpdir");
- tomcat.setBaseDir(baseDir);
-
- // Use the same directory for document base
- Context context = tomcat.addContext("", baseDir);
-
- // Create and configure Spring WebMvc context
- appContext = new AnnotationConfigWebApplicationContext();
- appContext.register(TestConfig.class);
- appContext.setServletContext(context.getServletContext());
- appContext.refresh();
-
- // Get the transport from Spring context
- mcpServerTransport = appContext.getBean(WebMvcSseServerTransport.class);
-
- // Create DispatcherServlet with our Spring context
- DispatcherServlet dispatcherServlet = new DispatcherServlet(appContext);
- // dispatcherServlet.setThrowExceptionIfNoHandlerFound(true);
-
- // Add servlet to Tomcat and get the wrapper
- var wrapper = Tomcat.addServlet(context, "dispatcherServlet", dispatcherServlet);
- wrapper.setLoadOnStartup(1);
- wrapper.setAsyncSupported(true);
- context.addServletMappingDecoded("/*", "dispatcherServlet");
+ tomcatServer = TomcatTestUtil.createTomcatServer("", PORT, TestConfig.class);
try {
- // Configure and start the connector with async support
- var connector = tomcat.getConnector();
- connector.setAsyncTimeout(3000); // 3 seconds timeout for async requests
- tomcat.start();
- assertThat(tomcat.getServer().getState() == LifecycleState.STARTED);
+ tomcatServer.tomcat().start();
+ assertThat(tomcatServer.tomcat().getServer().getState()).isEqualTo(LifecycleState.STARTED);
}
catch (Exception e) {
throw new RuntimeException("Failed to start Tomcat", e);
}
- this.clientBuilder = McpClient.sync(new HttpClientSseClientTransport("http://localhost:" + PORT));
+ clientBuilder = McpClient.sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT).build());
+
+ // Get the transport from Spring context
+ mcpServerTransportProvider = tomcatServer.appContext().getBean(WebMvcSseServerTransportProvider.class);
+
}
@AfterEach
public void after() {
- if (mcpServerTransport != null) {
- mcpServerTransport.closeGracefully().block();
+ if (mcpServerTransportProvider != null) {
+ mcpServerTransportProvider.closeGracefully().block();
}
- if (appContext != null) {
- appContext.close();
+ if (tomcatServer.appContext() != null) {
+ tomcatServer.appContext().close();
}
- if (tomcat != null) {
+ if (tomcatServer.tomcat() != null) {
try {
- tomcat.stop();
- tomcat.destroy();
+ tomcatServer.tomcat().stop();
+ tomcatServer.tomcat().destroy();
}
catch (LifecycleException e) {
throw new RuntimeException("Failed to stop Tomcat", e);
@@ -146,81 +117,244 @@ public void after() {
// Sampling Tests
// ---------------------------------------
@Test
- void testCreateMessageWithoutInitialization() {
- var mcpAsyncServer = McpServer.async(mcpServerTransport).serverInfo("test-server", "1.0.0").build();
+ void testCreateMessageWithoutSamplingCapabilities() {
+
+ McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification(
+ new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> {
- var messages = List
- .of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Test message")));
- var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0);
+ exchange.createMessage(mock(McpSchema.CreateMessageRequest.class)).block();
- var request = new McpSchema.CreateMessageRequest(messages, modelPrefs, null,
- McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), Map.of());
+ return Mono.just(mock(CallToolResult.class));
+ });
+
+ //@formatter:off
+ var server = McpServer.async(mcpServerTransportProvider)
+ .serverInfo("test-server", "1.0.0")
+ .tools(tool)
+ .build();
+
+ try (
+ // Create client without sampling capabilities
+ var client = clientBuilder
+ .clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0"))
+ .build()) {//@formatter:on
- StepVerifier.create(mcpAsyncServer.createMessage(request)).verifyErrorSatisfies(error -> {
- assertThat(error).isInstanceOf(McpError.class)
- .hasMessage("Client must be initialized. Call the initialize method first!");
- });
+ assertThat(client.initialize()).isNotNull();
+
+ try {
+ client.callTool(new McpSchema.CallToolRequest("tool1", Map.of()));
+ }
+ catch (McpError e) {
+ assertThat(e).isInstanceOf(McpError.class)
+ .hasMessage("Client must be configured with sampling capabilities");
+ }
+ }
+ server.close();
}
@Test
- void testCreateMessageWithoutSamplingCapabilities() {
+ void testCreateMessageSuccess() {
+
+ Function samplingHandler = request -> {
+ assertThat(request.messages()).hasSize(1);
+ assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class);
- var mcpAsyncServer = McpServer.async(mcpServerTransport).serverInfo("test-server", "1.0.0").build();
+ return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName",
+ CreateMessageResult.StopReason.STOP_SEQUENCE);
+ };
+
+ CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")),
+ null);
+
+ McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification(
+ new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> {
+
+ var createMessageRequest = McpSchema.CreateMessageRequest.builder()
+ .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER,
+ new McpSchema.TextContent("Test message"))))
+ .modelPreferences(ModelPreferences.builder()
+ .hints(List.of())
+ .costPriority(1.0)
+ .speedPriority(1.0)
+ .intelligencePriority(1.0)
+ .build())
+ .build();
+
+ StepVerifier.create(exchange.createMessage(createMessageRequest)).consumeNextWith(result -> {
+ assertThat(result).isNotNull();
+ assertThat(result.role()).isEqualTo(Role.USER);
+ assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class);
+ assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message");
+ assertThat(result.model()).isEqualTo("MockModelName");
+ assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE);
+ }).verifyComplete();
+
+ return Mono.just(callResponse);
+ });
- var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")).build();
+ //@formatter:off
+ var mcpServer = McpServer.async(mcpServerTransportProvider)
+ .serverInfo("test-server", "1.0.0")
+ .tools(tool)
+ .build();
- InitializeResult initResult = client.initialize();
- assertThat(initResult).isNotNull();
+ try (
+ var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0"))
+ .capabilities(ClientCapabilities.builder().sampling().build())
+ .sampling(samplingHandler)
+ .build()) {//@formatter:on
- var messages = List
- .of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Test message")));
- var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0);
+ InitializeResult initResult = mcpClient.initialize();
+ assertThat(initResult).isNotNull();
- var request = new McpSchema.CreateMessageRequest(messages, modelPrefs, null,
- McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), Map.of());
+ CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of()));
- StepVerifier.create(mcpAsyncServer.createMessage(request)).verifyErrorSatisfies(error -> {
- assertThat(error).isInstanceOf(McpError.class)
- .hasMessage("Client must be configured with sampling capabilities");
- });
+ assertThat(response).isNotNull().isEqualTo(callResponse);
+ }
+ mcpServer.close();
}
@Test
- void testCreateMessageSuccess() throws InterruptedException {
+ void testCreateMessageWithRequestTimeoutSuccess() throws InterruptedException {
- var mcpAsyncServer = McpServer.async(mcpServerTransport).serverInfo("test-server", "1.0.0").build();
+ // Client
Function samplingHandler = request -> {
assertThat(request.messages()).hasSize(1);
assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class);
+ try {
+ TimeUnit.SECONDS.sleep(2);
+ }
+ catch (InterruptedException e) {
+ throw new RuntimeException(e);
+ }
+ return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName",
+ CreateMessageResult.StopReason.STOP_SEQUENCE);
+ };
+
+ var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0"))
+ .capabilities(ClientCapabilities.builder().sampling().build())
+ .sampling(samplingHandler)
+ .build();
+
+ // Server
+
+ CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")),
+ null);
+
+ McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification(
+ new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> {
+
+ var craeteMessageRequest = McpSchema.CreateMessageRequest.builder()
+ .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER,
+ new McpSchema.TextContent("Test message"))))
+ .modelPreferences(ModelPreferences.builder()
+ .hints(List.of())
+ .costPriority(1.0)
+ .speedPriority(1.0)
+ .intelligencePriority(1.0)
+ .build())
+ .build();
+
+ StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> {
+ assertThat(result).isNotNull();
+ assertThat(result.role()).isEqualTo(Role.USER);
+ assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class);
+ assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message");
+ assertThat(result.model()).isEqualTo("MockModelName");
+ assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE);
+ }).verifyComplete();
+
+ return Mono.just(callResponse);
+ });
+
+ var mcpServer = McpServer.async(mcpServerTransportProvider)
+ .serverInfo("test-server", "1.0.0")
+ .requestTimeout(Duration.ofSeconds(4))
+ .tools(tool)
+ .build();
+
+ InitializeResult initResult = mcpClient.initialize();
+ assertThat(initResult).isNotNull();
+ CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of()));
+
+ assertThat(response).isNotNull();
+ assertThat(response).isEqualTo(callResponse);
+
+ mcpClient.close();
+ mcpServer.close();
+ }
+
+ @Test
+ void testCreateMessageWithRequestTimeoutFail() throws InterruptedException {
+
+ // Client
+
+ Function samplingHandler = request -> {
+ assertThat(request.messages()).hasSize(1);
+ assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class);
+ try {
+ TimeUnit.SECONDS.sleep(2);
+ }
+ catch (InterruptedException e) {
+ throw new RuntimeException(e);
+ }
return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName",
CreateMessageResult.StopReason.STOP_SEQUENCE);
};
- var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0"))
+ var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0"))
.capabilities(ClientCapabilities.builder().sampling().build())
.sampling(samplingHandler)
.build();
- InitializeResult initResult = client.initialize();
+ // Server
+
+ CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")),
+ null);
+
+ McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification(
+ new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> {
+
+ var craeteMessageRequest = McpSchema.CreateMessageRequest.builder()
+ .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER,
+ new McpSchema.TextContent("Test message"))))
+ .modelPreferences(ModelPreferences.builder()
+ .hints(List.of())
+ .costPriority(1.0)
+ .speedPriority(1.0)
+ .intelligencePriority(1.0)
+ .build())
+ .build();
+
+ StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> {
+ assertThat(result).isNotNull();
+ assertThat(result.role()).isEqualTo(Role.USER);
+ assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class);
+ assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message");
+ assertThat(result.model()).isEqualTo("MockModelName");
+ assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE);
+ }).verifyComplete();
+
+ return Mono.just(callResponse);
+ });
+
+ var mcpServer = McpServer.async(mcpServerTransportProvider)
+ .serverInfo("test-server", "1.0.0")
+ .requestTimeout(Duration.ofSeconds(1))
+ .tools(tool)
+ .build();
+
+ InitializeResult initResult = mcpClient.initialize();
assertThat(initResult).isNotNull();
- var messages = List
- .of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Test message")));
- var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0);
-
- var request = new McpSchema.CreateMessageRequest(messages, modelPrefs, null,
- McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), Map.of());
-
- StepVerifier.create(mcpAsyncServer.createMessage(request)).consumeNextWith(result -> {
- assertThat(result).isNotNull();
- assertThat(result.role()).isEqualTo(Role.USER);
- assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class);
- assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message");
- assertThat(result.model()).isEqualTo("MockModelName");
- assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE);
- }).verifyComplete();
+ assertThatExceptionOfType(McpError.class).isThrownBy(() -> {
+ mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of()));
+ }).withMessageContaining("Timeout");
+
+ mcpClient.close();
+ mcpServer.close();
}
// ---------------------------------------
@@ -231,117 +365,129 @@ void testRootsSuccess() {
List roots = List.of(new Root("uri1://", "root1"), new Root("uri2://", "root2"));
AtomicReference> rootsRef = new AtomicReference<>();
- var mcpServer = McpServer.sync(mcpServerTransport)
- .rootsChangeConsumer(rootsUpdate -> rootsRef.set(rootsUpdate))
- .build();
- var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build())
- .roots(roots)
+ var mcpServer = McpServer.sync(mcpServerTransportProvider)
+ .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate))
.build();
- InitializeResult initResult = mcpClient.initialize();
- assertThat(initResult).isNotNull();
+ try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build())
+ .roots(roots)
+ .build()) {
- assertThat(rootsRef.get()).isNull();
+ InitializeResult initResult = mcpClient.initialize();
+ assertThat(initResult).isNotNull();
- assertThat(mcpServer.listRoots().roots()).containsAll(roots);
+ assertThat(rootsRef.get()).isNull();
- mcpClient.rootsListChangedNotification();
+ mcpClient.rootsListChangedNotification();
- await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> {
- assertThat(rootsRef.get()).containsAll(roots);
- });
+ await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> {
+ assertThat(rootsRef.get()).containsAll(roots);
+ });
- // Remove a root
- mcpClient.removeRoot(roots.get(0).uri());
+ // Remove a root
+ mcpClient.removeRoot(roots.get(0).uri());
- await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> {
- assertThat(rootsRef.get()).containsAll(List.of(roots.get(1)));
- });
+ await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> {
+ assertThat(rootsRef.get()).containsAll(List.of(roots.get(1)));
+ });
- // Add a new root
- var root3 = new Root("uri3://", "root3");
- mcpClient.addRoot(root3);
+ // Add a new root
+ var root3 = new Root("uri3://", "root3");
+ mcpClient.addRoot(root3);
- await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> {
- assertThat(rootsRef.get()).containsAll(List.of(roots.get(1), root3));
- });
+ await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> {
+ assertThat(rootsRef.get()).containsAll(List.of(roots.get(1), root3));
+ });
+ }
- mcpClient.close();
mcpServer.close();
}
@Test
void testRootsWithoutCapability() {
- var mcpServer = McpServer.sync(mcpServerTransport).rootsChangeConsumer(rootsUpdate -> {
- }).build();
- // Create client without roots capability
- // No roots capability
- var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()).build();
+ McpServerFeatures.SyncToolSpecification tool = new McpServerFeatures.SyncToolSpecification(
+ new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> {
- InitializeResult initResult = mcpClient.initialize();
- assertThat(initResult).isNotNull();
+ exchange.listRoots(); // try to list roots
- // Attempt to list roots should fail
- assertThatThrownBy(() -> mcpServer.listRoots().roots()).isInstanceOf(McpError.class)
- .hasMessage("Roots not supported");
+ return mock(CallToolResult.class);
+ });
+
+ var mcpServer = McpServer.sync(mcpServerTransportProvider).rootsChangeHandler((exchange, rootsUpdate) -> {
+ }).tools(tool).build();
+
+ try (
+ // Create client without roots capability
+ // No roots capability
+ var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()).build()) {
+
+ assertThat(mcpClient.initialize()).isNotNull();
+
+ // Attempt to list roots should fail
+ try {
+ mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of()));
+ }
+ catch (McpError e) {
+ assertThat(e).isInstanceOf(McpError.class).hasMessage("Roots not supported");
+ }
+ }
- mcpClient.close();
mcpServer.close();
}
@Test
- void testRootsWithEmptyRootsList() {
+ void testRootsNotificationWithEmptyRootsList() {
AtomicReference> rootsRef = new AtomicReference<>();
- var mcpServer = McpServer.sync(mcpServerTransport)
- .rootsChangeConsumer(rootsUpdate -> rootsRef.set(rootsUpdate))
+
+ var mcpServer = McpServer.sync(mcpServerTransportProvider)
+ .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate))
.build();
- var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build())
+ try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build())
.roots(List.of()) // Empty roots list
- .build();
+ .build()) {
- InitializeResult initResult = mcpClient.initialize();
- assertThat(initResult).isNotNull();
+ InitializeResult initResult = mcpClient.initialize();
+ assertThat(initResult).isNotNull();
- mcpClient.rootsListChangedNotification();
+ mcpClient.rootsListChangedNotification();
- await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> {
- assertThat(rootsRef.get()).isEmpty();
- });
+ await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> {
+ assertThat(rootsRef.get()).isEmpty();
+ });
+ }
- mcpClient.close();
mcpServer.close();
}
@Test
- void testRootsWithMultipleConsumers() {
+ void testRootsWithMultipleHandlers() {
List roots = List.of(new Root("uri1://", "root1"));
AtomicReference> rootsRef1 = new AtomicReference<>();
AtomicReference> rootsRef2 = new AtomicReference<>();
- var mcpServer = McpServer.sync(mcpServerTransport)
- .rootsChangeConsumer(rootsUpdate -> rootsRef1.set(rootsUpdate))
- .rootsChangeConsumer(rootsUpdate -> rootsRef2.set(rootsUpdate))
+ var mcpServer = McpServer.sync(mcpServerTransportProvider)
+ .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef1.set(rootsUpdate))
+ .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef2.set(rootsUpdate))
.build();
- var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build())
+ try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build())
.roots(roots)
- .build();
+ .build()) {
- InitializeResult initResult = mcpClient.initialize();
- assertThat(initResult).isNotNull();
+ assertThat(mcpClient.initialize()).isNotNull();
- mcpClient.rootsListChangedNotification();
+ mcpClient.rootsListChangedNotification();
- await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> {
- assertThat(rootsRef1.get()).containsAll(roots);
- assertThat(rootsRef2.get()).containsAll(roots);
- });
+ await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> {
+ assertThat(rootsRef1.get()).containsAll(roots);
+ assertThat(rootsRef2.get()).containsAll(roots);
+ });
+ }
- mcpClient.close();
mcpServer.close();
}
@@ -350,28 +496,26 @@ void testRootsServerCloseWithActiveSubscription() {
List roots = List.of(new Root("uri1://", "root1"));
AtomicReference> rootsRef = new AtomicReference<>();
- var mcpServer = McpServer.sync(mcpServerTransport)
- .rootsChangeConsumer(rootsUpdate -> rootsRef.set(rootsUpdate))
+
+ var mcpServer = McpServer.sync(mcpServerTransportProvider)
+ .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate))
.build();
- var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build())
+ try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build())
.roots(roots)
- .build();
+ .build()) {
- InitializeResult initResult = mcpClient.initialize();
- assertThat(initResult).isNotNull();
+ InitializeResult initResult = mcpClient.initialize();
+ assertThat(initResult).isNotNull();
- mcpClient.rootsListChangedNotification();
+ mcpClient.rootsListChangedNotification();
- await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> {
- assertThat(rootsRef.get()).containsAll(roots);
- });
+ await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> {
+ assertThat(rootsRef.get()).containsAll(roots);
+ });
+ }
- // Close server while subscription is active
mcpServer.close();
-
- // Verify client can handle server closure gracefully
- mcpClient.close();
}
// ---------------------------------------
@@ -390,36 +534,35 @@ void testRootsServerCloseWithActiveSubscription() {
void testToolCallSuccess() {
var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null);
- McpServerFeatures.SyncToolRegistration tool1 = new McpServerFeatures.SyncToolRegistration(
- new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), request -> {
+ McpServerFeatures.SyncToolSpecification tool1 = new McpServerFeatures.SyncToolSpecification(
+ new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> {
// perform a blocking call to a remote service
String response = RestClient.create()
.get()
- .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md")
+ .uri("https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md")
.retrieve()
.body(String.class);
assertThat(response).isNotBlank();
return callResponse;
});
- var mcpServer = McpServer.sync(mcpServerTransport)
+ var mcpServer = McpServer.sync(mcpServerTransportProvider)
.capabilities(ServerCapabilities.builder().tools(true).build())
.tools(tool1)
.build();
- var mcpClient = clientBuilder.build();
+ try (var mcpClient = clientBuilder.build()) {
- InitializeResult initResult = mcpClient.initialize();
- assertThat(initResult).isNotNull();
+ InitializeResult initResult = mcpClient.initialize();
+ assertThat(initResult).isNotNull();
- assertThat(mcpClient.listTools().tools()).contains(tool1.tool());
+ assertThat(mcpClient.listTools().tools()).contains(tool1.tool());
- CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of()));
+ CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of()));
- assertThat(response).isNotNull();
- assertThat(response).isEqualTo(callResponse);
+ assertThat(response).isNotNull().isEqualTo(callResponse);
+ }
- mcpClient.close();
mcpServer.close();
}
@@ -427,80 +570,82 @@ void testToolCallSuccess() {
void testToolListChangeHandlingSuccess() {
var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null);
- McpServerFeatures.SyncToolRegistration tool1 = new McpServerFeatures.SyncToolRegistration(
- new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), request -> {
+ McpServerFeatures.SyncToolSpecification tool1 = new McpServerFeatures.SyncToolSpecification(
+ new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> {
// perform a blocking call to a remote service
String response = RestClient.create()
.get()
- .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md")
+ .uri("https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md")
.retrieve()
.body(String.class);
assertThat(response).isNotBlank();
return callResponse;
});
- var mcpServer = McpServer.sync(mcpServerTransport)
+ AtomicReference> rootsRef = new AtomicReference<>();
+
+ var mcpServer = McpServer.sync(mcpServerTransportProvider)
.capabilities(ServerCapabilities.builder().tools(true).build())
.tools(tool1)
.build();
- AtomicReference> rootsRef = new AtomicReference<>();
- var mcpClient = clientBuilder.toolsChangeConsumer(toolsUpdate -> {
+ try (var mcpClient = clientBuilder.toolsChangeConsumer(toolsUpdate -> {
// perform a blocking call to a remote service
String response = RestClient.create()
.get()
- .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md")
+ .uri("https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md")
.retrieve()
.body(String.class);
assertThat(response).isNotBlank();
rootsRef.set(toolsUpdate);
- }).build();
+ }).build()) {
- InitializeResult initResult = mcpClient.initialize();
- assertThat(initResult).isNotNull();
+ InitializeResult initResult = mcpClient.initialize();
+ assertThat(initResult).isNotNull();
- assertThat(rootsRef.get()).isNull();
+ assertThat(rootsRef.get()).isNull();
- assertThat(mcpClient.listTools().tools()).contains(tool1.tool());
+ assertThat(mcpClient.listTools().tools()).contains(tool1.tool());
- mcpServer.notifyToolsListChanged();
+ mcpServer.notifyToolsListChanged();
- await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> {
- assertThat(rootsRef.get()).containsAll(List.of(tool1.tool()));
- });
+ await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> {
+ assertThat(rootsRef.get()).containsAll(List.of(tool1.tool()));
+ });
- // Remove a tool
- mcpServer.removeTool("tool1");
+ // Remove a tool
+ mcpServer.removeTool("tool1");
- await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> {
- assertThat(rootsRef.get()).isEmpty();
- });
+ await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> {
+ assertThat(rootsRef.get()).isEmpty();
+ });
- // Add a new tool
- McpServerFeatures.SyncToolRegistration tool2 = new McpServerFeatures.SyncToolRegistration(
- new McpSchema.Tool("tool2", "tool2 description", emptyJsonSchema), request -> callResponse);
+ // Add a new tool
+ McpServerFeatures.SyncToolSpecification tool2 = new McpServerFeatures.SyncToolSpecification(
+ new McpSchema.Tool("tool2", "tool2 description", emptyJsonSchema),
+ (exchange, request) -> callResponse);
- mcpServer.addTool(tool2);
+ mcpServer.addTool(tool2);
- await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> {
- assertThat(rootsRef.get()).containsAll(List.of(tool2.tool()));
- });
+ await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> {
+ assertThat(rootsRef.get()).containsAll(List.of(tool2.tool()));
+ });
+ }
- mcpClient.close();
mcpServer.close();
}
@Test
void testInitialize() {
- var mcpServer = McpServer.sync(mcpServerTransport).build();
+ var mcpServer = McpServer.sync(mcpServerTransportProvider).build();
- var mcpClient = clientBuilder.build();
+ try (var mcpClient = clientBuilder.build()) {
- InitializeResult initResult = mcpClient.initialize();
- assertThat(initResult).isNotNull();
+ InitializeResult initResult = mcpClient.initialize();
+ assertThat(initResult).isNotNull();
+ }
- mcpClient.close();
mcpServer.close();
}
diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseSyncServerTransportTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseSyncServerTransportTests.java
index 249b4dea..1964703c 100644
--- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseSyncServerTransportTests.java
+++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseSyncServerTransportTests.java
@@ -5,8 +5,7 @@
package io.modelcontextprotocol.server;
import com.fasterxml.jackson.databind.ObjectMapper;
-import io.modelcontextprotocol.server.transport.WebMvcSseServerTransport;
-import io.modelcontextprotocol.spec.ServerMcpTransport;
+import io.modelcontextprotocol.server.transport.WebMvcSseServerTransportProvider;
import org.apache.catalina.Context;
import org.apache.catalina.LifecycleException;
import org.apache.catalina.startup.Tomcat;
@@ -25,24 +24,24 @@ class WebMvcSseSyncServerTransportTests extends AbstractMcpSyncServerTests {
private static final String MESSAGE_ENDPOINT = "/mcp/message";
- private static final int PORT = 8181;
+ private static final int PORT = TestUtil.findAvailablePort();
private Tomcat tomcat;
- private WebMvcSseServerTransport transport;
+ private WebMvcSseServerTransportProvider transportProvider;
@Configuration
@EnableWebMvc
static class TestConfig {
@Bean
- public WebMvcSseServerTransport webMvcSseServerTransport() {
- return new WebMvcSseServerTransport(new ObjectMapper(), MESSAGE_ENDPOINT);
+ public WebMvcSseServerTransportProvider webMvcSseServerTransportProvider() {
+ return new WebMvcSseServerTransportProvider(new ObjectMapper(), MESSAGE_ENDPOINT);
}
@Bean
- public RouterFunction routerFunction(WebMvcSseServerTransport transport) {
- return transport.getRouterFunction();
+ public RouterFunction routerFunction(WebMvcSseServerTransportProvider transportProvider) {
+ return transportProvider.getRouterFunction();
}
}
@@ -50,7 +49,7 @@ public RouterFunction routerFunction(WebMvcSseServerTransport tr
private AnnotationConfigWebApplicationContext appContext;
@Override
- protected ServerMcpTransport createMcpTransport() {
+ protected WebMvcSseServerTransportProvider createMcpTransportProvider() {
// Set up Tomcat first
tomcat = new Tomcat();
tomcat.setPort(PORT);
@@ -69,11 +68,10 @@ protected ServerMcpTransport createMcpTransport() {
appContext.refresh();
// Get the transport from Spring context
- transport = appContext.getBean(WebMvcSseServerTransport.class);
+ transportProvider = appContext.getBean(WebMvcSseServerTransportProvider.class);
// Create DispatcherServlet with our Spring context
DispatcherServlet dispatcherServlet = new DispatcherServlet(appContext);
- // dispatcherServlet.setThrowExceptionIfNoHandlerFound(true);
// Add servlet to Tomcat and get the wrapper
var wrapper = Tomcat.addServlet(context, "dispatcherServlet", dispatcherServlet);
@@ -88,7 +86,7 @@ protected ServerMcpTransport createMcpTransport() {
throw new RuntimeException("Failed to start Tomcat", e);
}
- return transport;
+ return transportProvider;
}
@Override
@@ -97,8 +95,8 @@ protected void onStart() {
@Override
protected void onClose() {
- if (transport != null) {
- transport.closeGracefully().block();
+ if (transportProvider != null) {
+ transportProvider.closeGracefully().block();
}
if (appContext != null) {
appContext.close();
diff --git a/mcp-test/pom.xml b/mcp-test/pom.xml
index 717f0319..a6e5bdb0 100644
--- a/mcp-test/pom.xml
+++ b/mcp-test/pom.xml
@@ -6,7 +6,7 @@
io.modelcontextprotocol.sdkmcp-parent
- 0.8.0-SNAPSHOT
+ 0.11.0-SNAPSHOTmcp-testjar
@@ -24,7 +24,7 @@
io.modelcontextprotocol.sdkmcp
- 0.8.0-SNAPSHOT
+ 0.11.0-SNAPSHOT
@@ -80,6 +80,7 @@
logback-classic${logback.version}
+
diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/MockMcpTransport.java b/mcp-test/src/main/java/io/modelcontextprotocol/MockMcpTransport.java
index d4e48ea7..5484a63c 100644
--- a/mcp-test/src/main/java/io/modelcontextprotocol/MockMcpTransport.java
+++ b/mcp-test/src/main/java/io/modelcontextprotocol/MockMcpTransport.java
@@ -11,19 +11,22 @@
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
-import io.modelcontextprotocol.spec.ClientMcpTransport;
+import io.modelcontextprotocol.spec.McpClientTransport;
import io.modelcontextprotocol.spec.McpSchema;
-import io.modelcontextprotocol.spec.ServerMcpTransport;
import io.modelcontextprotocol.spec.McpSchema.JSONRPCNotification;
import io.modelcontextprotocol.spec.McpSchema.JSONRPCRequest;
+import io.modelcontextprotocol.spec.McpServerTransport;
import reactor.core.publisher.Mono;
import reactor.core.publisher.Sinks;
/**
- * A mock implementation of the {@link ClientMcpTransport} and {@link ServerMcpTransport}
+ * A mock implementation of the {@link McpClientTransport} and {@link McpServerTransport}
* interfaces.
+ *
+ * @deprecated not used. to be removed in the future.
*/
-public class MockMcpTransport implements ClientMcpTransport, ServerMcpTransport {
+@Deprecated
+public class MockMcpTransport implements McpClientTransport, McpServerTransport {
private final Sinks.Many inbound = Sinks.many().unicast().onBackpressureBuffer();
diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java
index cdcba4d1..5452c8ea 100644
--- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java
+++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java
@@ -6,10 +6,13 @@
import java.time.Duration;
import java.util.Map;
+import java.util.Objects;
import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.Consumer;
import java.util.function.Function;
-import io.modelcontextprotocol.spec.ClientMcpTransport;
+import io.modelcontextprotocol.spec.McpClientTransport;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpSchema.CallToolRequest;
@@ -28,6 +31,7 @@
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
+import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
@@ -44,13 +48,9 @@
*/
public abstract class AbstractMcpAsyncClientTests {
- private McpAsyncClient mcpAsyncClient;
-
- protected ClientMcpTransport mcpTransport;
-
private static final String ECHO_TEST_MESSAGE = "Hello MCP Spring AI!";
- abstract protected ClientMcpTransport createMcpTransport();
+ abstract protected McpClientTransport createMcpTransport();
protected void onStart() {
}
@@ -58,275 +58,334 @@ protected void onStart() {
protected void onClose() {
}
- protected Duration getTimeoutDuration() {
+ protected Duration getRequestTimeout() {
+ return Duration.ofSeconds(14);
+ }
+
+ protected Duration getInitializationTimeout() {
return Duration.ofSeconds(2);
}
- @BeforeEach
- void setUp() {
- onStart();
- this.mcpTransport = createMcpTransport();
+ McpAsyncClient client(McpClientTransport transport) {
+ return client(transport, Function.identity());
+ }
+
+ McpAsyncClient client(McpClientTransport transport, Function customizer) {
+ AtomicReference client = new AtomicReference<>();
assertThatCode(() -> {
- mcpAsyncClient = McpClient.async(mcpTransport)
- .requestTimeout(getTimeoutDuration())
- .capabilities(ClientCapabilities.builder().roots(true).build())
- .build();
+ McpClient.AsyncSpec builder = McpClient.async(transport)
+ .requestTimeout(getRequestTimeout())
+ .initializationTimeout(getInitializationTimeout())
+ .capabilities(ClientCapabilities.builder().roots(true).build());
+ builder = customizer.apply(builder);
+ client.set(builder.build());
}).doesNotThrowAnyException();
+
+ return client.get();
+ }
+
+ void withClient(McpClientTransport transport, Consumer c) {
+ withClient(transport, Function.identity(), c);
+ }
+
+ void withClient(McpClientTransport transport, Function customizer,
+ Consumer c) {
+ var client = client(transport, customizer);
+ try {
+ c.accept(client);
+ }
+ finally {
+ StepVerifier.create(client.closeGracefully()).expectComplete().verify(Duration.ofSeconds(10));
+ }
+ }
+
+ @BeforeEach
+ void setUp() {
+ onStart();
}
@AfterEach
void tearDown() {
- if (mcpAsyncClient != null) {
- assertThatCode(() -> mcpAsyncClient.closeGracefully().block(Duration.ofSeconds(10)))
- .doesNotThrowAnyException();
- }
onClose();
}
+ void verifyInitializationTimeout(Function> operation, String action) {
+ withClient(createMcpTransport(), mcpAsyncClient -> {
+ StepVerifier.withVirtualTime(() -> operation.apply(mcpAsyncClient))
+ .expectSubscription()
+ .thenAwait(getInitializationTimeout())
+ .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class)
+ .hasMessage("Client must be initialized before " + action))
+ .verify();
+ });
+ }
+
@Test
void testConstructorWithInvalidArguments() {
- assertThatThrownBy(() -> McpClient.sync(null).build()).isInstanceOf(IllegalArgumentException.class)
+ assertThatThrownBy(() -> McpClient.async(null).build()).isInstanceOf(IllegalArgumentException.class)
.hasMessage("Transport must not be null");
- assertThatThrownBy(() -> McpClient.sync(mcpTransport).requestTimeout(null).build())
+ assertThatThrownBy(() -> McpClient.async(createMcpTransport()).requestTimeout(null).build())
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("Request timeout must not be null");
}
@Test
void testListToolsWithoutInitialization() {
- assertThatThrownBy(() -> mcpAsyncClient.listTools(null).block()).isInstanceOf(McpError.class)
- .hasMessage("Client must be initialized before listing tools");
+ verifyInitializationTimeout(client -> client.listTools(null), "listing tools");
}
@Test
void testListTools() {
- mcpAsyncClient.initialize().block(Duration.ofSeconds(10));
+ withClient(createMcpTransport(), mcpAsyncClient -> {
+ StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listTools(null)))
+ .consumeNextWith(result -> {
+ assertThat(result.tools()).isNotNull().isNotEmpty();
- StepVerifier.create(mcpAsyncClient.listTools(null)).consumeNextWith(result -> {
- assertThat(result.tools()).isNotNull().isNotEmpty();
-
- Tool firstTool = result.tools().get(0);
- assertThat(firstTool.name()).isNotNull();
- assertThat(firstTool.description()).isNotNull();
- }).verifyComplete();
+ Tool firstTool = result.tools().get(0);
+ assertThat(firstTool.name()).isNotNull();
+ assertThat(firstTool.description()).isNotNull();
+ })
+ .verifyComplete();
+ });
}
@Test
void testPingWithoutInitialization() {
- assertThatThrownBy(() -> mcpAsyncClient.ping().block()).isInstanceOf(McpError.class)
- .hasMessage("Client must be initialized before pinging the server");
+ verifyInitializationTimeout(client -> client.ping(), "pinging the server");
}
@Test
void testPing() {
- mcpAsyncClient.initialize().block(Duration.ofSeconds(10));
- assertThatCode(() -> mcpAsyncClient.ping().block()).doesNotThrowAnyException();
+ withClient(createMcpTransport(), mcpAsyncClient -> {
+ StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.ping()))
+ .expectNextCount(1)
+ .verifyComplete();
+ });
}
@Test
void testCallToolWithoutInitialization() {
CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", ECHO_TEST_MESSAGE));
-
- assertThatThrownBy(() -> mcpAsyncClient.callTool(callToolRequest).block()).isInstanceOf(McpError.class)
- .hasMessage("Client must be initialized before calling tools");
+ verifyInitializationTimeout(client -> client.callTool(callToolRequest), "calling tools");
}
@Test
void testCallTool() {
- mcpAsyncClient.initialize().block(Duration.ofSeconds(10));
+ withClient(createMcpTransport(), mcpAsyncClient -> {
+ CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", ECHO_TEST_MESSAGE));
- CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", ECHO_TEST_MESSAGE));
-
- StepVerifier.create(mcpAsyncClient.callTool(callToolRequest)).consumeNextWith(callToolResult -> {
- assertThat(callToolResult).isNotNull().satisfies(result -> {
- assertThat(result.content()).isNotNull();
- assertThat(result.isError()).isNull();
- });
- }).verifyComplete();
+ StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.callTool(callToolRequest)))
+ .consumeNextWith(callToolResult -> {
+ assertThat(callToolResult).isNotNull().satisfies(result -> {
+ assertThat(result.content()).isNotNull();
+ assertThat(result.isError()).isNull();
+ });
+ })
+ .verifyComplete();
+ });
}
@Test
void testCallToolWithInvalidTool() {
- mcpAsyncClient.initialize().block(Duration.ofSeconds(10));
-
- CallToolRequest invalidRequest = new CallToolRequest("nonexistent_tool", Map.of("message", ECHO_TEST_MESSAGE));
+ withClient(createMcpTransport(), mcpAsyncClient -> {
+ CallToolRequest invalidRequest = new CallToolRequest("nonexistent_tool",
+ Map.of("message", ECHO_TEST_MESSAGE));
- assertThatThrownBy(() -> mcpAsyncClient.callTool(invalidRequest).block()).isInstanceOf(Exception.class);
+ StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.callTool(invalidRequest)))
+ .consumeErrorWith(
+ e -> assertThat(e).isInstanceOf(McpError.class).hasMessage("Unknown tool: nonexistent_tool"))
+ .verify();
+ });
}
@Test
void testListResourcesWithoutInitialization() {
- assertThatThrownBy(() -> mcpAsyncClient.listResources(null).block()).isInstanceOf(McpError.class)
- .hasMessage("Client must be initialized before listing resources");
+ verifyInitializationTimeout(client -> client.listResources(null), "listing resources");
}
@Test
void testListResources() {
- mcpAsyncClient.initialize().block(Duration.ofSeconds(10));
+ withClient(createMcpTransport(), mcpAsyncClient -> {
+ StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listResources(null)))
+ .consumeNextWith(resources -> {
+ assertThat(resources).isNotNull().satisfies(result -> {
+ assertThat(result.resources()).isNotNull();
- StepVerifier.create(mcpAsyncClient.listResources(null)).consumeNextWith(resources -> {
- assertThat(resources).isNotNull().satisfies(result -> {
- assertThat(result.resources()).isNotNull();
-
- if (!result.resources().isEmpty()) {
- Resource firstResource = result.resources().get(0);
- assertThat(firstResource.uri()).isNotNull();
- assertThat(firstResource.name()).isNotNull();
- }
- });
- }).verifyComplete();
+ if (!result.resources().isEmpty()) {
+ Resource firstResource = result.resources().get(0);
+ assertThat(firstResource.uri()).isNotNull();
+ assertThat(firstResource.name()).isNotNull();
+ }
+ });
+ })
+ .verifyComplete();
+ });
}
@Test
void testMcpAsyncClientState() {
- assertThat(mcpAsyncClient).isNotNull();
+ withClient(createMcpTransport(), mcpAsyncClient -> {
+ assertThat(mcpAsyncClient).isNotNull();
+ });
}
@Test
void testListPromptsWithoutInitialization() {
- assertThatThrownBy(() -> mcpAsyncClient.listPrompts(null).block()).isInstanceOf(McpError.class)
- .hasMessage("Client must be initialized before listing prompts");
+ verifyInitializationTimeout(client -> client.listPrompts(null), "listing " + "prompts");
}
@Test
void testListPrompts() {
- mcpAsyncClient.initialize().block(Duration.ofSeconds(10));
-
- StepVerifier.create(mcpAsyncClient.listPrompts(null)).consumeNextWith(prompts -> {
- assertThat(prompts).isNotNull().satisfies(result -> {
- assertThat(result.prompts()).isNotNull();
+ withClient(createMcpTransport(), mcpAsyncClient -> {
+ StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listPrompts(null)))
+ .consumeNextWith(prompts -> {
+ assertThat(prompts).isNotNull().satisfies(result -> {
+ assertThat(result.prompts()).isNotNull();
- if (!result.prompts().isEmpty()) {
- Prompt firstPrompt = result.prompts().get(0);
- assertThat(firstPrompt.name()).isNotNull();
- assertThat(firstPrompt.description()).isNotNull();
- }
- });
- }).verifyComplete();
+ if (!result.prompts().isEmpty()) {
+ Prompt firstPrompt = result.prompts().get(0);
+ assertThat(firstPrompt.name()).isNotNull();
+ assertThat(firstPrompt.description()).isNotNull();
+ }
+ });
+ })
+ .verifyComplete();
+ });
}
@Test
void testGetPromptWithoutInitialization() {
GetPromptRequest request = new GetPromptRequest("simple_prompt", Map.of());
-
- assertThatThrownBy(() -> mcpAsyncClient.getPrompt(request).block()).isInstanceOf(McpError.class)
- .hasMessage("Client must be initialized before getting prompts");
+ verifyInitializationTimeout(client -> client.getPrompt(request), "getting " + "prompts");
}
@Test
void testGetPrompt() {
- mcpAsyncClient.initialize().block(Duration.ofSeconds(10));
-
- StepVerifier.create(mcpAsyncClient.getPrompt(new GetPromptRequest("simple_prompt", Map.of())))
- .consumeNextWith(prompt -> {
- assertThat(prompt).isNotNull().satisfies(result -> {
- assertThat(result.messages()).isNotEmpty();
- assertThat(result.messages()).hasSize(1);
- });
- })
- .verifyComplete();
+ withClient(createMcpTransport(), mcpAsyncClient -> {
+ StepVerifier
+ .create(mcpAsyncClient.initialize()
+ .then(mcpAsyncClient.getPrompt(new GetPromptRequest("simple_prompt", Map.of()))))
+ .consumeNextWith(prompt -> {
+ assertThat(prompt).isNotNull().satisfies(result -> {
+ assertThat(result.messages()).isNotEmpty();
+ assertThat(result.messages()).hasSize(1);
+ });
+ })
+ .verifyComplete();
+ });
}
@Test
void testRootsListChangedWithoutInitialization() {
- assertThatThrownBy(() -> mcpAsyncClient.rootsListChangedNotification().block()).isInstanceOf(McpError.class)
- .hasMessage("Client must be initialized before sending roots list changed notification");
+ verifyInitializationTimeout(client -> client.rootsListChangedNotification(),
+ "sending roots list changed notification");
}
@Test
void testRootsListChanged() {
- mcpAsyncClient.initialize().block(Duration.ofSeconds(10));
-
- assertThatCode(() -> mcpAsyncClient.rootsListChangedNotification().block()).doesNotThrowAnyException();
+ withClient(createMcpTransport(), mcpAsyncClient -> {
+ StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.rootsListChangedNotification()))
+ .verifyComplete();
+ });
}
@Test
void testInitializeWithRootsListProviders() {
- var transport = createMcpTransport();
-
- var client = McpClient.async(transport)
- .requestTimeout(getTimeoutDuration())
- .roots(new Root("file:///test/path", "test-root"))
- .build();
-
- assertThatCode(() -> client.initialize().block(Duration.ofSeconds(10))).doesNotThrowAnyException();
-
- assertThatCode(() -> client.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException();
+ withClient(createMcpTransport(), builder -> builder.roots(new Root("file:///test/path", "test-root")),
+ client -> {
+ StepVerifier.create(client.initialize().then(client.closeGracefully())).verifyComplete();
+ });
}
@Test
void testAddRoot() {
- Root newRoot = new Root("file:///new/test/path", "new-test-root");
- assertThatCode(() -> mcpAsyncClient.addRoot(newRoot).block()).doesNotThrowAnyException();
+ withClient(createMcpTransport(), mcpAsyncClient -> {
+ Root newRoot = new Root("file:///new/test/path", "new-test-root");
+ StepVerifier.create(mcpAsyncClient.addRoot(newRoot)).verifyComplete();
+ });
}
@Test
void testAddRootWithNullValue() {
- assertThatThrownBy(() -> mcpAsyncClient.addRoot(null).block()).hasMessageContaining("Root must not be null");
+ withClient(createMcpTransport(), mcpAsyncClient -> {
+ StepVerifier.create(mcpAsyncClient.addRoot(null))
+ .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class).hasMessage("Root must not be null"))
+ .verify();
+ });
}
@Test
void testRemoveRoot() {
- Root root = new Root("file:///test/path/to/remove", "root-to-remove");
- assertThatCode(() -> {
- mcpAsyncClient.addRoot(root).block();
- mcpAsyncClient.removeRoot(root.uri()).block();
- }).doesNotThrowAnyException();
+ withClient(createMcpTransport(), mcpAsyncClient -> {
+ Root root = new Root("file:///test/path/to/remove", "root-to-remove");
+ StepVerifier.create(mcpAsyncClient.addRoot(root)).verifyComplete();
+
+ StepVerifier.create(mcpAsyncClient.removeRoot(root.uri())).verifyComplete();
+ });
}
@Test
void testRemoveNonExistentRoot() {
- assertThatThrownBy(() -> mcpAsyncClient.removeRoot("nonexistent-uri").block())
- .hasMessageContaining("Root with uri 'nonexistent-uri' not found");
+ withClient(createMcpTransport(), mcpAsyncClient -> {
+ StepVerifier.create(mcpAsyncClient.removeRoot("nonexistent-uri"))
+ .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class)
+ .hasMessage("Root with uri 'nonexistent-uri' not found"))
+ .verify();
+ });
}
@Test
@Disabled
void testReadResource() {
- StepVerifier.create(mcpAsyncClient.listResources()).consumeNextWith(resources -> {
- if (!resources.resources().isEmpty()) {
- Resource firstResource = resources.resources().get(0);
- StepVerifier.create(mcpAsyncClient.readResource(firstResource)).consumeNextWith(result -> {
- assertThat(result).isNotNull();
- assertThat(result.contents()).isNotNull();
- }).verifyComplete();
- }
- }).verifyComplete();
+ withClient(createMcpTransport(), mcpAsyncClient -> {
+ StepVerifier.create(mcpAsyncClient.listResources()).consumeNextWith(resources -> {
+ if (!resources.resources().isEmpty()) {
+ Resource firstResource = resources.resources().get(0);
+ StepVerifier.create(mcpAsyncClient.readResource(firstResource)).consumeNextWith(result -> {
+ assertThat(result).isNotNull();
+ assertThat(result.contents()).isNotNull();
+ }).verifyComplete();
+ }
+ }).verifyComplete();
+ });
}
@Test
void testListResourceTemplatesWithoutInitialization() {
- assertThatThrownBy(() -> mcpAsyncClient.listResourceTemplates().block()).isInstanceOf(McpError.class)
- .hasMessage("Client must be initialized before listing resource templates");
+ verifyInitializationTimeout(client -> client.listResourceTemplates(), "listing resource templates");
}
@Test
void testListResourceTemplates() {
- mcpAsyncClient.initialize().block(Duration.ofSeconds(10));
-
- StepVerifier.create(mcpAsyncClient.listResourceTemplates()).consumeNextWith(result -> {
- assertThat(result).isNotNull();
- assertThat(result.resourceTemplates()).isNotNull();
- }).verifyComplete();
+ withClient(createMcpTransport(), mcpAsyncClient -> {
+ StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listResourceTemplates()))
+ .consumeNextWith(result -> {
+ assertThat(result).isNotNull();
+ assertThat(result.resourceTemplates()).isNotNull();
+ })
+ .verifyComplete();
+ });
}
// @Test
void testResourceSubscription() {
- StepVerifier.create(mcpAsyncClient.listResources()).consumeNextWith(resources -> {
- if (!resources.resources().isEmpty()) {
- Resource firstResource = resources.resources().get(0);
-
- // Test subscribe
- StepVerifier.create(mcpAsyncClient.subscribeResource(new SubscribeRequest(firstResource.uri())))
- .verifyComplete();
-
- // Test unsubscribe
- StepVerifier.create(mcpAsyncClient.unsubscribeResource(new UnsubscribeRequest(firstResource.uri())))
- .verifyComplete();
- }
- }).verifyComplete();
+ withClient(createMcpTransport(), mcpAsyncClient -> {
+ StepVerifier.create(mcpAsyncClient.listResources()).consumeNextWith(resources -> {
+ if (!resources.resources().isEmpty()) {
+ Resource firstResource = resources.resources().get(0);
+
+ // Test subscribe
+ StepVerifier.create(mcpAsyncClient.subscribeResource(new SubscribeRequest(firstResource.uri())))
+ .verifyComplete();
+
+ // Test unsubscribe
+ StepVerifier.create(mcpAsyncClient.unsubscribeResource(new UnsubscribeRequest(firstResource.uri())))
+ .verifyComplete();
+ }
+ }).verifyComplete();
+ });
}
@Test
@@ -335,42 +394,35 @@ void testNotificationHandlers() {
AtomicBoolean resourcesNotificationReceived = new AtomicBoolean(false);
AtomicBoolean promptsNotificationReceived = new AtomicBoolean(false);
- var transport = createMcpTransport();
- var client = McpClient.async(transport)
- .requestTimeout(getTimeoutDuration())
- .toolsChangeConsumer(tools -> Mono.fromRunnable(() -> toolsNotificationReceived.set(true)))
- .resourcesChangeConsumer(resources -> Mono.fromRunnable(() -> resourcesNotificationReceived.set(true)))
- .promptsChangeConsumer(prompts -> Mono.fromRunnable(() -> promptsNotificationReceived.set(true)))
- .build();
-
- assertThatCode(() -> {
- client.initialize().block();
- client.closeGracefully().block();
- }).doesNotThrowAnyException();
+ withClient(createMcpTransport(),
+ builder -> builder
+ .toolsChangeConsumer(tools -> Mono.fromRunnable(() -> toolsNotificationReceived.set(true)))
+ .resourcesChangeConsumer(
+ resources -> Mono.fromRunnable(() -> resourcesNotificationReceived.set(true)))
+ .promptsChangeConsumer(prompts -> Mono.fromRunnable(() -> promptsNotificationReceived.set(true))),
+ mcpAsyncClient -> {
+ StepVerifier.create(mcpAsyncClient.initialize())
+ .expectNextMatches(Objects::nonNull)
+ .verifyComplete();
+ });
}
@Test
void testInitializeWithSamplingCapability() {
- var transport = createMcpTransport();
-
- var capabilities = ClientCapabilities.builder().sampling().build();
-
- var client = McpClient.async(transport)
- .requestTimeout(getTimeoutDuration())
- .capabilities(capabilities)
- .sampling(request -> Mono.just(CreateMessageResult.builder().message("test").model("test-model").build()))
+ ClientCapabilities capabilities = ClientCapabilities.builder().sampling().build();
+ CreateMessageResult createMessageResult = CreateMessageResult.builder()
+ .message("test")
+ .model("test-model")
.build();
-
- assertThatCode(() -> {
- client.initialize().block(Duration.ofSeconds(10));
- client.closeGracefully().block(Duration.ofSeconds(10));
- }).doesNotThrowAnyException();
+ withClient(createMcpTransport(),
+ builder -> builder.capabilities(capabilities).sampling(request -> Mono.just(createMessageResult)),
+ client -> {
+ StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete();
+ });
}
@Test
void testInitializeWithAllCapabilities() {
- var transport = createMcpTransport();
-
var capabilities = ClientCapabilities.builder()
.experimental(Map.of("feature", "test"))
.roots(true)
@@ -379,18 +431,14 @@ void testInitializeWithAllCapabilities() {
Function> samplingHandler = request -> Mono
.just(CreateMessageResult.builder().message("test").model("test-model").build());
- var client = McpClient.async(transport)
- .requestTimeout(getTimeoutDuration())
- .capabilities(capabilities)
- .sampling(samplingHandler)
- .build();
- assertThatCode(() -> {
- var result = client.initialize().block(Duration.ofSeconds(10));
- assertThat(result).isNotNull();
- assertThat(result.capabilities()).isNotNull();
- client.closeGracefully().block(Duration.ofSeconds(10));
- }).doesNotThrowAnyException();
+ withClient(createMcpTransport(), builder -> builder.capabilities(capabilities).sampling(samplingHandler),
+ client ->
+
+ StepVerifier.create(client.initialize()).assertNext(result -> {
+ assertThat(result).isNotNull();
+ assertThat(result.capabilities()).isNotNull();
+ }).verifyComplete());
}
// ---------------------------------------
@@ -399,41 +447,41 @@ void testInitializeWithAllCapabilities() {
@Test
void testLoggingLevelsWithoutInitialization() {
- assertThatThrownBy(() -> mcpAsyncClient.setLoggingLevel(McpSchema.LoggingLevel.DEBUG).block())
- .isInstanceOf(McpError.class)
- .hasMessage("Client must be initialized before setting logging level");
+ verifyInitializationTimeout(client -> client.setLoggingLevel(McpSchema.LoggingLevel.DEBUG),
+ "setting logging level");
}
@Test
void testLoggingLevels() {
- mcpAsyncClient.initialize().block(Duration.ofSeconds(10));
-
- // Test all logging levels
- for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) {
- StepVerifier.create(mcpAsyncClient.setLoggingLevel(level)).verifyComplete();
- }
+ withClient(createMcpTransport(), mcpAsyncClient -> {
+ StepVerifier
+ .create(mcpAsyncClient.initialize()
+ .thenMany(Flux.fromArray(McpSchema.LoggingLevel.values()).flatMap(mcpAsyncClient::setLoggingLevel)))
+ .verifyComplete();
+ });
}
@Test
void testLoggingConsumer() {
AtomicBoolean logReceived = new AtomicBoolean(false);
- var transport = createMcpTransport();
- var client = McpClient.async(transport)
- .requestTimeout(getTimeoutDuration())
- .loggingConsumer(notification -> Mono.fromRunnable(() -> logReceived.set(true)))
- .build();
+ withClient(createMcpTransport(),
+ builder -> builder.loggingConsumer(notification -> Mono.fromRunnable(() -> logReceived.set(true))),
+ client -> {
+ StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete();
+ StepVerifier.create(client.closeGracefully()).verifyComplete();
+
+ });
- assertThatCode(() -> {
- client.initialize().block(Duration.ofSeconds(10));
- client.closeGracefully().block(Duration.ofSeconds(10));
- }).doesNotThrowAnyException();
}
@Test
void testLoggingWithNullNotification() {
- assertThatThrownBy(() -> mcpAsyncClient.setLoggingLevel(null).block())
- .hasMessageContaining("Logging level must not be null");
+ withClient(createMcpTransport(), mcpAsyncClient -> {
+ StepVerifier.create(mcpAsyncClient.setLoggingLevel(null))
+ .expectErrorMatches(error -> error.getMessage().contains("Logging level must not be null"))
+ .verify();
+ });
}
}
diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java
index aeed06cb..128441f8 100644
--- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java
+++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java
@@ -7,8 +7,11 @@
import java.time.Duration;
import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.Consumer;
+import java.util.function.Function;
-import io.modelcontextprotocol.spec.ClientMcpTransport;
+import io.modelcontextprotocol.spec.McpClientTransport;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpSchema.CallToolRequest;
@@ -27,6 +30,10 @@
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
+import reactor.core.publisher.Mono;
+import reactor.core.scheduler.Scheduler;
+import reactor.core.scheduler.Schedulers;
+import reactor.test.StepVerifier;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatCode;
@@ -40,269 +47,345 @@
*/
public abstract class AbstractMcpSyncClientTests {
- private McpSyncClient mcpSyncClient;
-
private static final String TEST_MESSAGE = "Hello MCP Spring AI!";
- protected ClientMcpTransport mcpTransport;
+ abstract protected McpClientTransport createMcpTransport();
- abstract protected ClientMcpTransport createMcpTransport();
+ protected void onStart() {
+ }
- abstract protected void onStart();
+ protected void onClose() {
+ }
- abstract protected void onClose();
+ protected Duration getRequestTimeout() {
+ return Duration.ofSeconds(14);
+ }
- protected Duration getTimeoutDuration() {
+ protected Duration getInitializationTimeout() {
return Duration.ofSeconds(2);
}
+ McpSyncClient client(McpClientTransport transport) {
+ return client(transport, Function.identity());
+ }
+
+ McpSyncClient client(McpClientTransport transport, Function customizer) {
+ AtomicReference client = new AtomicReference<>();
+
+ assertThatCode(() -> {
+ McpClient.SyncSpec builder = McpClient.sync(transport)
+ .requestTimeout(getRequestTimeout())
+ .initializationTimeout(getInitializationTimeout())
+ .capabilities(ClientCapabilities.builder().roots(true).build());
+ builder = customizer.apply(builder);
+ client.set(builder.build());
+ }).doesNotThrowAnyException();
+
+ return client.get();
+ }
+
+ void withClient(McpClientTransport transport, Consumer c) {
+ withClient(transport, Function.identity(), c);
+ }
+
+ void withClient(McpClientTransport transport, Function customizer,
+ Consumer c) {
+ var client = client(transport, customizer);
+ try {
+ c.accept(client);
+ }
+ finally {
+ assertThat(client.closeGracefully()).isTrue();
+ }
+ }
+
@BeforeEach
void setUp() {
onStart();
- this.mcpTransport = createMcpTransport();
- assertThatCode(() -> {
- mcpSyncClient = McpClient.sync(mcpTransport)
- .requestTimeout(getTimeoutDuration())
- .capabilities(ClientCapabilities.builder().roots(true).build())
- .build();
- }).doesNotThrowAnyException();
}
@AfterEach
void tearDown() {
- if (mcpSyncClient != null) {
- assertThatCode(() -> mcpSyncClient.close()).doesNotThrowAnyException();
- }
onClose();
}
+ static final Object DUMMY_RETURN_VALUE = new Object();
+
+ void verifyNotificationTimesOut(Consumer operation, String action) {
+ verifyCallTimesOut(client -> {
+ operation.accept(client);
+ return DUMMY_RETURN_VALUE;
+ }, action);
+ }
+
+ void verifyCallTimesOut(Function blockingOperation, String action) {
+ withClient(createMcpTransport(), mcpSyncClient -> {
+ // This scheduler is not replaced by virtual time scheduler
+ Scheduler customScheduler = Schedulers.newBoundedElastic(1, 1, "actualBoundedElastic");
+
+ StepVerifier.withVirtualTime(() -> Mono.fromSupplier(() -> blockingOperation.apply(mcpSyncClient))
+ // Offload the blocking call to the real scheduler
+ .subscribeOn(customScheduler))
+ .expectSubscription()
+ // This works without actually waiting but executes all the
+ // tasks pending execution on the VirtualTimeScheduler.
+ // It is possible to execute the blocking code from the operation
+ // because it is blocked on a dedicated Scheduler and the main
+ // flow is not blocked and uses the VirtualTimeScheduler.
+ .thenAwait(getInitializationTimeout())
+ .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class)
+ .hasMessage("Client must be initialized before " + action))
+ .verify();
+
+ customScheduler.dispose();
+ });
+ }
+
@Test
void testConstructorWithInvalidArguments() {
assertThatThrownBy(() -> McpClient.sync(null).build()).isInstanceOf(IllegalArgumentException.class)
.hasMessage("Transport must not be null");
- assertThatThrownBy(() -> McpClient.sync(mcpTransport).requestTimeout(null).build())
+ assertThatThrownBy(() -> McpClient.sync(createMcpTransport()).requestTimeout(null).build())
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("Request timeout must not be null");
}
@Test
void testListToolsWithoutInitialization() {
- assertThatThrownBy(() -> mcpSyncClient.listTools(null)).isInstanceOf(McpError.class)
- .hasMessage("Client must be initialized before listing tools");
+ verifyCallTimesOut(client -> client.listTools(null), "listing tools");
}
@Test
void testListTools() {
- mcpSyncClient.initialize();
- ListToolsResult tools = mcpSyncClient.listTools(null);
+ withClient(createMcpTransport(), mcpSyncClient -> {
+ mcpSyncClient.initialize();
+ ListToolsResult tools = mcpSyncClient.listTools(null);
- assertThat(tools).isNotNull().satisfies(result -> {
- assertThat(result.tools()).isNotNull().isNotEmpty();
+ assertThat(tools).isNotNull().satisfies(result -> {
+ assertThat(result.tools()).isNotNull().isNotEmpty();
- Tool firstTool = result.tools().get(0);
- assertThat(firstTool.name()).isNotNull();
- assertThat(firstTool.description()).isNotNull();
+ Tool firstTool = result.tools().get(0);
+ assertThat(firstTool.name()).isNotNull();
+ assertThat(firstTool.description()).isNotNull();
+ });
});
}
@Test
void testCallToolsWithoutInitialization() {
- assertThatThrownBy(() -> mcpSyncClient.callTool(new CallToolRequest("add", Map.of("a", 3, "b", 4))))
- .isInstanceOf(McpError.class)
- .hasMessage("Client must be initialized before calling tools");
+ verifyCallTimesOut(client -> client.callTool(new CallToolRequest("add", Map.of("a", 3, "b", 4))),
+ "calling tools");
}
@Test
void testCallTools() {
- mcpSyncClient.initialize();
- CallToolResult toolResult = mcpSyncClient.callTool(new CallToolRequest("add", Map.of("a", 3, "b", 4)));
+ withClient(createMcpTransport(), mcpSyncClient -> {
+ mcpSyncClient.initialize();
+ CallToolResult toolResult = mcpSyncClient.callTool(new CallToolRequest("add", Map.of("a", 3, "b", 4)));
- assertThat(toolResult).isNotNull().satisfies(result -> {
+ assertThat(toolResult).isNotNull().satisfies(result -> {
- assertThat(result.content()).hasSize(1);
+ assertThat(result.content()).hasSize(1);
- TextContent content = (TextContent) result.content().get(0);
+ TextContent content = (TextContent) result.content().get(0);
- assertThat(content).isNotNull();
- assertThat(content.text()).isNotNull();
- assertThat(content.text()).contains("7");
+ assertThat(content).isNotNull();
+ assertThat(content.text()).isNotNull();
+ assertThat(content.text()).contains("7");
+ });
});
}
@Test
void testPingWithoutInitialization() {
- assertThatThrownBy(() -> mcpSyncClient.ping()).isInstanceOf(McpError.class)
- .hasMessage("Client must be initialized before pinging the server");
+ verifyCallTimesOut(client -> client.ping(), "pinging the server");
}
@Test
void testPing() {
- mcpSyncClient.initialize();
- assertThatCode(() -> mcpSyncClient.ping()).doesNotThrowAnyException();
+ withClient(createMcpTransport(), mcpSyncClient -> {
+ mcpSyncClient.initialize();
+ assertThatCode(() -> mcpSyncClient.ping()).doesNotThrowAnyException();
+ });
}
@Test
void testCallToolWithoutInitialization() {
CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", TEST_MESSAGE));
-
- assertThatThrownBy(() -> mcpSyncClient.callTool(callToolRequest)).isInstanceOf(McpError.class)
- .hasMessage("Client must be initialized before calling tools");
+ verifyCallTimesOut(client -> client.callTool(callToolRequest), "calling tools");
}
@Test
void testCallTool() {
- mcpSyncClient.initialize();
- CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", TEST_MESSAGE));
+ withClient(createMcpTransport(), mcpSyncClient -> {
+ mcpSyncClient.initialize();
+ CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", TEST_MESSAGE));
- CallToolResult callToolResult = mcpSyncClient.callTool(callToolRequest);
+ CallToolResult callToolResult = mcpSyncClient.callTool(callToolRequest);
- assertThat(callToolResult).isNotNull().satisfies(result -> {
- assertThat(result.content()).isNotNull();
- assertThat(result.isError()).isNull();
+ assertThat(callToolResult).isNotNull().satisfies(result -> {
+ assertThat(result.content()).isNotNull();
+ assertThat(result.isError()).isNull();
+ });
});
}
@Test
void testCallToolWithInvalidTool() {
- CallToolRequest invalidRequest = new CallToolRequest("nonexistent_tool", Map.of("message", TEST_MESSAGE));
+ withClient(createMcpTransport(), mcpSyncClient -> {
+ CallToolRequest invalidRequest = new CallToolRequest("nonexistent_tool", Map.of("message", TEST_MESSAGE));
- assertThatThrownBy(() -> mcpSyncClient.callTool(invalidRequest)).isInstanceOf(Exception.class);
+ assertThatThrownBy(() -> mcpSyncClient.callTool(invalidRequest)).isInstanceOf(Exception.class);
+ });
}
@Test
void testRootsListChangedWithoutInitialization() {
- assertThatThrownBy(() -> mcpSyncClient.rootsListChangedNotification()).isInstanceOf(McpError.class)
- .hasMessage("Client must be initialized before sending roots list changed notification");
+ verifyNotificationTimesOut(client -> client.rootsListChangedNotification(),
+ "sending roots list changed notification");
}
@Test
void testRootsListChanged() {
- mcpSyncClient.initialize();
- assertThatCode(() -> mcpSyncClient.rootsListChangedNotification()).doesNotThrowAnyException();
+ withClient(createMcpTransport(), mcpSyncClient -> {
+ mcpSyncClient.initialize();
+ assertThatCode(() -> mcpSyncClient.rootsListChangedNotification()).doesNotThrowAnyException();
+ });
}
@Test
void testListResourcesWithoutInitialization() {
- assertThatThrownBy(() -> mcpSyncClient.listResources(null)).isInstanceOf(McpError.class)
- .hasMessage("Client must be initialized before listing resources");
+ verifyCallTimesOut(client -> client.listResources(null), "listing resources");
}
@Test
void testListResources() {
- mcpSyncClient.initialize();
- ListResourcesResult resources = mcpSyncClient.listResources(null);
-
- assertThat(resources).isNotNull().satisfies(result -> {
- assertThat(result.resources()).isNotNull();
-
- if (!result.resources().isEmpty()) {
- Resource firstResource = result.resources().get(0);
- assertThat(firstResource.uri()).isNotNull();
- assertThat(firstResource.name()).isNotNull();
- }
+ withClient(createMcpTransport(), mcpSyncClient -> {
+ mcpSyncClient.initialize();
+ ListResourcesResult resources = mcpSyncClient.listResources(null);
+
+ assertThat(resources).isNotNull().satisfies(result -> {
+ assertThat(result.resources()).isNotNull();
+
+ if (!result.resources().isEmpty()) {
+ Resource firstResource = result.resources().get(0);
+ assertThat(firstResource.uri()).isNotNull();
+ assertThat(firstResource.name()).isNotNull();
+ }
+ });
});
}
@Test
void testClientSessionState() {
- assertThat(mcpSyncClient).isNotNull();
+ withClient(createMcpTransport(), mcpSyncClient -> {
+ assertThat(mcpSyncClient).isNotNull();
+ });
}
@Test
void testInitializeWithRootsListProviders() {
- var transport = createMcpTransport();
-
- var client = McpClient.sync(transport)
- .requestTimeout(getTimeoutDuration())
- .roots(new Root("file:///test/path", "test-root"))
- .build();
+ withClient(createMcpTransport(), builder -> builder.roots(new Root("file:///test/path", "test-root")),
+ mcpSyncClient -> {
- assertThatCode(() -> {
- client.initialize();
- client.close();
- }).doesNotThrowAnyException();
+ assertThatCode(() -> {
+ mcpSyncClient.initialize();
+ mcpSyncClient.close();
+ }).doesNotThrowAnyException();
+ });
}
@Test
void testAddRoot() {
- Root newRoot = new Root("file:///new/test/path", "new-test-root");
- assertThatCode(() -> mcpSyncClient.addRoot(newRoot)).doesNotThrowAnyException();
+ withClient(createMcpTransport(), mcpSyncClient -> {
+ Root newRoot = new Root("file:///new/test/path", "new-test-root");
+ assertThatCode(() -> mcpSyncClient.addRoot(newRoot)).doesNotThrowAnyException();
+ });
}
@Test
void testAddRootWithNullValue() {
- assertThatThrownBy(() -> mcpSyncClient.addRoot(null)).hasMessageContaining("Root must not be null");
+ withClient(createMcpTransport(), mcpSyncClient -> {
+ assertThatThrownBy(() -> mcpSyncClient.addRoot(null)).hasMessageContaining("Root must not be null");
+ });
}
@Test
void testRemoveRoot() {
- Root root = new Root("file:///test/path/to/remove", "root-to-remove");
- assertThatCode(() -> {
- mcpSyncClient.addRoot(root);
- mcpSyncClient.removeRoot(root.uri());
- }).doesNotThrowAnyException();
+ withClient(createMcpTransport(), mcpSyncClient -> {
+ Root root = new Root("file:///test/path/to/remove", "root-to-remove");
+ assertThatCode(() -> {
+ mcpSyncClient.addRoot(root);
+ mcpSyncClient.removeRoot(root.uri());
+ }).doesNotThrowAnyException();
+ });
}
@Test
void testRemoveNonExistentRoot() {
- assertThatThrownBy(() -> mcpSyncClient.removeRoot("nonexistent-uri"))
- .hasMessageContaining("Root with uri 'nonexistent-uri' not found");
+ withClient(createMcpTransport(), mcpSyncClient -> {
+ assertThatThrownBy(() -> mcpSyncClient.removeRoot("nonexistent-uri"))
+ .hasMessageContaining("Root with uri 'nonexistent-uri' not found");
+ });
}
@Test
void testReadResourceWithoutInitialization() {
- assertThatThrownBy(() -> {
- Resource resource = new Resource("test://uri", "Test Resource", null, null, null);
- mcpSyncClient.readResource(resource);
- }).isInstanceOf(McpError.class).hasMessage("Client must be initialized before reading resources");
+ Resource resource = new Resource("test://uri", "Test Resource", null, null, null);
+ verifyCallTimesOut(client -> client.readResource(resource), "reading resources");
}
@Test
void testReadResource() {
- mcpSyncClient.initialize();
- ListResourcesResult resources = mcpSyncClient.listResources(null);
+ withClient(createMcpTransport(), mcpSyncClient -> {
+ mcpSyncClient.initialize();
+ ListResourcesResult resources = mcpSyncClient.listResources(null);
- if (!resources.resources().isEmpty()) {
- Resource firstResource = resources.resources().get(0);
- ReadResourceResult result = mcpSyncClient.readResource(firstResource);
+ if (!resources.resources().isEmpty()) {
+ Resource firstResource = resources.resources().get(0);
+ ReadResourceResult result = mcpSyncClient.readResource(firstResource);
- assertThat(result).isNotNull();
- assertThat(result.contents()).isNotNull();
- }
+ assertThat(result).isNotNull();
+ assertThat(result.contents()).isNotNull();
+ }
+ });
}
@Test
void testListResourceTemplatesWithoutInitialization() {
- assertThatThrownBy(() -> mcpSyncClient.listResourceTemplates(null)).isInstanceOf(McpError.class)
- .hasMessage("Client must be initialized before listing resource templates");
+ verifyCallTimesOut(client -> client.listResourceTemplates(null), "listing resource templates");
}
@Test
void testListResourceTemplates() {
- mcpSyncClient.initialize();
- ListResourceTemplatesResult result = mcpSyncClient.listResourceTemplates(null);
+ withClient(createMcpTransport(), mcpSyncClient -> {
+ mcpSyncClient.initialize();
+ ListResourceTemplatesResult result = mcpSyncClient.listResourceTemplates(null);
- assertThat(result).isNotNull();
- assertThat(result.resourceTemplates()).isNotNull();
+ assertThat(result).isNotNull();
+ assertThat(result.resourceTemplates()).isNotNull();
+ });
}
// @Test
void testResourceSubscription() {
- ListResourcesResult resources = mcpSyncClient.listResources(null);
+ withClient(createMcpTransport(), mcpSyncClient -> {
+ ListResourcesResult resources = mcpSyncClient.listResources(null);
- if (!resources.resources().isEmpty()) {
- Resource firstResource = resources.resources().get(0);
+ if (!resources.resources().isEmpty()) {
+ Resource firstResource = resources.resources().get(0);
- // Test subscribe
- assertThatCode(() -> mcpSyncClient.subscribeResource(new SubscribeRequest(firstResource.uri())))
- .doesNotThrowAnyException();
+ // Test subscribe
+ assertThatCode(() -> mcpSyncClient.subscribeResource(new SubscribeRequest(firstResource.uri())))
+ .doesNotThrowAnyException();
- // Test unsubscribe
- assertThatCode(() -> mcpSyncClient.unsubscribeResource(new UnsubscribeRequest(firstResource.uri())))
- .doesNotThrowAnyException();
- }
+ // Test unsubscribe
+ assertThatCode(() -> mcpSyncClient.unsubscribeResource(new UnsubscribeRequest(firstResource.uri())))
+ .doesNotThrowAnyException();
+ }
+ });
}
@Test
@@ -311,18 +394,17 @@ void testNotificationHandlers() {
AtomicBoolean resourcesNotificationReceived = new AtomicBoolean(false);
AtomicBoolean promptsNotificationReceived = new AtomicBoolean(false);
- var transport = createMcpTransport();
- var client = McpClient.sync(transport)
- .requestTimeout(getTimeoutDuration())
- .toolsChangeConsumer(tools -> toolsNotificationReceived.set(true))
- .resourcesChangeConsumer(resources -> resourcesNotificationReceived.set(true))
- .promptsChangeConsumer(prompts -> promptsNotificationReceived.set(true))
- .build();
+ withClient(createMcpTransport(),
+ builder -> builder.toolsChangeConsumer(tools -> toolsNotificationReceived.set(true))
+ .resourcesChangeConsumer(resources -> resourcesNotificationReceived.set(true))
+ .promptsChangeConsumer(prompts -> promptsNotificationReceived.set(true)),
+ client -> {
- assertThatCode(() -> {
- client.initialize();
- client.close();
- }).doesNotThrowAnyException();
+ assertThatCode(() -> {
+ client.initialize();
+ client.close();
+ }).doesNotThrowAnyException();
+ });
}
// ---------------------------------------
@@ -331,40 +413,37 @@ void testNotificationHandlers() {
@Test
void testLoggingLevelsWithoutInitialization() {
- assertThatThrownBy(() -> mcpSyncClient.setLoggingLevel(McpSchema.LoggingLevel.DEBUG))
- .isInstanceOf(McpError.class)
- .hasMessage("Client must be initialized before setting logging level");
+ verifyNotificationTimesOut(client -> client.setLoggingLevel(McpSchema.LoggingLevel.DEBUG),
+ "setting logging level");
}
@Test
void testLoggingLevels() {
- mcpSyncClient.initialize();
- // Test all logging levels
- for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) {
- assertThatCode(() -> mcpSyncClient.setLoggingLevel(level)).doesNotThrowAnyException();
- }
+ withClient(createMcpTransport(), mcpSyncClient -> {
+ mcpSyncClient.initialize();
+ // Test all logging levels
+ for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) {
+ assertThatCode(() -> mcpSyncClient.setLoggingLevel(level)).doesNotThrowAnyException();
+ }
+ });
}
@Test
void testLoggingConsumer() {
AtomicBoolean logReceived = new AtomicBoolean(false);
- var transport = createMcpTransport();
-
- var client = McpClient.sync(transport)
- .requestTimeout(getTimeoutDuration())
- .loggingConsumer(notification -> logReceived.set(true))
- .build();
-
- assertThatCode(() -> {
- client.initialize();
- client.close();
- }).doesNotThrowAnyException();
+ withClient(createMcpTransport(), builder -> builder.requestTimeout(getRequestTimeout())
+ .loggingConsumer(notification -> logReceived.set(true)), client -> {
+ assertThatCode(() -> {
+ client.initialize();
+ client.close();
+ }).doesNotThrowAnyException();
+ });
}
@Test
void testLoggingWithNullNotification() {
- assertThatThrownBy(() -> mcpSyncClient.setLoggingLevel(null))
- .hasMessageContaining("Logging level must not be null");
+ withClient(createMcpTransport(), mcpSyncClient -> assertThatThrownBy(() -> mcpSyncClient.setLoggingLevel(null))
+ .hasMessageContaining("Logging level must not be null"));
}
}
diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java
index ca5783d0..025cfeac 100644
--- a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java
+++ b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java
@@ -17,8 +17,7 @@
import io.modelcontextprotocol.spec.McpSchema.Resource;
import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities;
import io.modelcontextprotocol.spec.McpSchema.Tool;
-import io.modelcontextprotocol.spec.McpTransport;
-import io.modelcontextprotocol.spec.ServerMcpTransport;
+import io.modelcontextprotocol.spec.McpServerTransportProvider;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
@@ -31,10 +30,11 @@
/**
* Test suite for the {@link McpAsyncServer} that can be used with different
- * {@link McpTransport} implementations.
+ * {@link io.modelcontextprotocol.spec.McpServerTransportProvider} implementations.
*
* @author Christian Tzolov
*/
+// KEEP IN SYNC with the class in mcp-test module
public abstract class AbstractMcpAsyncServerTests {
private static final String TEST_TOOL_NAME = "test-tool";
@@ -43,7 +43,7 @@ public abstract class AbstractMcpAsyncServerTests {
private static final String TEST_PROMPT_NAME = "test-prompt";
- abstract protected ServerMcpTransport createMcpTransport();
+ abstract protected McpServerTransportProvider createMcpTransportProvider();
protected void onStart() {
}
@@ -66,24 +66,26 @@ void tearDown() {
@Test
void testConstructorWithInvalidArguments() {
- assertThatThrownBy(() -> McpServer.async(null)).isInstanceOf(IllegalArgumentException.class)
- .hasMessage("Transport must not be null");
+ assertThatThrownBy(() -> McpServer.async((McpServerTransportProvider) null))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessage("Transport provider must not be null");
- assertThatThrownBy(() -> McpServer.async(createMcpTransport()).serverInfo((McpSchema.Implementation) null))
+ assertThatThrownBy(
+ () -> McpServer.async(createMcpTransportProvider()).serverInfo((McpSchema.Implementation) null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("Server info must not be null");
}
@Test
void testGracefulShutdown() {
- var mcpAsyncServer = McpServer.async(createMcpTransport()).serverInfo("test-server", "1.0.0").build();
+ var mcpAsyncServer = McpServer.async(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build();
StepVerifier.create(mcpAsyncServer.closeGracefully()).verifyComplete();
}
@Test
void testImmediateClose() {
- var mcpAsyncServer = McpServer.async(createMcpTransport()).serverInfo("test-server", "1.0.0").build();
+ var mcpAsyncServer = McpServer.async(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build();
assertThatCode(() -> mcpAsyncServer.close()).doesNotThrowAnyException();
}
@@ -102,13 +104,13 @@ void testImmediateClose() {
@Test
void testAddTool() {
Tool newTool = new McpSchema.Tool("new-tool", "New test tool", emptyJsonSchema);
- var mcpAsyncServer = McpServer.async(createMcpTransport())
+ var mcpAsyncServer = McpServer.async(createMcpTransportProvider())
.serverInfo("test-server", "1.0.0")
.capabilities(ServerCapabilities.builder().tools(true).build())
.build();
- StepVerifier.create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolRegistration(newTool,
- args -> Mono.just(new CallToolResult(List.of(), false)))))
+ StepVerifier.create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolSpecification(newTool,
+ (exchange, args) -> Mono.just(new CallToolResult(List.of(), false)))))
.verifyComplete();
assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException();
@@ -118,14 +120,15 @@ void testAddTool() {
void testAddDuplicateTool() {
Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema);
- var mcpAsyncServer = McpServer.async(createMcpTransport())
+ var mcpAsyncServer = McpServer.async(createMcpTransportProvider())
.serverInfo("test-server", "1.0.0")
.capabilities(ServerCapabilities.builder().tools(true).build())
- .tool(duplicateTool, args -> Mono.just(new CallToolResult(List.of(), false)))
+ .tool(duplicateTool, (exchange, args) -> Mono.just(new CallToolResult(List.of(), false)))
.build();
- StepVerifier.create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolRegistration(duplicateTool,
- args -> Mono.just(new CallToolResult(List.of(), false)))))
+ StepVerifier
+ .create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolSpecification(duplicateTool,
+ (exchange, args) -> Mono.just(new CallToolResult(List.of(), false)))))
.verifyErrorSatisfies(error -> {
assertThat(error).isInstanceOf(McpError.class)
.hasMessage("Tool with name '" + TEST_TOOL_NAME + "' already exists");
@@ -138,10 +141,10 @@ void testAddDuplicateTool() {
void testRemoveTool() {
Tool too = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema);
- var mcpAsyncServer = McpServer.async(createMcpTransport())
+ var mcpAsyncServer = McpServer.async(createMcpTransportProvider())
.serverInfo("test-server", "1.0.0")
.capabilities(ServerCapabilities.builder().tools(true).build())
- .tool(too, args -> Mono.just(new CallToolResult(List.of(), false)))
+ .tool(too, (exchange, args) -> Mono.just(new CallToolResult(List.of(), false)))
.build();
StepVerifier.create(mcpAsyncServer.removeTool(TEST_TOOL_NAME)).verifyComplete();
@@ -151,7 +154,7 @@ void testRemoveTool() {
@Test
void testRemoveNonexistentTool() {
- var mcpAsyncServer = McpServer.async(createMcpTransport())
+ var mcpAsyncServer = McpServer.async(createMcpTransportProvider())
.serverInfo("test-server", "1.0.0")
.capabilities(ServerCapabilities.builder().tools(true).build())
.build();
@@ -167,10 +170,10 @@ void testRemoveNonexistentTool() {
void testNotifyToolsListChanged() {
Tool too = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema);
- var mcpAsyncServer = McpServer.async(createMcpTransport())
+ var mcpAsyncServer = McpServer.async(createMcpTransportProvider())
.serverInfo("test-server", "1.0.0")
.capabilities(ServerCapabilities.builder().tools(true).build())
- .tool(too, args -> Mono.just(new CallToolResult(List.of(), false)))
+ .tool(too, (exchange, args) -> Mono.just(new CallToolResult(List.of(), false)))
.build();
StepVerifier.create(mcpAsyncServer.notifyToolsListChanged()).verifyComplete();
@@ -184,7 +187,7 @@ void testNotifyToolsListChanged() {
@Test
void testNotifyResourcesListChanged() {
- var mcpAsyncServer = McpServer.async(createMcpTransport()).serverInfo("test-server", "1.0.0").build();
+ var mcpAsyncServer = McpServer.async(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build();
StepVerifier.create(mcpAsyncServer.notifyResourcesListChanged()).verifyComplete();
@@ -193,29 +196,29 @@ void testNotifyResourcesListChanged() {
@Test
void testAddResource() {
- var mcpAsyncServer = McpServer.async(createMcpTransport())
+ var mcpAsyncServer = McpServer.async(createMcpTransportProvider())
.serverInfo("test-server", "1.0.0")
.capabilities(ServerCapabilities.builder().resources(true, false).build())
.build();
Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description",
null);
- McpServerFeatures.AsyncResourceRegistration registration = new McpServerFeatures.AsyncResourceRegistration(
- resource, req -> Mono.just(new ReadResourceResult(List.of())));
+ McpServerFeatures.AsyncResourceSpecification specification = new McpServerFeatures.AsyncResourceSpecification(
+ resource, (exchange, req) -> Mono.just(new ReadResourceResult(List.of())));
- StepVerifier.create(mcpAsyncServer.addResource(registration)).verifyComplete();
+ StepVerifier.create(mcpAsyncServer.addResource(specification)).verifyComplete();
assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException();
}
@Test
- void testAddResourceWithNullRegistration() {
- var mcpAsyncServer = McpServer.async(createMcpTransport())
+ void testAddResourceWithNullSpecification() {
+ var mcpAsyncServer = McpServer.async(createMcpTransportProvider())
.serverInfo("test-server", "1.0.0")
.capabilities(ServerCapabilities.builder().resources(true, false).build())
.build();
- StepVerifier.create(mcpAsyncServer.addResource((McpServerFeatures.AsyncResourceRegistration) null))
+ StepVerifier.create(mcpAsyncServer.addResource((McpServerFeatures.AsyncResourceSpecification) null))
.verifyErrorSatisfies(error -> {
assertThat(error).isInstanceOf(McpError.class).hasMessage("Resource must not be null");
});
@@ -226,16 +229,16 @@ void testAddResourceWithNullRegistration() {
@Test
void testAddResourceWithoutCapability() {
// Create a server without resource capabilities
- McpAsyncServer serverWithoutResources = McpServer.async(createMcpTransport())
+ McpAsyncServer serverWithoutResources = McpServer.async(createMcpTransportProvider())
.serverInfo("test-server", "1.0.0")
.build();
Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description",
null);
- McpServerFeatures.AsyncResourceRegistration registration = new McpServerFeatures.AsyncResourceRegistration(
- resource, req -> Mono.just(new ReadResourceResult(List.of())));
+ McpServerFeatures.AsyncResourceSpecification specification = new McpServerFeatures.AsyncResourceSpecification(
+ resource, (exchange, req) -> Mono.just(new ReadResourceResult(List.of())));
- StepVerifier.create(serverWithoutResources.addResource(registration)).verifyErrorSatisfies(error -> {
+ StepVerifier.create(serverWithoutResources.addResource(specification)).verifyErrorSatisfies(error -> {
assertThat(error).isInstanceOf(McpError.class)
.hasMessage("Server must be configured with resource capabilities");
});
@@ -244,7 +247,7 @@ void testAddResourceWithoutCapability() {
@Test
void testRemoveResourceWithoutCapability() {
// Create a server without resource capabilities
- McpAsyncServer serverWithoutResources = McpServer.async(createMcpTransport())
+ McpAsyncServer serverWithoutResources = McpServer.async(createMcpTransportProvider())
.serverInfo("test-server", "1.0.0")
.build();
@@ -260,7 +263,7 @@ void testRemoveResourceWithoutCapability() {
@Test
void testNotifyPromptsListChanged() {
- var mcpAsyncServer = McpServer.async(createMcpTransport()).serverInfo("test-server", "1.0.0").build();
+ var mcpAsyncServer = McpServer.async(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build();
StepVerifier.create(mcpAsyncServer.notifyPromptsListChanged()).verifyComplete();
@@ -268,31 +271,31 @@ void testNotifyPromptsListChanged() {
}
@Test
- void testAddPromptWithNullRegistration() {
- var mcpAsyncServer = McpServer.async(createMcpTransport())
+ void testAddPromptWithNullSpecification() {
+ var mcpAsyncServer = McpServer.async(createMcpTransportProvider())
.serverInfo("test-server", "1.0.0")
.capabilities(ServerCapabilities.builder().prompts(false).build())
.build();
- StepVerifier.create(mcpAsyncServer.addPrompt((McpServerFeatures.AsyncPromptRegistration) null))
+ StepVerifier.create(mcpAsyncServer.addPrompt((McpServerFeatures.AsyncPromptSpecification) null))
.verifyErrorSatisfies(error -> {
- assertThat(error).isInstanceOf(McpError.class).hasMessage("Prompt registration must not be null");
+ assertThat(error).isInstanceOf(McpError.class).hasMessage("Prompt specification must not be null");
});
}
@Test
void testAddPromptWithoutCapability() {
// Create a server without prompt capabilities
- McpAsyncServer serverWithoutPrompts = McpServer.async(createMcpTransport())
+ McpAsyncServer serverWithoutPrompts = McpServer.async(createMcpTransportProvider())
.serverInfo("test-server", "1.0.0")
.build();
Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", List.of());
- McpServerFeatures.AsyncPromptRegistration registration = new McpServerFeatures.AsyncPromptRegistration(prompt,
- req -> Mono.just(new GetPromptResult("Test prompt description", List
+ McpServerFeatures.AsyncPromptSpecification specification = new McpServerFeatures.AsyncPromptSpecification(
+ prompt, (exchange, req) -> Mono.just(new GetPromptResult("Test prompt description", List
.of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content"))))));
- StepVerifier.create(serverWithoutPrompts.addPrompt(registration)).verifyErrorSatisfies(error -> {
+ StepVerifier.create(serverWithoutPrompts.addPrompt(specification)).verifyErrorSatisfies(error -> {
assertThat(error).isInstanceOf(McpError.class)
.hasMessage("Server must be configured with prompt capabilities");
});
@@ -301,7 +304,7 @@ void testAddPromptWithoutCapability() {
@Test
void testRemovePromptWithoutCapability() {
// Create a server without prompt capabilities
- McpAsyncServer serverWithoutPrompts = McpServer.async(createMcpTransport())
+ McpAsyncServer serverWithoutPrompts = McpServer.async(createMcpTransportProvider())
.serverInfo("test-server", "1.0.0")
.build();
@@ -316,14 +319,14 @@ void testRemovePrompt() {
String TEST_PROMPT_NAME_TO_REMOVE = "TEST_PROMPT_NAME678";
Prompt prompt = new Prompt(TEST_PROMPT_NAME_TO_REMOVE, "Test Prompt", List.of());
- McpServerFeatures.AsyncPromptRegistration registration = new McpServerFeatures.AsyncPromptRegistration(prompt,
- req -> Mono.just(new GetPromptResult("Test prompt description", List
+ McpServerFeatures.AsyncPromptSpecification specification = new McpServerFeatures.AsyncPromptSpecification(
+ prompt, (exchange, req) -> Mono.just(new GetPromptResult("Test prompt description", List
.of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content"))))));
- var mcpAsyncServer = McpServer.async(createMcpTransport())
+ var mcpAsyncServer = McpServer.async(createMcpTransportProvider())
.serverInfo("test-server", "1.0.0")
.capabilities(ServerCapabilities.builder().prompts(true).build())
- .prompts(registration)
+ .prompts(specification)
.build();
StepVerifier.create(mcpAsyncServer.removePrompt(TEST_PROMPT_NAME_TO_REMOVE)).verifyComplete();
@@ -333,7 +336,7 @@ void testRemovePrompt() {
@Test
void testRemoveNonexistentPrompt() {
- var mcpAsyncServer2 = McpServer.async(createMcpTransport())
+ var mcpAsyncServer2 = McpServer.async(createMcpTransportProvider())
.serverInfo("test-server", "1.0.0")
.capabilities(ServerCapabilities.builder().prompts(true).build())
.build();
@@ -352,14 +355,14 @@ void testRemoveNonexistentPrompt() {
// ---------------------------------------
@Test
- void testRootsChangeConsumers() {
+ void testRootsChangeHandlers() {
// Test with single consumer
var rootsReceived = new McpSchema.Root[1];
var consumerCalled = new boolean[1];
- var singleConsumerServer = McpServer.async(createMcpTransport())
+ var singleConsumerServer = McpServer.async(createMcpTransportProvider())
.serverInfo("test-server", "1.0.0")
- .rootsChangeConsumers(List.of(roots -> Mono.fromRunnable(() -> {
+ .rootsChangeHandlers(List.of((exchange, roots) -> Mono.fromRunnable(() -> {
consumerCalled[0] = true;
if (!roots.isEmpty()) {
rootsReceived[0] = roots.get(0);
@@ -377,12 +380,12 @@ void testRootsChangeConsumers() {
var consumer2Called = new boolean[1];
var rootsContent = new List[1];
- var multipleConsumersServer = McpServer.async(createMcpTransport())
+ var multipleConsumersServer = McpServer.async(createMcpTransportProvider())
.serverInfo("test-server", "1.0.0")
- .rootsChangeConsumers(List.of(roots -> Mono.fromRunnable(() -> {
+ .rootsChangeHandlers(List.of((exchange, roots) -> Mono.fromRunnable(() -> {
consumer1Called[0] = true;
rootsContent[0] = roots;
- }), roots -> Mono.fromRunnable(() -> consumer2Called[0] = true)))
+ }), (exchange, roots) -> Mono.fromRunnable(() -> consumer2Called[0] = true)))
.build();
assertThat(multipleConsumersServer).isNotNull();
@@ -391,9 +394,9 @@ void testRootsChangeConsumers() {
onClose();
// Test error handling
- var errorHandlingServer = McpServer.async(createMcpTransport())
+ var errorHandlingServer = McpServer.async(createMcpTransportProvider())
.serverInfo("test-server", "1.0.0")
- .rootsChangeConsumers(List.of(roots -> {
+ .rootsChangeHandlers(List.of((exchange, roots) -> {
throw new RuntimeException("Test error");
}))
.build();
@@ -404,60 +407,13 @@ void testRootsChangeConsumers() {
onClose();
// Test without consumers
- var noConsumersServer = McpServer.async(createMcpTransport()).serverInfo("test-server", "1.0.0").build();
+ var noConsumersServer = McpServer.async(createMcpTransportProvider())
+ .serverInfo("test-server", "1.0.0")
+ .build();
assertThat(noConsumersServer).isNotNull();
assertThatCode(() -> noConsumersServer.closeGracefully().block(Duration.ofSeconds(10)))
.doesNotThrowAnyException();
}
- // ---------------------------------------
- // Logging Tests
- // ---------------------------------------
-
- @Test
- void testLoggingLevels() {
- var mcpAsyncServer = McpServer.async(createMcpTransport())
- .serverInfo("test-server", "1.0.0")
- .capabilities(ServerCapabilities.builder().logging().build())
- .build();
-
- // Test all logging levels
- for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) {
- var notification = McpSchema.LoggingMessageNotification.builder()
- .level(level)
- .logger("test-logger")
- .data("Test message with level " + level)
- .build();
-
- StepVerifier.create(mcpAsyncServer.loggingNotification(notification)).verifyComplete();
- }
- }
-
- @Test
- void testLoggingWithoutCapability() {
- var mcpAsyncServer = McpServer.async(createMcpTransport())
- .serverInfo("test-server", "1.0.0")
- .capabilities(ServerCapabilities.builder().build()) // No logging capability
- .build();
-
- var notification = McpSchema.LoggingMessageNotification.builder()
- .level(McpSchema.LoggingLevel.INFO)
- .logger("test-logger")
- .data("Test log message")
- .build();
-
- StepVerifier.create(mcpAsyncServer.loggingNotification(notification)).verifyComplete();
- }
-
- @Test
- void testLoggingWithNullNotification() {
- var mcpAsyncServer = McpServer.async(createMcpTransport())
- .serverInfo("test-server", "1.0.0")
- .capabilities(ServerCapabilities.builder().logging().build())
- .build();
-
- StepVerifier.create(mcpAsyncServer.loggingNotification(null)).verifyError(McpError.class);
- }
-
}
diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java
index f8b95750..e313454b 100644
--- a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java
+++ b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java
@@ -16,8 +16,7 @@
import io.modelcontextprotocol.spec.McpSchema.Resource;
import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities;
import io.modelcontextprotocol.spec.McpSchema.Tool;
-import io.modelcontextprotocol.spec.McpTransport;
-import io.modelcontextprotocol.spec.ServerMcpTransport;
+import io.modelcontextprotocol.spec.McpServerTransportProvider;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
@@ -28,10 +27,11 @@
/**
* Test suite for the {@link McpSyncServer} that can be used with different
- * {@link McpTransport} implementations.
+ * {@link io.modelcontextprotocol.spec.McpServerTransportProvider} implementations.
*
* @author Christian Tzolov
*/
+// KEEP IN SYNC with the class in mcp-test module
public abstract class AbstractMcpSyncServerTests {
private static final String TEST_TOOL_NAME = "test-tool";
@@ -40,7 +40,7 @@ public abstract class AbstractMcpSyncServerTests {
private static final String TEST_PROMPT_NAME = "test-prompt";
- abstract protected ServerMcpTransport createMcpTransport();
+ abstract protected McpServerTransportProvider createMcpTransportProvider();
protected void onStart() {
}
@@ -64,31 +64,32 @@ void tearDown() {
@Test
void testConstructorWithInvalidArguments() {
- assertThatThrownBy(() -> McpServer.sync(null)).isInstanceOf(IllegalArgumentException.class)
- .hasMessage("Transport must not be null");
+ assertThatThrownBy(() -> McpServer.sync((McpServerTransportProvider) null))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessage("Transport provider must not be null");
- assertThatThrownBy(() -> McpServer.sync(createMcpTransport()).serverInfo(null))
+ assertThatThrownBy(() -> McpServer.sync(createMcpTransportProvider()).serverInfo(null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("Server info must not be null");
}
@Test
void testGracefulShutdown() {
- var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build();
+ var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build();
assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException();
}
@Test
void testImmediateClose() {
- var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build();
+ var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build();
assertThatCode(() -> mcpSyncServer.close()).doesNotThrowAnyException();
}
@Test
void testGetAsyncServer() {
- var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build();
+ var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build();
assertThat(mcpSyncServer.getAsyncServer()).isNotNull();
@@ -109,14 +110,14 @@ void testGetAsyncServer() {
@Test
void testAddTool() {
- var mcpSyncServer = McpServer.sync(createMcpTransport())
+ var mcpSyncServer = McpServer.sync(createMcpTransportProvider())
.serverInfo("test-server", "1.0.0")
.capabilities(ServerCapabilities.builder().tools(true).build())
.build();
Tool newTool = new McpSchema.Tool("new-tool", "New test tool", emptyJsonSchema);
- assertThatCode(() -> mcpSyncServer
- .addTool(new McpServerFeatures.SyncToolRegistration(newTool, args -> new CallToolResult(List.of(), false))))
+ assertThatCode(() -> mcpSyncServer.addTool(new McpServerFeatures.SyncToolSpecification(newTool,
+ (exchange, args) -> new CallToolResult(List.of(), false))))
.doesNotThrowAnyException();
assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException();
@@ -126,14 +127,14 @@ void testAddTool() {
void testAddDuplicateTool() {
Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema);
- var mcpSyncServer = McpServer.sync(createMcpTransport())
+ var mcpSyncServer = McpServer.sync(createMcpTransportProvider())
.serverInfo("test-server", "1.0.0")
.capabilities(ServerCapabilities.builder().tools(true).build())
- .tool(duplicateTool, args -> new CallToolResult(List.of(), false))
+ .tool(duplicateTool, (exchange, args) -> new CallToolResult(List.of(), false))
.build();
- assertThatThrownBy(() -> mcpSyncServer.addTool(new McpServerFeatures.SyncToolRegistration(duplicateTool,
- args -> new CallToolResult(List.of(), false))))
+ assertThatThrownBy(() -> mcpSyncServer.addTool(new McpServerFeatures.SyncToolSpecification(duplicateTool,
+ (exchange, args) -> new CallToolResult(List.of(), false))))
.isInstanceOf(McpError.class)
.hasMessage("Tool with name '" + TEST_TOOL_NAME + "' already exists");
@@ -144,10 +145,10 @@ void testAddDuplicateTool() {
void testRemoveTool() {
Tool tool = new McpSchema.Tool(TEST_TOOL_NAME, "Test tool", emptyJsonSchema);
- var mcpSyncServer = McpServer.sync(createMcpTransport())
+ var mcpSyncServer = McpServer.sync(createMcpTransportProvider())
.serverInfo("test-server", "1.0.0")
.capabilities(ServerCapabilities.builder().tools(true).build())
- .tool(tool, args -> new CallToolResult(List.of(), false))
+ .tool(tool, (exchange, args) -> new CallToolResult(List.of(), false))
.build();
assertThatCode(() -> mcpSyncServer.removeTool(TEST_TOOL_NAME)).doesNotThrowAnyException();
@@ -157,7 +158,7 @@ void testRemoveTool() {
@Test
void testRemoveNonexistentTool() {
- var mcpSyncServer = McpServer.sync(createMcpTransport())
+ var mcpSyncServer = McpServer.sync(createMcpTransportProvider())
.serverInfo("test-server", "1.0.0")
.capabilities(ServerCapabilities.builder().tools(true).build())
.build();
@@ -170,7 +171,7 @@ void testRemoveNonexistentTool() {
@Test
void testNotifyToolsListChanged() {
- var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build();
+ var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build();
assertThatCode(() -> mcpSyncServer.notifyToolsListChanged()).doesNotThrowAnyException();
@@ -183,7 +184,7 @@ void testNotifyToolsListChanged() {
@Test
void testNotifyResourcesListChanged() {
- var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build();
+ var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build();
assertThatCode(() -> mcpSyncServer.notifyResourcesListChanged()).doesNotThrowAnyException();
@@ -192,29 +193,29 @@ void testNotifyResourcesListChanged() {
@Test
void testAddResource() {
- var mcpSyncServer = McpServer.sync(createMcpTransport())
+ var mcpSyncServer = McpServer.sync(createMcpTransportProvider())
.serverInfo("test-server", "1.0.0")
.capabilities(ServerCapabilities.builder().resources(true, false).build())
.build();
Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description",
null);
- McpServerFeatures.SyncResourceRegistration registration = new McpServerFeatures.SyncResourceRegistration(
- resource, req -> new ReadResourceResult(List.of()));
+ McpServerFeatures.SyncResourceSpecification specification = new McpServerFeatures.SyncResourceSpecification(
+ resource, (exchange, req) -> new ReadResourceResult(List.of()));
- assertThatCode(() -> mcpSyncServer.addResource(registration)).doesNotThrowAnyException();
+ assertThatCode(() -> mcpSyncServer.addResource(specification)).doesNotThrowAnyException();
assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException();
}
@Test
- void testAddResourceWithNullRegistration() {
- var mcpSyncServer = McpServer.sync(createMcpTransport())
+ void testAddResourceWithNullSpecification() {
+ var mcpSyncServer = McpServer.sync(createMcpTransportProvider())
.serverInfo("test-server", "1.0.0")
.capabilities(ServerCapabilities.builder().resources(true, false).build())
.build();
- assertThatThrownBy(() -> mcpSyncServer.addResource((McpServerFeatures.SyncResourceRegistration) null))
+ assertThatThrownBy(() -> mcpSyncServer.addResource((McpServerFeatures.SyncResourceSpecification) null))
.isInstanceOf(McpError.class)
.hasMessage("Resource must not be null");
@@ -223,20 +224,24 @@ void testAddResourceWithNullRegistration() {
@Test
void testAddResourceWithoutCapability() {
- var serverWithoutResources = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build();
+ var serverWithoutResources = McpServer.sync(createMcpTransportProvider())
+ .serverInfo("test-server", "1.0.0")
+ .build();
Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description",
null);
- McpServerFeatures.SyncResourceRegistration registration = new McpServerFeatures.SyncResourceRegistration(
- resource, req -> new ReadResourceResult(List.of()));
+ McpServerFeatures.SyncResourceSpecification specification = new McpServerFeatures.SyncResourceSpecification(
+ resource, (exchange, req) -> new ReadResourceResult(List.of()));
- assertThatThrownBy(() -> serverWithoutResources.addResource(registration)).isInstanceOf(McpError.class)
+ assertThatThrownBy(() -> serverWithoutResources.addResource(specification)).isInstanceOf(McpError.class)
.hasMessage("Server must be configured with resource capabilities");
}
@Test
void testRemoveResourceWithoutCapability() {
- var serverWithoutResources = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build();
+ var serverWithoutResources = McpServer.sync(createMcpTransportProvider())
+ .serverInfo("test-server", "1.0.0")
+ .build();
assertThatThrownBy(() -> serverWithoutResources.removeResource(TEST_RESOURCE_URI)).isInstanceOf(McpError.class)
.hasMessage("Server must be configured with resource capabilities");
@@ -248,7 +253,7 @@ void testRemoveResourceWithoutCapability() {
@Test
void testNotifyPromptsListChanged() {
- var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build();
+ var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build();
assertThatCode(() -> mcpSyncServer.notifyPromptsListChanged()).doesNotThrowAnyException();
@@ -256,33 +261,37 @@ void testNotifyPromptsListChanged() {
}
@Test
- void testAddPromptWithNullRegistration() {
- var mcpSyncServer = McpServer.sync(createMcpTransport())
+ void testAddPromptWithNullSpecification() {
+ var mcpSyncServer = McpServer.sync(createMcpTransportProvider())
.serverInfo("test-server", "1.0.0")
.capabilities(ServerCapabilities.builder().prompts(false).build())
.build();
- assertThatThrownBy(() -> mcpSyncServer.addPrompt((McpServerFeatures.SyncPromptRegistration) null))
+ assertThatThrownBy(() -> mcpSyncServer.addPrompt((McpServerFeatures.SyncPromptSpecification) null))
.isInstanceOf(McpError.class)
- .hasMessage("Prompt registration must not be null");
+ .hasMessage("Prompt specification must not be null");
}
@Test
void testAddPromptWithoutCapability() {
- var serverWithoutPrompts = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build();
+ var serverWithoutPrompts = McpServer.sync(createMcpTransportProvider())
+ .serverInfo("test-server", "1.0.0")
+ .build();
Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", List.of());
- McpServerFeatures.SyncPromptRegistration registration = new McpServerFeatures.SyncPromptRegistration(prompt,
- req -> new GetPromptResult("Test prompt description", List
+ McpServerFeatures.SyncPromptSpecification specification = new McpServerFeatures.SyncPromptSpecification(prompt,
+ (exchange, req) -> new GetPromptResult("Test prompt description", List
.of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content")))));
- assertThatThrownBy(() -> serverWithoutPrompts.addPrompt(registration)).isInstanceOf(McpError.class)
+ assertThatThrownBy(() -> serverWithoutPrompts.addPrompt(specification)).isInstanceOf(McpError.class)
.hasMessage("Server must be configured with prompt capabilities");
}
@Test
void testRemovePromptWithoutCapability() {
- var serverWithoutPrompts = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build();
+ var serverWithoutPrompts = McpServer.sync(createMcpTransportProvider())
+ .serverInfo("test-server", "1.0.0")
+ .build();
assertThatThrownBy(() -> serverWithoutPrompts.removePrompt(TEST_PROMPT_NAME)).isInstanceOf(McpError.class)
.hasMessage("Server must be configured with prompt capabilities");
@@ -291,14 +300,14 @@ void testRemovePromptWithoutCapability() {
@Test
void testRemovePrompt() {
Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", List.of());
- McpServerFeatures.SyncPromptRegistration registration = new McpServerFeatures.SyncPromptRegistration(prompt,
- req -> new GetPromptResult("Test prompt description", List
+ McpServerFeatures.SyncPromptSpecification specification = new McpServerFeatures.SyncPromptSpecification(prompt,
+ (exchange, req) -> new GetPromptResult("Test prompt description", List
.of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content")))));
- var mcpSyncServer = McpServer.sync(createMcpTransport())
+ var mcpSyncServer = McpServer.sync(createMcpTransportProvider())
.serverInfo("test-server", "1.0.0")
.capabilities(ServerCapabilities.builder().prompts(true).build())
- .prompts(registration)
+ .prompts(specification)
.build();
assertThatCode(() -> mcpSyncServer.removePrompt(TEST_PROMPT_NAME)).doesNotThrowAnyException();
@@ -308,7 +317,7 @@ void testRemovePrompt() {
@Test
void testRemoveNonexistentPrompt() {
- var mcpSyncServer = McpServer.sync(createMcpTransport())
+ var mcpSyncServer = McpServer.sync(createMcpTransportProvider())
.serverInfo("test-server", "1.0.0")
.capabilities(ServerCapabilities.builder().prompts(true).build())
.build();
@@ -324,14 +333,14 @@ void testRemoveNonexistentPrompt() {
// ---------------------------------------
@Test
- void testRootsChangeConsumers() {
+ void testRootsChangeHandlers() {
// Test with single consumer
var rootsReceived = new McpSchema.Root[1];
var consumerCalled = new boolean[1];
- var singleConsumerServer = McpServer.sync(createMcpTransport())
+ var singleConsumerServer = McpServer.sync(createMcpTransportProvider())
.serverInfo("test-server", "1.0.0")
- .rootsChangeConsumers(List.of(roots -> {
+ .rootsChangeHandlers(List.of((exchange, roots) -> {
consumerCalled[0] = true;
if (!roots.isEmpty()) {
rootsReceived[0] = roots.get(0);
@@ -348,12 +357,12 @@ void testRootsChangeConsumers() {
var consumer2Called = new boolean[1];
var rootsContent = new List[1];
- var multipleConsumersServer = McpServer.sync(createMcpTransport())
+ var multipleConsumersServer = McpServer.sync(createMcpTransportProvider())
.serverInfo("test-server", "1.0.0")
- .rootsChangeConsumers(List.of(roots -> {
+ .rootsChangeHandlers(List.of((exchange, roots) -> {
consumer1Called[0] = true;
rootsContent[0] = roots;
- }, roots -> consumer2Called[0] = true))
+ }, (exchange, roots) -> consumer2Called[0] = true))
.build();
assertThat(multipleConsumersServer).isNotNull();
@@ -361,9 +370,9 @@ void testRootsChangeConsumers() {
onClose();
// Test error handling
- var errorHandlingServer = McpServer.sync(createMcpTransport())
+ var errorHandlingServer = McpServer.sync(createMcpTransportProvider())
.serverInfo("test-server", "1.0.0")
- .rootsChangeConsumers(List.of(roots -> {
+ .rootsChangeHandlers(List.of((exchange, roots) -> {
throw new RuntimeException("Test error");
}))
.build();
@@ -373,59 +382,10 @@ void testRootsChangeConsumers() {
onClose();
// Test without consumers
- var noConsumersServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build();
+ var noConsumersServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build();
assertThat(noConsumersServer).isNotNull();
assertThatCode(() -> noConsumersServer.closeGracefully()).doesNotThrowAnyException();
}
- // ---------------------------------------
- // Logging Tests
- // ---------------------------------------
-
- @Test
- void testLoggingLevels() {
- var mcpSyncServer = McpServer.sync(createMcpTransport())
- .serverInfo("test-server", "1.0.0")
- .capabilities(ServerCapabilities.builder().logging().build())
- .build();
-
- // Test all logging levels
- for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) {
- var notification = McpSchema.LoggingMessageNotification.builder()
- .level(level)
- .logger("test-logger")
- .data("Test message with level " + level)
- .build();
-
- assertThatCode(() -> mcpSyncServer.loggingNotification(notification)).doesNotThrowAnyException();
- }
- }
-
- @Test
- void testLoggingWithoutCapability() {
- var mcpSyncServer = McpServer.sync(createMcpTransport())
- .serverInfo("test-server", "1.0.0")
- .capabilities(ServerCapabilities.builder().build()) // No logging capability
- .build();
-
- var notification = McpSchema.LoggingMessageNotification.builder()
- .level(McpSchema.LoggingLevel.INFO)
- .logger("test-logger")
- .data("Test log message")
- .build();
-
- assertThatCode(() -> mcpSyncServer.loggingNotification(notification)).doesNotThrowAnyException();
- }
-
- @Test
- void testLoggingWithNullNotification() {
- var mcpSyncServer = McpServer.sync(createMcpTransport())
- .serverInfo("test-server", "1.0.0")
- .capabilities(ServerCapabilities.builder().logging().build())
- .build();
-
- assertThatThrownBy(() -> mcpSyncServer.loggingNotification(null)).isInstanceOf(McpError.class);
- }
-
}
diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/server/TestUtil.java b/mcp-test/src/main/java/io/modelcontextprotocol/server/TestUtil.java
new file mode 100644
index 00000000..0085f31e
--- /dev/null
+++ b/mcp-test/src/main/java/io/modelcontextprotocol/server/TestUtil.java
@@ -0,0 +1,31 @@
+/*
+* Copyright 2025 - 2025 the original author or authors.
+*/
+package io.modelcontextprotocol.server;
+
+import java.io.IOException;
+import java.net.InetSocketAddress;
+import java.net.ServerSocket;
+
+public class TestUtil {
+
+ TestUtil() {
+ // Prevent instantiation
+ }
+
+ /**
+ * Finds an available port on the local machine.
+ * @return an available port number
+ * @throws IllegalStateException if no available port can be found
+ */
+ public static int findAvailablePort() {
+ try (final ServerSocket socket = new ServerSocket()) {
+ socket.bind(new InetSocketAddress(0));
+ return socket.getLocalPort();
+ }
+ catch (final IOException e) {
+ throw new IllegalStateException("Cannot bind to an available port!", e);
+ }
+ }
+
+}
diff --git a/mcp/pom.xml b/mcp/pom.xml
index 2170ffef..77343282 100644
--- a/mcp/pom.xml
+++ b/mcp/pom.xml
@@ -6,7 +6,7 @@
io.modelcontextprotocol.sdkmcp-parent
- 0.8.0-SNAPSHOT
+ 0.11.0-SNAPSHOTmcpjar
@@ -97,7 +97,7 @@
test
-
@@ -126,12 +126,26 @@
${junit.version}test
+
+ org.junit.jupiter
+ junit-jupiter-params
+ ${junit.version}
+ test
+ org.mockitomockito-core${mockito.version}test
+
+
+
+ net.bytebuddy
+ byte-buddy
+ ${byte-buddy.version}
+ test
+ io.projectreactorreactor-test
diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java
index b301aa93..e3a997ba 100644
--- a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java
+++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java
@@ -14,10 +14,10 @@
import java.util.function.Function;
import com.fasterxml.jackson.core.type.TypeReference;
-import io.modelcontextprotocol.spec.ClientMcpTransport;
-import io.modelcontextprotocol.spec.DefaultMcpSession;
-import io.modelcontextprotocol.spec.DefaultMcpSession.NotificationHandler;
-import io.modelcontextprotocol.spec.DefaultMcpSession.RequestHandler;
+import io.modelcontextprotocol.spec.McpClientSession;
+import io.modelcontextprotocol.spec.McpClientSession.NotificationHandler;
+import io.modelcontextprotocol.spec.McpClientSession.RequestHandler;
+import io.modelcontextprotocol.spec.McpClientTransport;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities;
@@ -71,9 +71,10 @@
*
* @author Dariusz Jędrzejczyk
* @author Christian Tzolov
+ * @author Jihoon Kim
* @see McpClient
* @see McpSchema
- * @see DefaultMcpSession
+ * @see McpClientSession
*/
public class McpAsyncClient {
@@ -88,7 +89,6 @@ public class McpAsyncClient {
/**
* The max timeout to await for the client-server connection to be initialized.
- * Usually x2 the request timeout. // TODO should we make it configurable?
*/
private final Duration initializationTimeout;
@@ -96,7 +96,7 @@ public class McpAsyncClient {
* The MCP session implementation that manages bidirectional JSON-RPC communication
* between clients and servers.
*/
- private final DefaultMcpSession mcpSession;
+ private final McpClientSession mcpSession;
/**
* Client capabilities.
@@ -113,6 +113,11 @@ public class McpAsyncClient {
*/
private McpSchema.ServerCapabilities serverCapabilities;
+ /**
+ * Server instructions.
+ */
+ private String serverInstructions;
+
/**
* Server implementation information.
*/
@@ -151,18 +156,21 @@ public class McpAsyncClient {
* timeout.
* @param transport the transport to use.
* @param requestTimeout the session request-response timeout.
+ * @param initializationTimeout the max timeout to await for the client-server
* @param features the MCP Client supported features.
*/
- McpAsyncClient(ClientMcpTransport transport, Duration requestTimeout, McpClientFeatures.Async features) {
+ McpAsyncClient(McpClientTransport transport, Duration requestTimeout, Duration initializationTimeout,
+ McpClientFeatures.Async features) {
Assert.notNull(transport, "Transport must not be null");
Assert.notNull(requestTimeout, "Request timeout must not be null");
+ Assert.notNull(initializationTimeout, "Initialization timeout must not be null");
this.clientInfo = features.clientInfo();
this.clientCapabilities = features.clientCapabilities();
this.transport = transport;
this.roots = new ConcurrentHashMap<>(features.roots());
- this.initializationTimeout = requestTimeout.multipliedBy(2);
+ this.initializationTimeout = initializationTimeout;
// Request Handlers
Map> requestHandlers = new HashMap<>();
@@ -226,7 +234,7 @@ public class McpAsyncClient {
notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_MESSAGE,
asyncLoggingNotificationHandler(loggingConsumersFinal));
- this.mcpSession = new DefaultMcpSession(requestTimeout, transport, requestHandlers, notificationHandlers);
+ this.mcpSession = new McpClientSession(requestTimeout, transport, requestHandlers, notificationHandlers);
}
@@ -238,6 +246,15 @@ public McpSchema.ServerCapabilities getServerCapabilities() {
return this.serverCapabilities;
}
+ /**
+ * Get the server instructions that provide guidance to the client on how to interact
+ * with this server.
+ * @return The server instructions
+ */
+ public String getServerInstructions() {
+ return this.serverInstructions;
+ }
+
/**
* Get the server implementation information.
* @return The server implementation details
@@ -300,9 +317,9 @@ public Mono closeGracefully() {
* The client MUST initiate this phase by sending an initialize request containing:
* The protocol version the client supports, client's capabilities and clients
* implementation information.
- *
+ *
* The server MUST respond with its own capabilities and information.
- *
+ *
* After successful initialization, the client MUST send an initialized notification
* to indicate it is ready to begin normal operations.
* @return the initialize result.
@@ -326,6 +343,7 @@ public Mono initialize() {
return result.flatMap(initializeResult -> {
this.serverCapabilities = initializeResult.capabilities();
+ this.serverInstructions = initializeResult.instructions();
this.serverInfo = initializeResult.serverInfo();
logger.info("Server response with Protocol: {}, Capabilities: {}, Info: {} and Instructions {}",
@@ -362,7 +380,7 @@ private Mono withInitializationCheck(String actionName,
}
// --------------------------
- // Basic Utilites
+ // Basic Utilities
// --------------------------
/**
@@ -749,6 +767,14 @@ private NotificationHandler asyncPromptsChangeNotificationHandler(
// --------------------------
// Logging
// --------------------------
+ /**
+ * Create a notification handler for logging notifications from the server. This
+ * handler automatically distributes logging messages to all registered consumers.
+ * @param loggingConsumers List of consumers that will be notified when a logging
+ * message is received. Each consumer receives the logging message notification.
+ * @return A NotificationHandler that processes log notifications by distributing the
+ * message to all registered consumers
+ */
private NotificationHandler asyncLoggingNotificationHandler(
List>> loggingConsumers) {
@@ -771,13 +797,14 @@ private NotificationHandler asyncLoggingNotificationHandler(
* @see McpSchema.LoggingLevel
*/
public Mono setLoggingLevel(LoggingLevel loggingLevel) {
- Assert.notNull(loggingLevel, "Logging level must not be null");
+ if (loggingLevel == null) {
+ return Mono.error(new McpError("Logging level must not be null"));
+ }
return this.withInitializationCheck("setting logging level", initializedResult -> {
- String levelName = this.transport.unmarshalFrom(loggingLevel, new TypeReference() {
- });
- Map params = Map.of("level", levelName);
- return this.mcpSession.sendNotification(McpSchema.METHOD_LOGGING_SET_LEVEL, params);
+ var params = new McpSchema.SetLevelRequest(loggingLevel);
+ return this.mcpSession.sendRequest(McpSchema.METHOD_LOGGING_SET_LEVEL, params, new TypeReference() {
+ }).then();
});
}
@@ -790,4 +817,25 @@ void setProtocolVersions(List protocolVersions) {
this.protocolVersions = protocolVersions;
}
+ // --------------------------
+ // Completions
+ // --------------------------
+ private static final TypeReference COMPLETION_COMPLETE_RESULT_TYPE_REF = new TypeReference<>() {
+ };
+
+ /**
+ * Sends a completion/complete request to generate value suggestions based on a given
+ * reference and argument. This is typically used to provide auto-completion options
+ * for user input fields.
+ * @param completeRequest The request containing the prompt or resource reference and
+ * argument for which to generate completions.
+ * @return A Mono that completes with the result containing completion suggestions.
+ * @see McpSchema.CompleteRequest
+ * @see McpSchema.CompleteResult
+ */
+ public Mono completeCompletion(McpSchema.CompleteRequest completeRequest) {
+ return this.withInitializationCheck("complete completions", initializedResult -> this.mcpSession
+ .sendRequest(McpSchema.METHOD_COMPLETION_COMPLETE, completeRequest, COMPLETION_COMPLETE_RESULT_TYPE_REF));
+ }
+
}
diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java
index 7ab01b70..a1dc1168 100644
--- a/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java
+++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java
@@ -12,7 +12,7 @@
import java.util.function.Consumer;
import java.util.function.Function;
-import io.modelcontextprotocol.spec.ClientMcpTransport;
+import io.modelcontextprotocol.spec.McpClientTransport;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpTransport;
import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities;
@@ -114,7 +114,7 @@ public interface McpClient {
* @return A new builder instance for configuring the client
* @throws IllegalArgumentException if transport is null
*/
- static SyncSpec sync(ClientMcpTransport transport) {
+ static SyncSpec sync(McpClientTransport transport) {
return new SyncSpec(transport);
}
@@ -131,7 +131,7 @@ static SyncSpec sync(ClientMcpTransport transport) {
* @return A new builder instance for configuring the client
* @throws IllegalArgumentException if transport is null
*/
- static AsyncSpec async(ClientMcpTransport transport) {
+ static AsyncSpec async(McpClientTransport transport) {
return new AsyncSpec(transport);
}
@@ -153,10 +153,12 @@ static AsyncSpec async(ClientMcpTransport transport) {
*/
class SyncSpec {
- private final ClientMcpTransport transport;
+ private final McpClientTransport transport;
private Duration requestTimeout = Duration.ofSeconds(20); // Default timeout
+ private Duration initializationTimeout = Duration.ofSeconds(20);
+
private ClientCapabilities capabilities;
private Implementation clientInfo = new Implementation("Java SDK MCP Client", "1.0.0");
@@ -173,7 +175,7 @@ class SyncSpec {
private Function samplingHandler;
- private SyncSpec(ClientMcpTransport transport) {
+ private SyncSpec(McpClientTransport transport) {
Assert.notNull(transport, "Transport must not be null");
this.transport = transport;
}
@@ -193,6 +195,18 @@ public SyncSpec requestTimeout(Duration requestTimeout) {
return this;
}
+ /**
+ * @param initializationTimeout The duration to wait for the initialization
+ * lifecycle step to complete.
+ * @return This builder instance for method chaining
+ * @throws IllegalArgumentException if initializationTimeout is null
+ */
+ public SyncSpec initializationTimeout(Duration initializationTimeout) {
+ Assert.notNull(initializationTimeout, "Initialization timeout must not be null");
+ this.initializationTimeout = initializationTimeout;
+ return this;
+ }
+
/**
* Sets the client capabilities that will be advertised to the server during
* connection initialization. Capabilities define what features the client
@@ -354,7 +368,8 @@ public McpSyncClient build() {
McpClientFeatures.Async asyncFeatures = McpClientFeatures.Async.fromSync(syncFeatures);
- return new McpSyncClient(new McpAsyncClient(transport, this.requestTimeout, asyncFeatures));
+ return new McpSyncClient(
+ new McpAsyncClient(transport, this.requestTimeout, this.initializationTimeout, asyncFeatures));
}
}
@@ -377,10 +392,12 @@ public McpSyncClient build() {
*/
class AsyncSpec {
- private final ClientMcpTransport transport;
+ private final McpClientTransport transport;
private Duration requestTimeout = Duration.ofSeconds(20); // Default timeout
+ private Duration initializationTimeout = Duration.ofSeconds(20);
+
private ClientCapabilities capabilities;
private Implementation clientInfo = new Implementation("Spring AI MCP Client", "0.3.1");
@@ -397,7 +414,7 @@ class AsyncSpec {
private Function> samplingHandler;
- private AsyncSpec(ClientMcpTransport transport) {
+ private AsyncSpec(McpClientTransport transport) {
Assert.notNull(transport, "Transport must not be null");
this.transport = transport;
}
@@ -417,6 +434,18 @@ public AsyncSpec requestTimeout(Duration requestTimeout) {
return this;
}
+ /**
+ * @param initializationTimeout The duration to wait for the initialization
+ * lifecycle step to complete.
+ * @return This builder instance for method chaining
+ * @throws IllegalArgumentException if initializationTimeout is null
+ */
+ public AsyncSpec initializationTimeout(Duration initializationTimeout) {
+ Assert.notNull(initializationTimeout, "Initialization timeout must not be null");
+ this.initializationTimeout = initializationTimeout;
+ return this;
+ }
+
/**
* Sets the client capabilities that will be advertised to the server during
* connection initialization. Capabilities define what features the client
@@ -574,7 +603,7 @@ public AsyncSpec loggingConsumers(
* @return a new instance of {@link McpAsyncClient}.
*/
public McpAsyncClient build() {
- return new McpAsyncClient(this.transport, this.requestTimeout,
+ return new McpAsyncClient(this.transport, this.requestTimeout, this.initializationTimeout,
new McpClientFeatures.Async(this.clientInfo, this.capabilities, this.roots,
this.toolsChangeConsumers, this.resourcesChangeConsumers, this.promptsChangeConsumers,
this.loggingConsumers, this.samplingHandler));
diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java
index e5d964b7..a8fb979e 100644
--- a/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java
+++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java
@@ -6,7 +6,6 @@
import java.time.Duration;
-import io.modelcontextprotocol.spec.ClientMcpTransport;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities;
import io.modelcontextprotocol.spec.McpSchema.GetPromptRequest;
@@ -47,6 +46,7 @@
*
* @author Dariusz Jędrzejczyk
* @author Christian Tzolov
+ * @author Jihoon Kim
* @see McpClient
* @see McpAsyncClient
* @see McpSchema
@@ -66,11 +66,8 @@ public class McpSyncClient implements AutoCloseable {
* Create a new McpSyncClient with the given delegate.
* @param delegate the asynchronous kernel on top of which this synchronous client
* provides a blocking API.
- * @deprecated Use {@link McpClient#sync(ClientMcpTransport)} to obtain an instance.
*/
- @Deprecated
- // TODO make the constructor package private post-deprecation
- public McpSyncClient(McpAsyncClient delegate) {
+ McpSyncClient(McpAsyncClient delegate) {
Assert.notNull(delegate, "The delegate can not be null");
this.delegate = delegate;
}
@@ -83,6 +80,15 @@ public McpSchema.ServerCapabilities getServerCapabilities() {
return this.delegate.getServerCapabilities();
}
+ /**
+ * Get the server instructions that provide guidance to the client on how to interact
+ * with this server.
+ * @return The instructions
+ */
+ public String getServerInstructions() {
+ return this.delegate.getServerInstructions();
+ }
+
/**
* Get the server implementation information.
* @return The server implementation details
@@ -91,6 +97,14 @@ public McpSchema.Implementation getServerInfo() {
return this.delegate.getServerInfo();
}
+ /**
+ * Check if the client-server connection is initialized.
+ * @return true if the client-server connection is initialized
+ */
+ public boolean isInitialized() {
+ return this.delegate.isInitialized();
+ }
+
/**
* Get the client capabilities that define the supported features and functionality.
* @return The client capabilities
@@ -329,4 +343,14 @@ public void setLoggingLevel(McpSchema.LoggingLevel loggingLevel) {
this.delegate.setLoggingLevel(loggingLevel).block();
}
+ /**
+ * Send a completion/complete request.
+ * @param completeRequest the completion request contains the prompt or resource
+ * reference and arguments for generating suggestions.
+ * @return the completion result containing suggested values.
+ */
+ public McpSchema.CompleteResult completeCompletion(McpSchema.CompleteRequest completeRequest) {
+ return this.delegate.completeCompletion(completeRequest).block();
+ }
+
}
diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/FlowSseClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/FlowSseClient.java
index 7fc67993..50af35c7 100644
--- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/FlowSseClient.java
+++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/FlowSseClient.java
@@ -39,6 +39,8 @@ public class FlowSseClient {
private final HttpClient httpClient;
+ private final HttpRequest.Builder requestBuilder;
+
/**
* Pattern to extract the data content from SSE data field lines. Matches lines
* starting with "data:" and captures the remaining content.
@@ -92,7 +94,17 @@ public interface SseEventHandler {
* @param httpClient the {@link HttpClient} instance to use for SSE connections
*/
public FlowSseClient(HttpClient httpClient) {
+ this(httpClient, HttpRequest.newBuilder());
+ }
+
+ /**
+ * Creates a new FlowSseClient with the specified HTTP client and request builder.
+ * @param httpClient the {@link HttpClient} instance to use for SSE connections
+ * @param requestBuilder the {@link HttpRequest.Builder} to use for SSE requests
+ */
+ public FlowSseClient(HttpClient httpClient, HttpRequest.Builder requestBuilder) {
this.httpClient = httpClient;
+ this.requestBuilder = requestBuilder;
}
/**
@@ -109,8 +121,7 @@ public FlowSseClient(HttpClient httpClient) {
* @throws RuntimeException if the connection fails with a non-200 status code
*/
public void subscribe(String url, SseEventHandler eventHandler) {
- HttpRequest request = HttpRequest.newBuilder()
- .uri(URI.create(url))
+ HttpRequest request = this.requestBuilder.uri(URI.create(url))
.header("Accept", "text/event-stream")
.header("Cache-Control", "no-cache")
.GET()
diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java
index 35da5197..99cf2a62 100644
--- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java
+++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java
@@ -3,18 +3,6 @@
*/
package io.modelcontextprotocol.client.transport;
-import com.fasterxml.jackson.core.type.TypeReference;
-import com.fasterxml.jackson.databind.ObjectMapper;
-import io.modelcontextprotocol.client.transport.FlowSseClient.SseEvent;
-import io.modelcontextprotocol.spec.ClientMcpTransport;
-import io.modelcontextprotocol.spec.McpError;
-import io.modelcontextprotocol.spec.McpSchema;
-import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage;
-import io.modelcontextprotocol.util.Assert;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-import reactor.core.publisher.Mono;
-
import java.io.IOException;
import java.net.URI;
import java.net.http.HttpClient;
@@ -25,8 +13,22 @@
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.Consumer;
import java.util.function.Function;
+import com.fasterxml.jackson.core.type.TypeReference;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import io.modelcontextprotocol.client.transport.FlowSseClient.SseEvent;
+import io.modelcontextprotocol.spec.McpClientTransport;
+import io.modelcontextprotocol.spec.McpError;
+import io.modelcontextprotocol.spec.McpSchema;
+import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage;
+import io.modelcontextprotocol.util.Assert;
+import io.modelcontextprotocol.util.Utils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import reactor.core.publisher.Mono;
+
/**
* Server-Sent Events (SSE) implementation of the
* {@link io.modelcontextprotocol.spec.McpTransport} that follows the MCP HTTP with SSE
@@ -52,9 +54,9 @@
*
* @author Christian Tzolov
* @see io.modelcontextprotocol.spec.McpTransport
- * @see io.modelcontextprotocol.spec.ClientMcpTransport
+ * @see io.modelcontextprotocol.spec.McpClientTransport
*/
-public class HttpClientSseClientTransport implements ClientMcpTransport {
+public class HttpClientSseClientTransport implements McpClientTransport {
private static final Logger logger = LoggerFactory.getLogger(HttpClientSseClientTransport.class);
@@ -65,10 +67,13 @@ public class HttpClientSseClientTransport implements ClientMcpTransport {
private static final String ENDPOINT_EVENT_TYPE = "endpoint";
/** Default SSE endpoint path */
- private static final String SSE_ENDPOINT = "/sse";
+ private static final String DEFAULT_SSE_ENDPOINT = "/sse";
/** Base URI for the MCP server */
- private final String baseUri;
+ private final URI baseUri;
+
+ /** SSE endpoint path */
+ private final String sseEndpoint;
/** SSE client for handling server-sent events. Uses the /sse endpoint */
private final FlowSseClient sseClient;
@@ -79,6 +84,9 @@ public class HttpClientSseClientTransport implements ClientMcpTransport {
*/
private final HttpClient httpClient;
+ /** HTTP request builder for building requests to send messages to the server */
+ private final HttpRequest.Builder requestBuilder;
+
/** JSON object mapper for message serialization/deserialization */
protected ObjectMapper objectMapper;
@@ -97,7 +105,10 @@ public class HttpClientSseClientTransport implements ClientMcpTransport {
/**
* Creates a new transport instance with default HTTP client and object mapper.
* @param baseUri the base URI of the MCP server
+ * @deprecated Use {@link HttpClientSseClientTransport#builder(String)} instead. This
+ * constructor will be removed in future versions.
*/
+ @Deprecated(forRemoval = true)
public HttpClientSseClientTransport(String baseUri) {
this(HttpClient.newBuilder(), baseUri, new ObjectMapper());
}
@@ -108,15 +119,208 @@ public HttpClientSseClientTransport(String baseUri) {
* @param baseUri the base URI of the MCP server
* @param objectMapper the object mapper for JSON serialization/deserialization
* @throws IllegalArgumentException if objectMapper or clientBuilder is null
+ * @deprecated Use {@link HttpClientSseClientTransport#builder(String)} instead. This
+ * constructor will be removed in future versions.
*/
+ @Deprecated(forRemoval = true)
public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, String baseUri, ObjectMapper objectMapper) {
+ this(clientBuilder, baseUri, DEFAULT_SSE_ENDPOINT, objectMapper);
+ }
+
+ /**
+ * Creates a new transport instance with custom HTTP client builder and object mapper.
+ * @param clientBuilder the HTTP client builder to use
+ * @param baseUri the base URI of the MCP server
+ * @param sseEndpoint the SSE endpoint path
+ * @param objectMapper the object mapper for JSON serialization/deserialization
+ * @throws IllegalArgumentException if objectMapper or clientBuilder is null
+ * @deprecated Use {@link HttpClientSseClientTransport#builder(String)} instead. This
+ * constructor will be removed in future versions.
+ */
+ @Deprecated(forRemoval = true)
+ public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, String baseUri, String sseEndpoint,
+ ObjectMapper objectMapper) {
+ this(clientBuilder, HttpRequest.newBuilder(), baseUri, sseEndpoint, objectMapper);
+ }
+
+ /**
+ * Creates a new transport instance with custom HTTP client builder, object mapper,
+ * and headers.
+ * @param clientBuilder the HTTP client builder to use
+ * @param requestBuilder the HTTP request builder to use
+ * @param baseUri the base URI of the MCP server
+ * @param sseEndpoint the SSE endpoint path
+ * @param objectMapper the object mapper for JSON serialization/deserialization
+ * @throws IllegalArgumentException if objectMapper, clientBuilder, or headers is null
+ * @deprecated Use {@link HttpClientSseClientTransport#builder(String)} instead. This
+ * constructor will be removed in future versions.
+ */
+ @Deprecated(forRemoval = true)
+ public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, HttpRequest.Builder requestBuilder,
+ String baseUri, String sseEndpoint, ObjectMapper objectMapper) {
+ this(clientBuilder.connectTimeout(Duration.ofSeconds(10)).build(), requestBuilder, baseUri, sseEndpoint,
+ objectMapper);
+ }
+
+ /**
+ * Creates a new transport instance with custom HTTP client builder, object mapper,
+ * and headers.
+ * @param httpClient the HTTP client to use
+ * @param requestBuilder the HTTP request builder to use
+ * @param baseUri the base URI of the MCP server
+ * @param sseEndpoint the SSE endpoint path
+ * @param objectMapper the object mapper for JSON serialization/deserialization
+ * @throws IllegalArgumentException if objectMapper, clientBuilder, or headers is null
+ */
+ HttpClientSseClientTransport(HttpClient httpClient, HttpRequest.Builder requestBuilder, String baseUri,
+ String sseEndpoint, ObjectMapper objectMapper) {
Assert.notNull(objectMapper, "ObjectMapper must not be null");
Assert.hasText(baseUri, "baseUri must not be empty");
- Assert.notNull(clientBuilder, "clientBuilder must not be null");
- this.baseUri = baseUri;
+ Assert.hasText(sseEndpoint, "sseEndpoint must not be empty");
+ Assert.notNull(httpClient, "httpClient must not be null");
+ Assert.notNull(requestBuilder, "requestBuilder must not be null");
+ this.baseUri = URI.create(baseUri);
+ this.sseEndpoint = sseEndpoint;
this.objectMapper = objectMapper;
- this.httpClient = clientBuilder.connectTimeout(Duration.ofSeconds(10)).build();
- this.sseClient = new FlowSseClient(this.httpClient);
+ this.httpClient = httpClient;
+ this.requestBuilder = requestBuilder;
+
+ this.sseClient = new FlowSseClient(this.httpClient, requestBuilder);
+ }
+
+ /**
+ * Creates a new builder for {@link HttpClientSseClientTransport}.
+ * @param baseUri the base URI of the MCP server
+ * @return a new builder instance
+ */
+ public static Builder builder(String baseUri) {
+ return new Builder().baseUri(baseUri);
+ }
+
+ /**
+ * Builder for {@link HttpClientSseClientTransport}.
+ */
+ public static class Builder {
+
+ private String baseUri;
+
+ private String sseEndpoint = DEFAULT_SSE_ENDPOINT;
+
+ private HttpClient.Builder clientBuilder = HttpClient.newBuilder()
+ .version(HttpClient.Version.HTTP_1_1)
+ .connectTimeout(Duration.ofSeconds(10));
+
+ private ObjectMapper objectMapper = new ObjectMapper();
+
+ private HttpRequest.Builder requestBuilder = HttpRequest.newBuilder()
+ .header("Content-Type", "application/json");
+
+ /**
+ * Creates a new builder instance.
+ */
+ Builder() {
+ // Default constructor
+ }
+
+ /**
+ * Creates a new builder with the specified base URI.
+ * @param baseUri the base URI of the MCP server
+ * @deprecated Use {@link HttpClientSseClientTransport#builder(String)} instead.
+ * This constructor is deprecated and will be removed or made {@code protected} or
+ * {@code private} in a future release.
+ */
+ @Deprecated(forRemoval = true)
+ public Builder(String baseUri) {
+ Assert.hasText(baseUri, "baseUri must not be empty");
+ this.baseUri = baseUri;
+ }
+
+ /**
+ * Sets the base URI.
+ * @param baseUri the base URI
+ * @return this builder
+ */
+ Builder baseUri(String baseUri) {
+ Assert.hasText(baseUri, "baseUri must not be empty");
+ this.baseUri = baseUri;
+ return this;
+ }
+
+ /**
+ * Sets the SSE endpoint path.
+ * @param sseEndpoint the SSE endpoint path
+ * @return this builder
+ */
+ public Builder sseEndpoint(String sseEndpoint) {
+ Assert.hasText(sseEndpoint, "sseEndpoint must not be empty");
+ this.sseEndpoint = sseEndpoint;
+ return this;
+ }
+
+ /**
+ * Sets the HTTP client builder.
+ * @param clientBuilder the HTTP client builder
+ * @return this builder
+ */
+ public Builder clientBuilder(HttpClient.Builder clientBuilder) {
+ Assert.notNull(clientBuilder, "clientBuilder must not be null");
+ this.clientBuilder = clientBuilder;
+ return this;
+ }
+
+ /**
+ * Customizes the HTTP client builder.
+ * @param clientCustomizer the consumer to customize the HTTP client builder
+ * @return this builder
+ */
+ public Builder customizeClient(final Consumer clientCustomizer) {
+ Assert.notNull(clientCustomizer, "clientCustomizer must not be null");
+ clientCustomizer.accept(clientBuilder);
+ return this;
+ }
+
+ /**
+ * Sets the HTTP request builder.
+ * @param requestBuilder the HTTP request builder
+ * @return this builder
+ */
+ public Builder requestBuilder(HttpRequest.Builder requestBuilder) {
+ Assert.notNull(requestBuilder, "requestBuilder must not be null");
+ this.requestBuilder = requestBuilder;
+ return this;
+ }
+
+ /**
+ * Customizes the HTTP client builder.
+ * @param requestCustomizer the consumer to customize the HTTP request builder
+ * @return this builder
+ */
+ public Builder customizeRequest(final Consumer requestCustomizer) {
+ Assert.notNull(requestCustomizer, "requestCustomizer must not be null");
+ requestCustomizer.accept(requestBuilder);
+ return this;
+ }
+
+ /**
+ * Sets the object mapper for JSON serialization/deserialization.
+ * @param objectMapper the object mapper
+ * @return this builder
+ */
+ public Builder objectMapper(ObjectMapper objectMapper) {
+ Assert.notNull(objectMapper, "objectMapper must not be null");
+ this.objectMapper = objectMapper;
+ return this;
+ }
+
+ /**
+ * Builds a new {@link HttpClientSseClientTransport} instance.
+ * @return a new transport instance
+ */
+ public HttpClientSseClientTransport build() {
+ return new HttpClientSseClientTransport(clientBuilder.build(), requestBuilder, baseUri, sseEndpoint,
+ objectMapper);
+ }
+
}
/**
@@ -137,7 +341,8 @@ public Mono connect(Function, Mono> h
CompletableFuture future = new CompletableFuture<>();
connectionFuture.set(future);
- sseClient.subscribe(this.baseUri + SSE_ENDPOINT, new FlowSseClient.SseEventHandler() {
+ URI clientUri = Utils.resolveUri(this.baseUri, this.sseEndpoint);
+ sseClient.subscribe(clientUri.toString(), new FlowSseClient.SseEventHandler() {
@Override
public void onEvent(SseEvent event) {
if (isClosing) {
@@ -209,9 +414,8 @@ public Mono sendMessage(JSONRPCMessage message) {
try {
String jsonText = this.objectMapper.writeValueAsString(message);
- HttpRequest request = HttpRequest.newBuilder()
- .uri(URI.create(this.baseUri + endpoint))
- .header("Content-Type", "application/json")
+ URI requestUri = Utils.resolveUri(baseUri, endpoint);
+ HttpRequest request = this.requestBuilder.uri(requestUri)
.POST(HttpRequest.BodyPublishers.ofString(jsonText))
.build();
@@ -251,7 +455,7 @@ public Mono closeGracefully() {
}
/**
- * Unmarshals data to the specified type using the configured object mapper.
+ * Unmarshal data to the specified type using the configured object mapper.
* @param data the data to unmarshal
* @param typeRef the type reference for the target type
* @param the target type
diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java
index 614c6512..9d71cbb4 100644
--- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java
+++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java
@@ -11,14 +11,13 @@
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
-import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executors;
import java.util.function.Consumer;
import java.util.function.Function;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
-import io.modelcontextprotocol.spec.ClientMcpTransport;
+import io.modelcontextprotocol.spec.McpClientTransport;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage;
import io.modelcontextprotocol.util.Assert;
@@ -38,7 +37,7 @@
* @author Christian Tzolov
* @author Dariusz Jędrzejczyk
*/
-public class StdioClientTransport implements ClientMcpTransport {
+public class StdioClientTransport implements McpClientTransport {
private static final Logger logger = LoggerFactory.getLogger(StdioClientTransport.class);
@@ -293,7 +292,7 @@ private void startInboundProcessing() {
*/
private void startOutboundProcessing() {
this.handleOutbound(messages -> messages
- // this bit is important since writes come from user threads and we
+ // this bit is important since writes come from user threads, and we
// want to ensure that the actual writing happens on a dedicated thread
.publishOn(outboundScheduler)
.handle((message, s) -> {
@@ -353,14 +352,15 @@ public Mono closeGracefully() {
// Give a short time for any pending messages to be processed
return Mono.delay(Duration.ofMillis(100));
- })).then(Mono.fromFuture(() -> {
+ })).then(Mono.defer(() -> {
logger.debug("Sending TERM to process");
if (this.process != null) {
this.process.destroy();
- return process.onExit();
+ return Mono.fromFuture(process.onExit());
}
else {
- return CompletableFuture.failedFuture(new RuntimeException("Process not started"));
+ logger.warn("Process not started");
+ return Mono.empty();
}
})).doOnNext(process -> {
if (process.exitValue() != 0) {
diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java
index 7b691678..1efa13de 100644
--- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java
+++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java
@@ -5,25 +5,31 @@
package io.modelcontextprotocol.server;
import java.time.Duration;
+import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
+import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
-import java.util.function.Function;
+import java.util.function.BiFunction;
import com.fasterxml.jackson.core.type.TypeReference;
-import io.modelcontextprotocol.spec.DefaultMcpSession;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import io.modelcontextprotocol.spec.McpClientSession;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
-import io.modelcontextprotocol.spec.ServerMcpTransport;
-import io.modelcontextprotocol.spec.DefaultMcpSession.NotificationHandler;
import io.modelcontextprotocol.spec.McpSchema.CallToolResult;
-import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities;
import io.modelcontextprotocol.spec.McpSchema.LoggingLevel;
import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification;
+import io.modelcontextprotocol.spec.McpSchema.ResourceTemplate;
+import io.modelcontextprotocol.spec.McpSchema.SetLevelRequest;
import io.modelcontextprotocol.spec.McpSchema.Tool;
+import io.modelcontextprotocol.spec.McpServerSession;
+import io.modelcontextprotocol.spec.McpServerTransportProvider;
+import io.modelcontextprotocol.util.DeafaultMcpUriTemplateManagerFactory;
+import io.modelcontextprotocol.util.McpUriTemplateManagerFactory;
import io.modelcontextprotocol.util.Utils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -67,69 +73,71 @@
*
* @author Christian Tzolov
* @author Dariusz Jędrzejczyk
+ * @author Jihoon Kim
* @see McpServer
* @see McpSchema
- * @see DefaultMcpSession
+ * @see McpClientSession
*/
public class McpAsyncServer {
private static final Logger logger = LoggerFactory.getLogger(McpAsyncServer.class);
- /**
- * The MCP session implementation that manages bidirectional JSON-RPC communication
- * between clients and servers.
- */
- private final DefaultMcpSession mcpSession;
+ private final McpServerTransportProvider mcpTransportProvider;
- private final ServerMcpTransport transport;
+ private final ObjectMapper objectMapper;
private final McpSchema.ServerCapabilities serverCapabilities;
private final McpSchema.Implementation serverInfo;
- private McpSchema.ClientCapabilities clientCapabilities;
-
- private McpSchema.Implementation clientInfo;
+ private final String instructions;
- /**
- * Thread-safe list of tool handlers that can be modified at runtime.
- */
- private final CopyOnWriteArrayList tools = new CopyOnWriteArrayList<>();
+ private final CopyOnWriteArrayList tools = new CopyOnWriteArrayList<>();
private final CopyOnWriteArrayList resourceTemplates = new CopyOnWriteArrayList<>();
- private final ConcurrentHashMap resources = new ConcurrentHashMap<>();
+ private final ConcurrentHashMap resources = new ConcurrentHashMap<>();
- private final ConcurrentHashMap prompts = new ConcurrentHashMap<>();
+ private final ConcurrentHashMap prompts = new ConcurrentHashMap<>();
+ // FIXME: this field is deprecated and should be remvoed together with the
+ // broadcasting loggingNotification.
private LoggingLevel minLoggingLevel = LoggingLevel.DEBUG;
- /**
- * Supported protocol versions.
- */
+ private final ConcurrentHashMap completions = new ConcurrentHashMap<>();
+
private List protocolVersions = List.of(McpSchema.LATEST_PROTOCOL_VERSION);
+ private McpUriTemplateManagerFactory uriTemplateManagerFactory = new DeafaultMcpUriTemplateManagerFactory();
+
/**
- * Create a new McpAsyncServer with the given transport and capabilities.
- * @param mcpTransport The transport layer implementation for MCP communication.
+ * Create a new McpAsyncServer with the given transport provider and capabilities.
+ * @param mcpTransportProvider The transport layer implementation for MCP
+ * communication.
* @param features The MCP server supported features.
+ * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization
*/
- McpAsyncServer(ServerMcpTransport mcpTransport, McpServerFeatures.Async features) {
-
+ McpAsyncServer(McpServerTransportProvider mcpTransportProvider, ObjectMapper objectMapper,
+ McpServerFeatures.Async features, Duration requestTimeout,
+ McpUriTemplateManagerFactory uriTemplateManagerFactory) {
+ this.mcpTransportProvider = mcpTransportProvider;
+ this.objectMapper = objectMapper;
this.serverInfo = features.serverInfo();
this.serverCapabilities = features.serverCapabilities();
+ this.instructions = features.instructions();
this.tools.addAll(features.tools());
this.resources.putAll(features.resources());
this.resourceTemplates.addAll(features.resourceTemplates());
this.prompts.putAll(features.prompts());
+ this.completions.putAll(features.completions());
+ this.uriTemplateManagerFactory = uriTemplateManagerFactory;
- Map> requestHandlers = new HashMap<>();
+ Map> requestHandlers = new HashMap<>();
// Initialize request handlers for standard MCP methods
- requestHandlers.put(McpSchema.METHOD_INITIALIZE, asyncInitializeRequestHandler());
// Ping MUST respond with an empty data, but not NULL response.
- requestHandlers.put(McpSchema.METHOD_PING, (params) -> Mono.just(""));
+ requestHandlers.put(McpSchema.METHOD_PING, (exchange, params) -> Mono.just(Map.of()));
// Add tools API handlers if the tool capability is enabled
if (this.serverCapabilities.tools() != null) {
@@ -155,57 +163,61 @@ public class McpAsyncServer {
requestHandlers.put(McpSchema.METHOD_LOGGING_SET_LEVEL, setLoggerRequestHandler());
}
- Map notificationHandlers = new HashMap<>();
+ // Add completion API handlers if the completion capability is enabled
+ if (this.serverCapabilities.completions() != null) {
+ requestHandlers.put(McpSchema.METHOD_COMPLETION_COMPLETE, completionCompleteRequestHandler());
+ }
- notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_INITIALIZED, (params) -> Mono.empty());
+ Map notificationHandlers = new HashMap<>();
- List, Mono>> rootsChangeConsumers = features.rootsChangeConsumers();
+ notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_INITIALIZED, (exchange, params) -> Mono.empty());
+
+ List, Mono>> rootsChangeConsumers = features
+ .rootsChangeConsumers();
if (Utils.isEmpty(rootsChangeConsumers)) {
- rootsChangeConsumers = List.of((roots) -> Mono.fromRunnable(() -> logger
+ rootsChangeConsumers = List.of((exchange, roots) -> Mono.fromRunnable(() -> logger
.warn("Roots list changed notification, but no consumers provided. Roots list changed: {}", roots)));
}
notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED,
asyncRootsListChangedNotificationHandler(rootsChangeConsumers));
- this.transport = mcpTransport;
- this.mcpSession = new DefaultMcpSession(Duration.ofSeconds(10), mcpTransport, requestHandlers,
- notificationHandlers);
+ mcpTransportProvider.setSessionFactory(
+ transport -> new McpServerSession(UUID.randomUUID().toString(), requestTimeout, transport,
+ this::asyncInitializeRequestHandler, Mono::empty, requestHandlers, notificationHandlers));
}
// ---------------------------------------
// Lifecycle Management
// ---------------------------------------
- private DefaultMcpSession.RequestHandler asyncInitializeRequestHandler() {
- return params -> {
- McpSchema.InitializeRequest initializeRequest = transport.unmarshalFrom(params,
- new TypeReference() {
- });
- this.clientCapabilities = initializeRequest.capabilities();
- this.clientInfo = initializeRequest.clientInfo();
+ private Mono asyncInitializeRequestHandler(
+ McpSchema.InitializeRequest initializeRequest) {
+ return Mono.defer(() -> {
logger.info("Client initialize request - Protocol: {}, Capabilities: {}, Info: {}",
initializeRequest.protocolVersion(), initializeRequest.capabilities(),
initializeRequest.clientInfo());
- // The server MUST respond with the highest protocol version it supports if
+ // The server MUST respond with the highest protocol version it supports
+ // if
// it does not support the requested (e.g. Client) version.
String serverProtocolVersion = this.protocolVersions.get(this.protocolVersions.size() - 1);
if (this.protocolVersions.contains(initializeRequest.protocolVersion())) {
- // If the server supports the requested protocol version, it MUST respond
+ // If the server supports the requested protocol version, it MUST
+ // respond
// with the same version.
serverProtocolVersion = initializeRequest.protocolVersion();
}
else {
logger.warn(
- "Client requested unsupported protocol version: {}, so the server will sugggest the {} version instead",
+ "Client requested unsupported protocol version: {}, so the server will suggest the {} version instead",
initializeRequest.protocolVersion(), serverProtocolVersion);
}
return Mono.just(new McpSchema.InitializeResult(serverProtocolVersion, this.serverCapabilities,
- this.serverInfo, null));
- };
+ this.serverInfo, this.instructions));
+ });
}
/**
@@ -224,67 +236,31 @@ public McpSchema.Implementation getServerInfo() {
return this.serverInfo;
}
- /**
- * Get the client capabilities that define the supported features and functionality.
- * @return The client capabilities
- */
- public ClientCapabilities getClientCapabilities() {
- return this.clientCapabilities;
- }
-
- /**
- * Get the client implementation information.
- * @return The client implementation details
- */
- public McpSchema.Implementation getClientInfo() {
- return this.clientInfo;
- }
-
/**
* Gracefully closes the server, allowing any in-progress operations to complete.
* @return A Mono that completes when the server has been closed
*/
public Mono closeGracefully() {
- return this.mcpSession.closeGracefully();
+ return this.mcpTransportProvider.closeGracefully();
}
/**
* Close the server immediately.
*/
public void close() {
- this.mcpSession.close();
- }
-
- private static final TypeReference LIST_ROOTS_RESULT_TYPE_REF = new TypeReference<>() {
- };
-
- /**
- * Retrieves the list of all roots provided by the client.
- * @return A Mono that emits the list of roots result.
- */
- public Mono listRoots() {
- return this.listRoots(null);
- }
-
- /**
- * Retrieves a paginated list of roots provided by the server.
- * @param cursor Optional pagination cursor from a previous list request
- * @return A Mono that emits the list of roots result containing
- */
- public Mono listRoots(String cursor) {
- return this.mcpSession.sendRequest(McpSchema.METHOD_ROOTS_LIST, new McpSchema.PaginatedRequest(cursor),
- LIST_ROOTS_RESULT_TYPE_REF);
+ this.mcpTransportProvider.close();
}
- private NotificationHandler asyncRootsListChangedNotificationHandler(
- List, Mono>> rootsChangeConsumers) {
- return params -> listRoots().flatMap(listRootsResult -> Flux.fromIterable(rootsChangeConsumers)
- .flatMap(consumer -> consumer.apply(listRootsResult.roots()))
- .onErrorResume(error -> {
- logger.error("Error handling roots list change notification", error);
- return Mono.empty();
- })
- .then());
+ private McpServerSession.NotificationHandler asyncRootsListChangedNotificationHandler(
+ List, Mono>> rootsChangeConsumers) {
+ return (exchange, params) -> exchange.listRoots()
+ .flatMap(listRootsResult -> Flux.fromIterable(rootsChangeConsumers)
+ .flatMap(consumer -> consumer.apply(exchange, listRootsResult.roots()))
+ .onErrorResume(error -> {
+ logger.error("Error handling roots list change notification", error);
+ return Mono.empty();
+ })
+ .then());
}
// ---------------------------------------
@@ -292,18 +268,18 @@ private NotificationHandler asyncRootsListChangedNotificationHandler(
// ---------------------------------------
/**
- * Add a new tool registration at runtime.
- * @param toolRegistration The tool registration to add
+ * Add a new tool specification at runtime.
+ * @param toolSpecification The tool specification to add
* @return Mono that completes when clients have been notified of the change
*/
- public Mono addTool(McpServerFeatures.AsyncToolRegistration toolRegistration) {
- if (toolRegistration == null) {
- return Mono.error(new McpError("Tool registration must not be null"));
+ public Mono addTool(McpServerFeatures.AsyncToolSpecification toolSpecification) {
+ if (toolSpecification == null) {
+ return Mono.error(new McpError("Tool specification must not be null"));
}
- if (toolRegistration.tool() == null) {
+ if (toolSpecification.tool() == null) {
return Mono.error(new McpError("Tool must not be null"));
}
- if (toolRegistration.call() == null) {
+ if (toolSpecification.call() == null) {
return Mono.error(new McpError("Tool call handler must not be null"));
}
if (this.serverCapabilities.tools() == null) {
@@ -312,13 +288,13 @@ public Mono addTool(McpServerFeatures.AsyncToolRegistration toolRegistrati
return Mono.defer(() -> {
// Check for duplicate tool names
- if (this.tools.stream().anyMatch(th -> th.tool().name().equals(toolRegistration.tool().name()))) {
+ if (this.tools.stream().anyMatch(th -> th.tool().name().equals(toolSpecification.tool().name()))) {
return Mono
- .error(new McpError("Tool with name '" + toolRegistration.tool().name() + "' already exists"));
+ .error(new McpError("Tool with name '" + toolSpecification.tool().name() + "' already exists"));
}
- this.tools.add(toolRegistration);
- logger.debug("Added tool handler: {}", toolRegistration.tool().name());
+ this.tools.add(toolSpecification);
+ logger.debug("Added tool handler: {}", toolSpecification.tool().name());
if (this.serverCapabilities.tools().listChanged()) {
return notifyToolsListChanged();
@@ -341,7 +317,8 @@ public Mono removeTool(String toolName) {
}
return Mono.defer(() -> {
- boolean removed = this.tools.removeIf(toolRegistration -> toolRegistration.tool().name().equals(toolName));
+ boolean removed = this.tools
+ .removeIf(toolSpecification -> toolSpecification.tool().name().equals(toolName));
if (removed) {
logger.debug("Removed tool handler: {}", toolName);
if (this.serverCapabilities.tools().listChanged()) {
@@ -358,32 +335,32 @@ public Mono removeTool(String toolName) {
* @return A Mono that completes when all clients have been notified
*/
public Mono notifyToolsListChanged() {
- return this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_TOOLS_LIST_CHANGED, null);
+ return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_TOOLS_LIST_CHANGED, null);
}
- private DefaultMcpSession.RequestHandler toolsListRequestHandler() {
- return params -> {
- List tools = this.tools.stream().map(McpServerFeatures.AsyncToolRegistration::tool).toList();
+ private McpServerSession.RequestHandler toolsListRequestHandler() {
+ return (exchange, params) -> {
+ List tools = this.tools.stream().map(McpServerFeatures.AsyncToolSpecification::tool).toList();
return Mono.just(new McpSchema.ListToolsResult(tools, null));
};
}
- private DefaultMcpSession.RequestHandler toolsCallRequestHandler() {
- return params -> {
- McpSchema.CallToolRequest callToolRequest = transport.unmarshalFrom(params,
+ private McpServerSession.RequestHandler toolsCallRequestHandler() {
+ return (exchange, params) -> {
+ McpSchema.CallToolRequest callToolRequest = objectMapper.convertValue(params,
new TypeReference() {
});
- Optional toolRegistration = this.tools.stream()
+ Optional toolSpecification = this.tools.stream()
.filter(tr -> callToolRequest.name().equals(tr.tool().name()))
.findAny();
- if (toolRegistration.isEmpty()) {
+ if (toolSpecification.isEmpty()) {
return Mono.error(new McpError("Tool not found: " + callToolRequest.name()));
}
- return toolRegistration.map(tool -> tool.call().apply(callToolRequest.arguments()))
+ return toolSpecification.map(tool -> tool.call().apply(exchange, callToolRequest.arguments()))
.orElse(Mono.error(new McpError("Tool not found: " + callToolRequest.name())));
};
}
@@ -394,11 +371,11 @@ private DefaultMcpSession.RequestHandler toolsCallRequestHandler
/**
* Add a new resource handler at runtime.
- * @param resourceHandler The resource handler to add
+ * @param resourceSpecification The resource handler to add
* @return Mono that completes when clients have been notified of the change
*/
- public Mono addResource(McpServerFeatures.AsyncResourceRegistration resourceHandler) {
- if (resourceHandler == null || resourceHandler.resource() == null) {
+ public Mono addResource(McpServerFeatures.AsyncResourceSpecification resourceSpecification) {
+ if (resourceSpecification == null || resourceSpecification.resource() == null) {
return Mono.error(new McpError("Resource must not be null"));
}
@@ -407,11 +384,11 @@ public Mono addResource(McpServerFeatures.AsyncResourceRegistration resour
}
return Mono.defer(() -> {
- if (this.resources.putIfAbsent(resourceHandler.resource().uri(), resourceHandler) != null) {
- return Mono
- .error(new McpError("Resource with URI '" + resourceHandler.resource().uri() + "' already exists"));
+ if (this.resources.putIfAbsent(resourceSpecification.resource().uri(), resourceSpecification) != null) {
+ return Mono.error(new McpError(
+ "Resource with URI '" + resourceSpecification.resource().uri() + "' already exists"));
}
- logger.debug("Added resource handler: {}", resourceHandler.resource().uri());
+ logger.debug("Added resource handler: {}", resourceSpecification.resource().uri());
if (this.serverCapabilities.resources().listChanged()) {
return notifyResourcesListChanged();
}
@@ -433,7 +410,7 @@ public Mono removeResource(String resourceUri) {
}
return Mono.defer(() -> {
- McpServerFeatures.AsyncResourceRegistration removed = this.resources.remove(resourceUri);
+ McpServerFeatures.AsyncResourceSpecification removed = this.resources.remove(resourceUri);
if (removed != null) {
logger.debug("Removed resource handler: {}", resourceUri);
if (this.serverCapabilities.resources().listChanged()) {
@@ -450,35 +427,59 @@ public Mono removeResource(String resourceUri) {
* @return A Mono that completes when all clients have been notified
*/
public Mono notifyResourcesListChanged() {
- return this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_RESOURCES_LIST_CHANGED, null);
+ return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_RESOURCES_LIST_CHANGED, null);
}
- private DefaultMcpSession.RequestHandler resourcesListRequestHandler() {
- return params -> {
+ private McpServerSession.RequestHandler resourcesListRequestHandler() {
+ return (exchange, params) -> {
var resourceList = this.resources.values()
.stream()
- .map(McpServerFeatures.AsyncResourceRegistration::resource)
+ .map(McpServerFeatures.AsyncResourceSpecification::resource)
.toList();
return Mono.just(new McpSchema.ListResourcesResult(resourceList, null));
};
}
- private DefaultMcpSession.RequestHandler resourceTemplateListRequestHandler() {
- return params -> Mono.just(new McpSchema.ListResourceTemplatesResult(this.resourceTemplates, null));
+ private McpServerSession.RequestHandler resourceTemplateListRequestHandler() {
+ return (exchange, params) -> Mono
+ .just(new McpSchema.ListResourceTemplatesResult(this.getResourceTemplates(), null));
+
+ }
+
+ private List getResourceTemplates() {
+ var list = new ArrayList<>(this.resourceTemplates);
+ List resourceTemplates = this.resources.keySet()
+ .stream()
+ .filter(uri -> uri.contains("{"))
+ .map(uri -> {
+ var resource = this.resources.get(uri).resource();
+ var template = new McpSchema.ResourceTemplate(resource.uri(), resource.name(), resource.description(),
+ resource.mimeType(), resource.annotations());
+ return template;
+ })
+ .toList();
+
+ list.addAll(resourceTemplates);
+ return list;
}
- private DefaultMcpSession.RequestHandler resourcesReadRequestHandler() {
- return params -> {
- McpSchema.ReadResourceRequest resourceRequest = transport.unmarshalFrom(params,
+ private McpServerSession.RequestHandler resourcesReadRequestHandler() {
+ return (exchange, params) -> {
+ McpSchema.ReadResourceRequest resourceRequest = objectMapper.convertValue(params,
new TypeReference() {
});
var resourceUri = resourceRequest.uri();
- McpServerFeatures.AsyncResourceRegistration registration = this.resources.get(resourceUri);
- if (registration != null) {
- return registration.readHandler().apply(resourceRequest);
- }
- return Mono.error(new McpError("Resource not found: " + resourceUri));
+
+ McpServerFeatures.AsyncResourceSpecification specification = this.resources.values()
+ .stream()
+ .filter(resourceSpecification -> this.uriTemplateManagerFactory
+ .create(resourceSpecification.resource().uri())
+ .matches(resourceUri))
+ .findFirst()
+ .orElseThrow(() -> new McpError("Resource not found: " + resourceUri));
+
+ return specification.readHandler().apply(exchange, resourceRequest);
};
}
@@ -488,26 +489,26 @@ private DefaultMcpSession.RequestHandler resources
/**
* Add a new prompt handler at runtime.
- * @param promptRegistration The prompt handler to add
+ * @param promptSpecification The prompt handler to add
* @return Mono that completes when clients have been notified of the change
*/
- public Mono addPrompt(McpServerFeatures.AsyncPromptRegistration promptRegistration) {
- if (promptRegistration == null) {
- return Mono.error(new McpError("Prompt registration must not be null"));
+ public Mono addPrompt(McpServerFeatures.AsyncPromptSpecification promptSpecification) {
+ if (promptSpecification == null) {
+ return Mono.error(new McpError("Prompt specification must not be null"));
}
if (this.serverCapabilities.prompts() == null) {
return Mono.error(new McpError("Server must be configured with prompt capabilities"));
}
return Mono.defer(() -> {
- McpServerFeatures.AsyncPromptRegistration registration = this.prompts
- .putIfAbsent(promptRegistration.prompt().name(), promptRegistration);
- if (registration != null) {
+ McpServerFeatures.AsyncPromptSpecification specification = this.prompts
+ .putIfAbsent(promptSpecification.prompt().name(), promptSpecification);
+ if (specification != null) {
return Mono.error(
- new McpError("Prompt with name '" + promptRegistration.prompt().name() + "' already exists"));
+ new McpError("Prompt with name '" + promptSpecification.prompt().name() + "' already exists"));
}
- logger.debug("Added prompt handler: {}", promptRegistration.prompt().name());
+ logger.debug("Added prompt handler: {}", promptSpecification.prompt().name());
// Servers that declared the listChanged capability SHOULD send a
// notification,
@@ -533,7 +534,7 @@ public Mono removePrompt(String promptName) {
}
return Mono.defer(() -> {
- McpServerFeatures.AsyncPromptRegistration removed = this.prompts.remove(promptName);
+ McpServerFeatures.AsyncPromptSpecification removed = this.prompts.remove(promptName);
if (removed != null) {
logger.debug("Removed prompt handler: {}", promptName);
@@ -553,38 +554,38 @@ public Mono removePrompt(String promptName) {
* @return A Mono that completes when all clients have been notified
*/
public Mono notifyPromptsListChanged() {
- return this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_PROMPTS_LIST_CHANGED, null);
+ return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_PROMPTS_LIST_CHANGED, null);
}
- private DefaultMcpSession.RequestHandler promptsListRequestHandler() {
- return params -> {
+ private McpServerSession.RequestHandler promptsListRequestHandler() {
+ return (exchange, params) -> {
// TODO: Implement pagination
- // McpSchema.PaginatedRequest request = transport.unmarshalFrom(params,
+ // McpSchema.PaginatedRequest request = objectMapper.convertValue(params,
// new TypeReference() {
// });
var promptList = this.prompts.values()
.stream()
- .map(McpServerFeatures.AsyncPromptRegistration::prompt)
+ .map(McpServerFeatures.AsyncPromptSpecification::prompt)
.toList();
return Mono.just(new McpSchema.ListPromptsResult(promptList, null));
};
}
- private DefaultMcpSession.RequestHandler promptsGetRequestHandler() {
- return params -> {
- McpSchema.GetPromptRequest promptRequest = transport.unmarshalFrom(params,
+ private McpServerSession.RequestHandler promptsGetRequestHandler() {
+ return (exchange, params) -> {
+ McpSchema.GetPromptRequest promptRequest = objectMapper.convertValue(params,
new TypeReference() {
});
// Implement prompt retrieval logic here
- McpServerFeatures.AsyncPromptRegistration registration = this.prompts.get(promptRequest.name());
- if (registration == null) {
+ McpServerFeatures.AsyncPromptSpecification specification = this.prompts.get(promptRequest.name());
+ if (specification == null) {
return Mono.error(new McpError("Prompt not found: " + promptRequest.name()));
}
- return registration.promptHandler().apply(promptRequest);
+ return specification.promptHandler().apply(exchange, promptRequest);
};
}
@@ -593,77 +594,139 @@ private DefaultMcpSession.RequestHandler promptsGetRe
// ---------------------------------------
/**
- * Send a logging message notification to all connected clients. Messages below the
- * current minimum logging level will be filtered out.
+ * This implementation would, incorrectly, broadcast the logging message to all
+ * connected clients, using a single minLoggingLevel for all of them. Similar to the
+ * sampling and roots, the logging level should be set per client session and use the
+ * ServerExchange to send the logging message to the right client.
* @param loggingMessageNotification The logging message to send
* @return A Mono that completes when the notification has been sent
+ * @deprecated Use
+ * {@link McpAsyncServerExchange#loggingNotification(LoggingMessageNotification)}
+ * instead.
*/
+ @Deprecated
public Mono loggingNotification(LoggingMessageNotification loggingMessageNotification) {
if (loggingMessageNotification == null) {
return Mono.error(new McpError("Logging message must not be null"));
}
- Map params = this.transport.unmarshalFrom(loggingMessageNotification,
- new TypeReference