Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,6 @@ out

.DS_Store
dist/

# claude
.claude/
5 changes: 5 additions & 0 deletions src/client/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,11 @@ export class Client<

this._instructions = result.instructions;

// Handle session assignment from server
if (result.sessionId) {
this.createSession(result.sessionId, result.sessionTimeout);
}

await this.notification({
method: "notifications/initialized",
});
Expand Down
9 changes: 8 additions & 1 deletion src/client/streamableHttp.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { Transport, FetchLike } from "../shared/transport.js";
import { SessionState } from "../shared/protocol.js";
import { isInitializedNotification, isJSONRPCRequest, isJSONRPCResponse, JSONRPCMessage, JSONRPCMessageSchema } from "../types.js";
import { auth, AuthResult, extractResourceMetadataUrl, OAuthClientProvider, UnauthorizedError } from "./auth.js";
import { EventSourceParserStream } from "eventsource-parser/stream";
Expand Down Expand Up @@ -129,6 +130,7 @@ export class StreamableHTTPClientTransport implements Transport {
private _authProvider?: OAuthClientProvider;
private _fetch?: FetchLike;
private _sessionId?: string;
private _sessionState?: SessionState; // For protocol-level session support
private _reconnectionOptions: StreamableHTTPReconnectionOptions;
private _protocolVersion?: string;

Expand Down Expand Up @@ -504,7 +506,12 @@ export class StreamableHTTPClientTransport implements Transport {
}

get sessionId(): string | undefined {
return this._sessionId;
// Prefer protocol-level session state, fallback to legacy _sessionId
return this._sessionState?.sessionId || this._sessionId;
}

setSessionState(sessionState: SessionState): void {
this._sessionState = sessionState;
}

/**
Expand Down
19 changes: 18 additions & 1 deletion src/inMemory.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { Transport } from "./shared/transport.js";
import { JSONRPCMessage, RequestId } from "./types.js";
import { AuthInfo } from "./server/auth/types.js";
import { SessionState } from "./shared/protocol.js";

interface QueuedMessage {
message: JSONRPCMessage;
Expand All @@ -17,7 +18,23 @@ export class InMemoryTransport implements Transport {
onclose?: () => void;
onerror?: (error: Error) => void;
onmessage?: (message: JSONRPCMessage, extra?: { authInfo?: AuthInfo }) => void;
sessionId?: string;

private _sessionState?: SessionState;

get sessionId(): string | undefined {
return this._sessionState?.sessionId;
}

getLegacySessionOptions(): undefined {
// InMemoryTransport has no legacy session configuration
return undefined;
}

setSessionState(sessionState: SessionState): void {
// Store session state for sessionId getter
// InMemoryTransport doesn't use session state for other purposes
this._sessionState = sessionState;
}

/**
* Creates a pair of linked in-memory transports that can communicate with each other. One should be passed to a Client and one to a Server.
Expand Down
10 changes: 6 additions & 4 deletions src/integration-tests/taskResumability.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -186,16 +186,18 @@ describe('Transport resumability', () => {
name: 'run-notifications',
arguments: {
count: 3,
interval: 10
interval: 50 // Increased interval for more reliable timing
}
}
}, CallToolResultSchema, {
resumptionToken: lastEventId,
onresumptiontoken: onLastEventIdUpdate
});

// Wait for some notifications to arrive (not all) - shorter wait time
await new Promise(resolve => setTimeout(resolve, 20));
// Wait for some notifications to arrive (not all)
// With 50ms interval, first notification should arrive immediately,
// second at 50ms. We wait 75ms to ensure we get at least 1-2 notifications
await new Promise(resolve => setTimeout(resolve, 75));

// Verify we received some notifications and lastEventId was updated
expect(notifications.length).toBeGreaterThan(0);
Expand All @@ -219,7 +221,7 @@ describe('Transport resumability', () => {


// Add a short delay to ensure clean disconnect before reconnecting
await new Promise(resolve => setTimeout(resolve, 10));
await new Promise(resolve => setTimeout(resolve, 50));

// Wait for the rejection to be handled
await catchPromise;
Expand Down
130 changes: 128 additions & 2 deletions src/server/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ import {
Protocol,
ProtocolOptions,
RequestOptions,
SessionOptions,
SessionState,
} from "../shared/protocol.js";
import { Transport } from "../shared/transport.js";
import {
ClientCapabilities,
CreateMessageRequest,
Expand Down Expand Up @@ -32,6 +35,8 @@ import {
ServerRequest,
ServerResult,
SUPPORTED_PROTOCOL_VERSIONS,
SessionTerminateRequestSchema,
SessionTerminateRequest,
} from "../types.js";
import Ajv from "ajv";

Expand Down Expand Up @@ -85,31 +90,88 @@ export class Server<
private _clientVersion?: Implementation;
private _capabilities: ServerCapabilities;
private _instructions?: string;
private _sessionOptions?: SessionOptions;

/**
* Callback for when initialization has fully completed (i.e., the client has sent an `initialized` notification).
*/
oninitialized?: () => void;

/**
* Returns the connected transport instance.
* Used for session-to-server routing in examples.
*/
getTransport() {
return this.transport;
}

/**
* Initializes this server with the given name and version information.
*/
constructor(
private _serverInfo: Implementation,
options?: ServerOptions,
) {
super(options);
// Extract session options before passing to super
const { sessions, ...protocolOptions } = options ?? {};
super(protocolOptions);
this._sessionOptions = sessions;
this._capabilities = options?.capabilities ?? {};
this._instructions = options?.instructions;

this.setRequestHandler(InitializeRequestSchema, (request) =>
this._oninitialize(request),
);
this.setRequestHandler(SessionTerminateRequestSchema, (request) =>
this._onSessionTerminate(request),
);
this.setNotificationHandler(InitializedNotificationSchema, () =>
this.oninitialized?.(),
);
}

/**
* Handles initialization request synchronously for HTTP transport backward compatibility.
* This bypasses the Protocol's async request handling to allow immediate error detection.
* @internal
*/
async handleInitializeSync(request: InitializeRequest): Promise<InitializeResult> {
// Call the internal initialization handler directly
const result = await this._oninitialize(request);
return result;
}

/**
* Connect to a transport, handling legacy session options from the transport.
*/
async connect(transport: Transport): Promise<void> {
// Handle legacy session options delegation from transport
const legacySessionOptions = transport.getLegacySessionOptions?.();
if (legacySessionOptions) {
if (this._sessionOptions) {
// Both server session options and transport legacy session options provided. Using server options.
} else {
this._sessionOptions = legacySessionOptions;
}
}

// Register synchronous initialization handler if transport supports it
if (transport.setInitializeHandler) {
transport.setInitializeHandler((request: InitializeRequest) =>
this.handleInitializeSync(request)
);
}

// Register synchronous termination handler if transport supports it
if (transport.setTerminateHandler) {
transport.setTerminateHandler((sessionId?: string) =>
this.terminateSession(sessionId)
);
}

await super.connect(transport);
}

/**
* Registers new capabilities. This can only be called before connecting to a transport.
*
Expand Down Expand Up @@ -269,12 +331,76 @@ export class Server<
? requestedVersion
: LATEST_PROTOCOL_VERSION;

return {
const result: InitializeResult = {
protocolVersion,
capabilities: this.getCapabilities(),
serverInfo: this._serverInfo,
...(this._instructions && { instructions: this._instructions }),
};

// Generate session if supported
if (this._sessionOptions?.sessionIdGenerator) {
const sessionId = this._sessionOptions.sessionIdGenerator();
result.sessionId = sessionId;
result.sessionTimeout = this._sessionOptions.sessionTimeout;

await this.initializeSession(sessionId, this._sessionOptions.sessionTimeout);
}

return result;
}

private async initializeSession(sessionId: string, timeout?: number): Promise<void> {
// Create the session
this.createSession(sessionId, timeout);

// Try to call the initialization callback, but if it fails,
// store the error in session state and rethrow
try {
await this._sessionOptions?.onsessioninitialized?.(sessionId);
} catch (error) {
// Store the error in session state for the transport to check
const sessionState = this.getSessionState();
if (sessionState) {
sessionState.callbackError = error instanceof Error ? error : new Error(String(error));
}
throw error;
}
}

protected async terminateSession(sessionId?: string): Promise<void> {
// Get the current session ID before termination
const currentSessionId = this.getSessionState()?.sessionId;

// Call parent's terminateSession to clear the session state
await super.terminateSession(sessionId);

// Now call the callback if we had a session
if (currentSessionId) {
try {
await this._sessionOptions?.onsessionclosed?.(currentSessionId);
} catch (error) {
// Re-create minimal session state just to store the error for transport to check
const sessionState: SessionState = {
sessionId: currentSessionId,
createdAt: Date.now(),
lastActivity: Date.now(),
callbackError: error instanceof Error ? error : new Error(String(error))
};
// Notify transport of the error state
this.transport?.setSessionState?.(sessionState);
throw error;
}
}
}

private async _onSessionTerminate(
request: SessionTerminateRequest
): Promise<object> {
// Use the same termination logic as the protocol method
// sessionId comes directly from the protocol request
await this.terminateSession(request.sessionId);
return {};
}

/**
Expand Down
25 changes: 21 additions & 4 deletions src/server/mcp.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ import {
LoggingMessageNotificationSchema,
Notification,
TextContent,
ElicitRequestSchema
ElicitRequestSchema,
InitializeResultSchema
} from "../types.js";
import { ResourceTemplate } from "./mcp.js";
import { completable } from "./completable.js";
Expand Down Expand Up @@ -1342,14 +1343,18 @@ describe("tool()", () => {
const mcpServer = new McpServer({
name: "test server",
version: "1.0",
}, {
sessions: {
sessionIdGenerator: () => "test-session-123"
}
});

const client = new Client({
name: "test client",
version: "1.0",
});

let receivedSessionId: string | undefined;
let receivedSessionId: string | number | undefined;
mcpServer.tool("test-tool", async (extra) => {
receivedSessionId = extra.sessionId;
return {
Expand All @@ -1363,20 +1368,32 @@ describe("tool()", () => {
});

const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair();
// Set a test sessionId on the server transport
serverTransport.sessionId = "test-session-123";

await Promise.all([
client.connect(clientTransport),
mcpServer.server.connect(serverTransport),
]);

// Initialize to create session
await client.request(
{
method: "initialize",
params: {
protocolVersion: "2025-06-18",
capabilities: {},
clientInfo: { name: "test client", version: "1.0" }
}
},
InitializeResultSchema
);

await client.request(
{
method: "tools/call",
params: {
name: "test-tool",
},
sessionId: "test-session-123", // Protocol-level session approach
},
CallToolResultSchema,
);
Expand Down
8 changes: 8 additions & 0 deletions src/server/mcp.ts
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,14 @@ export class McpServer {
await this.server.close();
}

/**
* Returns the connected transport instance.
* Used for session-to-server routing in examples.
*/
getTransport() {
return this.server.getTransport();
}

private _toolHandlersInitialized = false;

private setToolRequestHandlers() {
Expand Down
Loading