Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 211 additions & 1 deletion src/server/streamableHttp.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ interface TestServerConfig {
enableJsonResponse?: boolean;
customRequestHandler?: (req: IncomingMessage, res: ServerResponse, parsedBody?: unknown) => Promise<void>;
eventStore?: EventStore;
onsessionclosed?: (sessionId: string) => void;
onsessioninitialized?: (sessionId: string) => void | Promise<void>;
onsessionclosed?: (sessionId: string) => void | Promise<void>;
}

/**
Expand Down Expand Up @@ -59,6 +60,7 @@ async function createTestServer(config: TestServerConfig = { sessionIdGenerator:
sessionIdGenerator: config.sessionIdGenerator,
enableJsonResponse: config.enableJsonResponse ?? false,
eventStore: config.eventStore,
onsessioninitialized: config.onsessioninitialized,
onsessionclosed: config.onsessionclosed
});

Expand Down Expand Up @@ -114,6 +116,7 @@ async function createTestAuthServer(config: TestServerConfig = { sessionIdGenera
sessionIdGenerator: config.sessionIdGenerator,
enableJsonResponse: config.enableJsonResponse ?? false,
eventStore: config.eventStore,
onsessioninitialized: config.onsessioninitialized,
onsessionclosed: config.onsessionclosed
});

Expand Down Expand Up @@ -1666,6 +1669,213 @@ describe("StreamableHTTPServerTransport onsessionclosed callback", () => {
});
});

// Test async callbacks for onsessioninitialized and onsessionclosed
describe("StreamableHTTPServerTransport async callbacks", () => {
it("should support async onsessioninitialized callback", async () => {
const initializationOrder: string[] = [];

// Create server with async onsessioninitialized callback
const result = await createTestServer({
sessionIdGenerator: () => randomUUID(),
onsessioninitialized: async (sessionId: string) => {
initializationOrder.push('async-start');
// Simulate async operation
await new Promise(resolve => setTimeout(resolve, 10));
initializationOrder.push('async-end');
initializationOrder.push(sessionId);
},
});

const tempServer = result.server;
const tempUrl = result.baseUrl;

// Initialize to trigger the callback
const initResponse = await sendPostRequest(tempUrl, TEST_MESSAGES.initialize);
const tempSessionId = initResponse.headers.get("mcp-session-id");

// Give time for async callback to complete
await new Promise(resolve => setTimeout(resolve, 50));

expect(initializationOrder).toEqual(['async-start', 'async-end', tempSessionId]);

// Clean up
tempServer.close();
});

it("should support sync onsessioninitialized callback (backwards compatibility)", async () => {
const capturedSessionId: string[] = [];

// Create server with sync onsessioninitialized callback
const result = await createTestServer({
sessionIdGenerator: () => randomUUID(),
onsessioninitialized: (sessionId: string) => {
capturedSessionId.push(sessionId);
},
});

const tempServer = result.server;
const tempUrl = result.baseUrl;

// Initialize to trigger the callback
const initResponse = await sendPostRequest(tempUrl, TEST_MESSAGES.initialize);
const tempSessionId = initResponse.headers.get("mcp-session-id");

expect(capturedSessionId).toEqual([tempSessionId]);

// Clean up
tempServer.close();
});

it("should support async onsessionclosed callback", async () => {
const closureOrder: string[] = [];

// Create server with async onsessionclosed callback
const result = await createTestServer({
sessionIdGenerator: () => randomUUID(),
onsessionclosed: async (sessionId: string) => {
closureOrder.push('async-close-start');
// Simulate async operation
await new Promise(resolve => setTimeout(resolve, 10));
closureOrder.push('async-close-end');
closureOrder.push(sessionId);
},
});

const tempServer = result.server;
const tempUrl = result.baseUrl;

// Initialize to get a session ID
const initResponse = await sendPostRequest(tempUrl, TEST_MESSAGES.initialize);
const tempSessionId = initResponse.headers.get("mcp-session-id");
expect(tempSessionId).toBeDefined();

// DELETE the session
const deleteResponse = await fetch(tempUrl, {
method: "DELETE",
headers: {
"mcp-session-id": tempSessionId || "",
"mcp-protocol-version": "2025-03-26",
},
});

expect(deleteResponse.status).toBe(200);

// Give time for async callback to complete
await new Promise(resolve => setTimeout(resolve, 50));

expect(closureOrder).toEqual(['async-close-start', 'async-close-end', tempSessionId]);

// Clean up
tempServer.close();
});

it("should propagate errors from async onsessioninitialized callback", async () => {
const consoleErrorSpy = jest.spyOn(console, 'error').mockImplementation();

// Create server with async onsessioninitialized callback that throws
const result = await createTestServer({
sessionIdGenerator: () => randomUUID(),
onsessioninitialized: async (_sessionId: string) => {
throw new Error('Async initialization error');
},
});

const tempServer = result.server;
const tempUrl = result.baseUrl;

// Initialize should fail when callback throws
const initResponse = await sendPostRequest(tempUrl, TEST_MESSAGES.initialize);
expect(initResponse.status).toBe(400);

// Clean up
consoleErrorSpy.mockRestore();
tempServer.close();
});

it("should propagate errors from async onsessionclosed callback", async () => {
const consoleErrorSpy = jest.spyOn(console, 'error').mockImplementation();

// Create server with async onsessionclosed callback that throws
const result = await createTestServer({
sessionIdGenerator: () => randomUUID(),
onsessionclosed: async (_sessionId: string) => {
throw new Error('Async closure error');
},
});

const tempServer = result.server;
const tempUrl = result.baseUrl;

// Initialize to get a session ID
const initResponse = await sendPostRequest(tempUrl, TEST_MESSAGES.initialize);
const tempSessionId = initResponse.headers.get("mcp-session-id");

// DELETE should fail when callback throws
const deleteResponse = await fetch(tempUrl, {
method: "DELETE",
headers: {
"mcp-session-id": tempSessionId || "",
"mcp-protocol-version": "2025-03-26",
},
});

expect(deleteResponse.status).toBe(500);

// Clean up
consoleErrorSpy.mockRestore();
tempServer.close();
});

it("should handle both async callbacks together", async () => {
const events: string[] = [];

// Create server with both async callbacks
const result = await createTestServer({
sessionIdGenerator: () => randomUUID(),
onsessioninitialized: async (sessionId: string) => {
await new Promise(resolve => setTimeout(resolve, 5));
events.push(`initialized:${sessionId}`);
},
onsessionclosed: async (sessionId: string) => {
await new Promise(resolve => setTimeout(resolve, 5));
events.push(`closed:${sessionId}`);
},
});

const tempServer = result.server;
const tempUrl = result.baseUrl;

// Initialize to trigger first callback
const initResponse = await sendPostRequest(tempUrl, TEST_MESSAGES.initialize);
const tempSessionId = initResponse.headers.get("mcp-session-id");

// Wait for async callback
await new Promise(resolve => setTimeout(resolve, 20));

expect(events).toContain(`initialized:${tempSessionId}`);

// DELETE to trigger second callback
const deleteResponse = await fetch(tempUrl, {
method: "DELETE",
headers: {
"mcp-session-id": tempSessionId || "",
"mcp-protocol-version": "2025-03-26",
},
});

expect(deleteResponse.status).toBe(200);

// Wait for async callback
await new Promise(resolve => setTimeout(resolve, 20));

expect(events).toContain(`closed:${tempSessionId}`);
expect(events).toHaveLength(2);

// Clean up
tempServer.close();
});
});

// Test DNS rebinding protection
describe("StreamableHTTPServerTransport DNS rebinding protection", () => {
let server: Server;
Expand Down
12 changes: 6 additions & 6 deletions src/server/streamableHttp.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ export interface StreamableHTTPServerTransportOptions {
* and need to keep track of them.
* @param sessionId The generated session ID
*/
onsessioninitialized?: (sessionId: string) => void;
onsessioninitialized?: (sessionId: string) => void | Promise<void>;

