diff --git a/src/inMemory.test.ts b/src/inMemory.test.ts index f7e9e979..baf43446 100644 --- a/src/inMemory.test.ts +++ b/src/inMemory.test.ts @@ -1,5 +1,6 @@ import { InMemoryTransport } from "./inMemory.js"; import { JSONRPCMessage } from "./types.js"; +import { AuthInfo } from "./server/auth/types.js"; describe("InMemoryTransport", () => { let clientTransport: InMemoryTransport; @@ -35,6 +36,32 @@ describe("InMemoryTransport", () => { expect(receivedMessage).toEqual(message); }); + test("should send message with auth info from client to server", async () => { + const message: JSONRPCMessage = { + jsonrpc: "2.0", + method: "test", + id: 1, + }; + + const authInfo: AuthInfo = { + token: "test-token", + clientId: "test-client", + scopes: ["read", "write"], + expiresAt: Date.now() / 1000 + 3600, + }; + + let receivedMessage: JSONRPCMessage | undefined; + let receivedAuthInfo: AuthInfo | undefined; + serverTransport.onmessage = (msg, extra) => { + receivedMessage = msg; + receivedAuthInfo = extra?.authInfo; + }; + + await clientTransport.send(message, { authInfo }); + expect(receivedMessage).toEqual(message); + expect(receivedAuthInfo).toEqual(authInfo); + }); + test("should send message from server to client", async () => { const message: JSONRPCMessage = { jsonrpc: "2.0", diff --git a/src/inMemory.ts b/src/inMemory.ts index 106a9e7e..5dd6e81e 100644 --- a/src/inMemory.ts +++ b/src/inMemory.ts @@ -1,16 +1,22 @@ import { Transport } from "./shared/transport.js"; -import { JSONRPCMessage } from "./types.js"; +import { JSONRPCMessage, RequestId } from "./types.js"; +import { AuthInfo } from "./server/auth/types.js"; + +interface QueuedMessage { + message: JSONRPCMessage; + extra?: { authInfo?: AuthInfo }; +} /** * In-memory transport for creating clients and servers that talk to each other within the same process. */ export class InMemoryTransport implements Transport { private _otherTransport?: InMemoryTransport; - private _messageQueue: JSONRPCMessage[] = []; + private _messageQueue: QueuedMessage[] = []; onclose?: () => void; onerror?: (error: Error) => void; - onmessage?: (message: JSONRPCMessage) => void; + onmessage?: (message: JSONRPCMessage, extra?: { authInfo?: AuthInfo }) => void; sessionId?: string; /** @@ -27,10 +33,8 @@ export class InMemoryTransport implements Transport { async start(): Promise { // Process any messages that were queued before start was called while (this._messageQueue.length > 0) { - const message = this._messageQueue.shift(); - if (message) { - this.onmessage?.(message); - } + const queuedMessage = this._messageQueue.shift()!; + this.onmessage?.(queuedMessage.message, queuedMessage.extra); } } @@ -41,15 +45,19 @@ export class InMemoryTransport implements Transport { this.onclose?.(); } - async send(message: JSONRPCMessage): Promise { + /** + * Sends a message with optional auth info. + * This is useful for testing authentication scenarios. + */ + async send(message: JSONRPCMessage, options?: { relatedRequestId?: RequestId, authInfo?: AuthInfo }): Promise { if (!this._otherTransport) { throw new Error("Not connected"); } if (this._otherTransport.onmessage) { - this._otherTransport.onmessage(message); + this._otherTransport.onmessage(message, { authInfo: options?.authInfo }); } else { - this._otherTransport._messageQueue.push(message); + this._otherTransport._messageQueue.push({ message, extra: { authInfo: options?.authInfo } }); } } } diff --git a/src/server/auth/types.ts b/src/server/auth/types.ts index 93c5a493..c25c2b60 100644 --- a/src/server/auth/types.ts +++ b/src/server/auth/types.ts @@ -21,4 +21,10 @@ export interface AuthInfo { * When the token expires (in seconds since epoch). */ expiresAt?: number; + + /** + * Additional data associated with the token. + * This field should be used for any additional data that needs to be attached to the auth info. + */ + extra?: Record; } \ No newline at end of file diff --git a/src/server/sse.ts b/src/server/sse.ts index 46948b47..03f6fefc 100644 --- a/src/server/sse.ts +++ b/src/server/sse.ts @@ -4,6 +4,7 @@ import { Transport } from "../shared/transport.js"; import { JSONRPCMessage, JSONRPCMessageSchema } from "../types.js"; import getRawBody from "raw-body"; import contentType from "content-type"; +import { AuthInfo } from "./auth/types.js"; import { URL } from 'url'; const MAXIMUM_MESSAGE_SIZE = "4mb"; @@ -19,7 +20,7 @@ export class SSEServerTransport implements Transport { onclose?: () => void; onerror?: (error: Error) => void; - onmessage?: (message: JSONRPCMessage) => void; + onmessage?: (message: JSONRPCMessage, extra?: { authInfo?: AuthInfo }) => void; /** * Creates a new SSE server transport, which will direct the client to POST messages to the relative or absolute URL identified by `_endpoint`. @@ -76,7 +77,7 @@ export class SSEServerTransport implements Transport { * This should be called when a POST request is made to send a message to the server. */ async handlePostMessage( - req: IncomingMessage, + req: IncomingMessage & { auth?: AuthInfo }, res: ServerResponse, parsedBody?: unknown, ): Promise { @@ -85,6 +86,7 @@ export class SSEServerTransport implements Transport { res.writeHead(500).end(message); throw new Error(message); } + const authInfo: AuthInfo | undefined = req.auth; let body: string | unknown; try { @@ -104,7 +106,7 @@ export class SSEServerTransport implements Transport { } try { - await this.handleMessage(typeof body === 'string' ? JSON.parse(body) : body); + await this.handleMessage(typeof body === 'string' ? JSON.parse(body) : body, { authInfo }); } catch { res.writeHead(400).end(`Invalid message: ${body}`); return; @@ -116,7 +118,7 @@ export class SSEServerTransport implements Transport { /** * Handle a client message, regardless of how it arrived. This can be used to inform the server of messages that arrive via a means different than HTTP POST. */ - async handleMessage(message: unknown): Promise { + async handleMessage(message: unknown, extra?: { authInfo?: AuthInfo }): Promise { let parsedMessage: JSONRPCMessage; try { parsedMessage = JSONRPCMessageSchema.parse(message); @@ -125,7 +127,7 @@ export class SSEServerTransport implements Transport { throw error; } - this.onmessage?.(parsedMessage); + this.onmessage?.(parsedMessage, extra); } async close(): Promise { diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index 07bfc02c..91fa8366 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -23,6 +23,7 @@ import { ServerCapabilities, } from "../types.js"; import { Transport } from "./transport.js"; +import { AuthInfo } from "../server/auth/types.js"; /** * Callback for progress notifications. @@ -109,6 +110,11 @@ export type RequestHandlerExtra { + this._transport.onmessage = (message, extra) => { if (isJSONRPCResponse(message) || isJSONRPCError(message)) { this._onresponse(message); } else if (isJSONRPCRequest(message)) { - this._onrequest(message); + this._onrequest(message, extra); } else if (isJSONRPCNotification(message)) { this._onnotification(message); } else { @@ -326,7 +332,7 @@ export abstract class Protocol< ); } - private _onrequest(request: JSONRPCRequest): void { + private _onrequest(request: JSONRPCRequest, extra?: { authInfo?: AuthInfo }): void { const handler = this._requestHandlers.get(request.method) ?? this.fallbackRequestHandler; @@ -351,20 +357,20 @@ export abstract class Protocol< const abortController = new AbortController(); this._requestHandlerAbortControllers.set(request.id, abortController); - // Create extra object with both abort signal and sessionId from transport - const extra: RequestHandlerExtra = { + const fullExtra: RequestHandlerExtra = { signal: abortController.signal, sessionId: this._transport?.sessionId, sendNotification: (notification) => this.notification(notification, { relatedRequestId: request.id }), sendRequest: (r, resultSchema, options?) => - this.request(r, resultSchema, { ...options, relatedRequestId: request.id }) + this.request(r, resultSchema, { ...options, relatedRequestId: request.id }), + authInfo: extra?.authInfo, }; // Starting with Promise.resolve() puts any synchronous errors into the monad as well. Promise.resolve() - .then(() => handler(request, extra)) + .then(() => handler(request, fullExtra)) .then( (result) => { if (abortController.signal.aborted) { diff --git a/src/shared/transport.ts b/src/shared/transport.ts index e464653b..c2732391 100644 --- a/src/shared/transport.ts +++ b/src/shared/transport.ts @@ -1,3 +1,4 @@ +import { AuthInfo } from "../server/auth/types.js"; import { JSONRPCMessage, RequestId } from "../types.js"; /** @@ -41,8 +42,11 @@ export interface Transport { /** * Callback for when a message (request or response) is received over the connection. + * + * Includes the authInfo if the transport is authenticated. + * */ - onmessage?: (message: JSONRPCMessage) => void; + onmessage?: (message: JSONRPCMessage, extra?: { authInfo?: AuthInfo }) => void; /** * The session ID generated for this connection.