From 5b5b783da57e2999f91dbc02f80a42bff2448d64 Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Sat, 5 Apr 2025 12:54:48 +0100 Subject: [PATCH] Add ways to associate related requests and notifications --- src/shared/protocol.ts | 38 ++++++++++++++++++++++++++++---------- src/shared/transport.ts | 4 ++-- 2 files changed, 30 insertions(+), 12 deletions(-) diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index a6e47184..43a3f1d6 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -78,6 +78,8 @@ export type RequestOptions = { * If not specified, there is no maximum total timeout. */ maxTotalTimeout?: number; + + relatedRequestId?: RequestId; }; /** @@ -364,7 +366,7 @@ export abstract class Protocol< private _onprogress(notification: ProgressNotification): void { const { progressToken, ...params } = notification.params; const messageId = Number(progressToken); - + const handler = this._progressHandlers.get(messageId); if (!handler) { this._onerror(new Error(`Received a progress notification for an unknown token: ${JSON.stringify(notification)}`)); @@ -373,7 +375,7 @@ export abstract class Protocol< const responseHandler = this._responseHandlers.get(messageId); const timeoutInfo = this._timeoutInfo.get(messageId); - + if (timeoutInfo && responseHandler && timeoutInfo.resetTimeoutOnProgress) { try { this._resetTimeout(messageId); @@ -460,6 +462,8 @@ export abstract class Protocol< resultSchema: T, options?: RequestOptions, ): Promise> { + const { relatedRequestId } = options ?? {}; + return new Promise((resolve, reject) => { if (!this._transport) { reject(new Error("Not connected")); @@ -500,7 +504,7 @@ export abstract class Protocol< requestId: messageId, reason: String(reason), }, - }) + }, { relatedRequestId }) .catch((error) => this._onerror(new Error(`Failed to send cancellation: ${error}`)), ); @@ -538,7 +542,7 @@ export abstract class Protocol< this._setupTimeout(messageId, timeout, options?.maxTotalTimeout, timeoutHandler, options?.resetTimeoutOnProgress ?? false); - this._transport.send(jsonrpcRequest).catch((error) => { + this._transport.send(jsonrpcRequest, { relatedRequestId }).catch((error) => { this._cleanupTimeout(messageId); reject(error); }); @@ -548,7 +552,7 @@ export abstract class Protocol< /** * Emits a notification, which is a one-way message that does not expect a response. */ - async notification(notification: SendNotificationT): Promise { + async notification(notification: SendNotificationT, options?: { relatedRequestId?: RequestId }): Promise { if (!this._transport) { throw new Error("Not connected"); } @@ -560,7 +564,7 @@ export abstract class Protocol< jsonrpc: "2.0", }; - await this._transport.send(jsonrpcNotification); + await this._transport.send(jsonrpcNotification, options); } /** @@ -572,18 +576,32 @@ export abstract class Protocol< T extends ZodObject<{ method: ZodLiteral; }>, + U extends ZodType, >( requestSchema: T, handler: ( request: z.infer, - extra: RequestHandlerExtra, + extra: RequestHandlerExtra & { + sendNotification: (notification: SendNotificationT) => Promise, + sendRequest: (request: SendRequestT, resultSchema: U, options?: RequestOptions) => Promise>, + }, ) => SendResultT | Promise, ): void { const method = requestSchema.shape.method.value; this.assertRequestHandlerCapability(method); - this._requestHandlers.set(method, (request, extra) => - Promise.resolve(handler(requestSchema.parse(request), extra)), - ); + + this._requestHandlers.set(method, (request, extra) => { + return Promise.resolve(handler(requestSchema.parse(request), { + ...extra, + + sendNotification: + (notification: SendNotificationT) => + this.notification(notification, { relatedRequestId: request.id }), + + sendRequest: (r: SendRequestT, resultSchema: U, options?: RequestOptions) => + this.request(r, resultSchema, { ...options, relatedRequestId: request.id }) + })); + }); } /** diff --git a/src/shared/transport.ts b/src/shared/transport.ts index b80e2a51..aa29f490 100644 --- a/src/shared/transport.ts +++ b/src/shared/transport.ts @@ -1,4 +1,4 @@ -import { JSONRPCMessage } from "../types.js"; +import { JSONRPCMessage, RequestId } from "../types.js"; /** * Describes the minimal contract for a MCP transport that a client or server can communicate over. @@ -16,7 +16,7 @@ export interface Transport { /** * Sends a JSON-RPC message (request or response). */ - send(message: JSONRPCMessage): Promise; + send(message: JSONRPCMessage, options?: { relatedRequestId?: RequestId }): Promise; /** * Closes the connection.