/**
* A callback for session close events
Expand All @@ -59,7 +59,7 @@ export interface StreamableHTTPServerTransportOptions {
* session open/running.
* @param sessionId The session ID that was closed
*/
onsessionclosed?: (sessionId: string) => void;
onsessionclosed?: (sessionId: string) => void | Promise<void>;

/**
* If true, the server will return JSON responses instead of starting an SSE stream.
Expand Down Expand Up @@ -138,8 +138,8 @@ export class StreamableHTTPServerTransport implements Transport {
private _enableJsonResponse: boolean = false;
private _standaloneSseStreamId: string = '_GET_stream';
private _eventStore?: EventStore;
private _onsessioninitialized?: (sessionId: string) => void;
private _onsessionclosed?: (sessionId: string) => void;
private _onsessioninitialized?: (sessionId: string) => void | Promise<void>;
private _onsessionclosed?: (sessionId: string) => void | Promise<void>;
private _allowedHosts?: string[];
private _allowedOrigins?: string[];
private _enableDnsRebindingProtection: boolean;
Expand Down Expand Up @@ -460,7 +460,7 @@ export class StreamableHTTPServerTransport implements Transport {
// If we have a session ID and an onsessioninitialized handler, call it immediately
// This is needed in cases where the server needs to keep track of multiple sessions
if (this.sessionId && this._onsessioninitialized) {
this._onsessioninitialized(this.sessionId);
await Promise.resolve(this._onsessioninitialized(this.sessionId));
}

}
Expand Down Expand Up @@ -552,7 +552,7 @@ export class StreamableHTTPServerTransport implements Transport {
if (!this.validateProtocolVersion(req, res)) {
return;
}
this._onsessionclosed?.(this.sessionId!);
await Promise.resolve(this._onsessionclosed?.(this.sessionId!));
await this.close();
res.writeHead(200).end();
}
Expand Down