Skip to content
Merged
22 changes: 21 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,11 @@ app.post('/mcp', async (req, res) => {
onsessioninitialized: (sessionId) => {
// Store the transport by session ID
transports[sessionId] = transport;
}
},
// DNS rebinding protection is disabled by default for backwards compatibility. If you are running this server
// locally, make sure to set:
// enableDnsRebindingProtection: true,
// allowedHosts: ['127.0.0.1'],
});

// Clean up transport when closed
Expand Down Expand Up @@ -596,6 +600,22 @@ This stateless approach is useful for:
- RESTful scenarios where each request is independent
- Horizontally scaled deployments without shared session state

#### DNS Rebinding Protection

The Streamable HTTP transport includes DNS rebinding protection to prevent security vulnerabilities. By default, this protection is **disabled** for backwards compatibility.

**Important**: If you are running this server locally, enable DNS rebinding protection:

```typescript
const transport = new StreamableHTTPServerTransport({
sessionIdGenerator: () => randomUUID(),
enableDnsRebindingProtection: true,

allowedHosts: ['127.0.0.1', ...],
allowedOrigins: ['https://yourdomain.com', 'https://www.yourdomain.com']
});
```

### Testing and Debugging

To test your server, you can use the [MCP Inspector](https://github.com/modelcontextprotocol/inspector). See its README for more information.
Expand Down
262 changes: 261 additions & 1 deletion src/server/sse.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -453,4 +453,264 @@ describe('SSEServerTransport', () => {
expect.stringContaining(`data: /messages?sessionId=${transport.sessionId}`));
});
});
});

describe('DNS rebinding protection', () => {
beforeEach(() => {
jest.clearAllMocks();
});

describe('Host header validation', () => {
it('should accept requests with allowed host headers', async () => {
const mockRes = createMockResponse();
const transport = new SSEServerTransport('/messages', mockRes, {
allowedHosts: ['localhost:3000', 'example.com'],
enableDnsRebindingProtection: true,
});
await transport.start();

const mockReq = createMockRequest({
headers: {
host: 'localhost:3000',
'content-type': 'application/json',
}
});
const mockHandleRes = createMockResponse();

await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' });

expect(mockHandleRes.writeHead).toHaveBeenCalledWith(202);
expect(mockHandleRes.end).toHaveBeenCalledWith('Accepted');
});

it('should reject requests with disallowed host headers', async () => {
const mockRes = createMockResponse();
const transport = new SSEServerTransport('/messages', mockRes, {
allowedHosts: ['localhost:3000'],
enableDnsRebindingProtection: true,
});
await transport.start();

const mockReq = createMockRequest({
headers: {
host: 'evil.com',
'content-type': 'application/json',
}
});
const mockHandleRes = createMockResponse();

await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' });

expect(mockHandleRes.writeHead).toHaveBeenCalledWith(403);
expect(mockHandleRes.end).toHaveBeenCalledWith('Invalid Host header: evil.com');
});

it('should reject requests without host header when allowedHosts is configured', async () => {
const mockRes = createMockResponse();
const transport = new SSEServerTransport('/messages', mockRes, {
allowedHosts: ['localhost:3000'],
enableDnsRebindingProtection: true,
});
await transport.start();

const mockReq = createMockRequest({
headers: {
'content-type': 'application/json',
}
});
const mockHandleRes = createMockResponse();

await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' });

expect(mockHandleRes.writeHead).toHaveBeenCalledWith(403);
expect(mockHandleRes.end).toHaveBeenCalledWith('Invalid Host header: undefined');
});
});

describe('Origin header validation', () => {
it('should accept requests with allowed origin headers', async () => {
const mockRes = createMockResponse();
const transport = new SSEServerTransport('/messages', mockRes, {
allowedOrigins: ['http://localhost:3000', 'https://example.com'],
enableDnsRebindingProtection: true,
});
await transport.start();

const mockReq = createMockRequest({
headers: {
origin: 'http://localhost:3000',
'content-type': 'application/json',
}
});
const mockHandleRes = createMockResponse();

await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' });

expect(mockHandleRes.writeHead).toHaveBeenCalledWith(202);
expect(mockHandleRes.end).toHaveBeenCalledWith('Accepted');
});

it('should reject requests with disallowed origin headers', async () => {
const mockRes = createMockResponse();
const transport = new SSEServerTransport('/messages', mockRes, {
allowedOrigins: ['http://localhost:3000'],
enableDnsRebindingProtection: true,
});
await transport.start();

const mockReq = createMockRequest({
headers: {
origin: 'http://evil.com',
'content-type': 'application/json',
}
});
const mockHandleRes = createMockResponse();

await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' });

expect(mockHandleRes.writeHead).toHaveBeenCalledWith(403);
expect(mockHandleRes.end).toHaveBeenCalledWith('Invalid Origin header: http://evil.com');
});
});

describe('Content-Type validation', () => {
it('should accept requests with application/json content-type', async () => {
const mockRes = createMockResponse();
const transport = new SSEServerTransport('/messages', mockRes);
await transport.start();

const mockReq = createMockRequest({
headers: {
'content-type': 'application/json',
}
});
const mockHandleRes = createMockResponse();

await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' });

expect(mockHandleRes.writeHead).toHaveBeenCalledWith(202);
expect(mockHandleRes.end).toHaveBeenCalledWith('Accepted');
});

it('should accept requests with application/json with charset', async () => {
const mockRes = createMockResponse();
const transport = new SSEServerTransport('/messages', mockRes);
await transport.start();

const mockReq = createMockRequest({
headers: {
'content-type': 'application/json; charset=utf-8',
}
});
const mockHandleRes = createMockResponse();

await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' });

expect(mockHandleRes.writeHead).toHaveBeenCalledWith(202);
expect(mockHandleRes.end).toHaveBeenCalledWith('Accepted');
});

it('should reject requests with non-application/json content-type when protection is enabled', async () => {
const mockRes = createMockResponse();
const transport = new SSEServerTransport('/messages', mockRes);
await transport.start();

const mockReq = createMockRequest({
headers: {
'content-type': 'text/plain',
}
});
const mockHandleRes = createMockResponse();

await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' });

expect(mockHandleRes.writeHead).toHaveBeenCalledWith(400);
expect(mockHandleRes.end).toHaveBeenCalledWith('Error: Unsupported content-type: text/plain');
});
});

describe('enableDnsRebindingProtection option', () => {
it('should skip all validations when enableDnsRebindingProtection is false', async () => {
const mockRes = createMockResponse();
const transport = new SSEServerTransport('/messages', mockRes, {
allowedHosts: ['localhost:3000'],
allowedOrigins: ['http://localhost:3000'],
enableDnsRebindingProtection: false,
});
await transport.start();

const mockReq = createMockRequest({
headers: {
host: 'evil.com',
origin: 'http://evil.com',
'content-type': 'text/plain',
}
});
const mockHandleRes = createMockResponse();

await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' });

// Should pass even with invalid headers because protection is disabled
expect(mockHandleRes.writeHead).toHaveBeenCalledWith(400);
// The error should be from content-type parsing, not DNS rebinding protection
expect(mockHandleRes.end).toHaveBeenCalledWith('Error: Unsupported content-type: text/plain');
});
});

describe('Combined validations', () => {
it('should validate both host and origin when both are configured', async () => {
const mockRes = createMockResponse();
const transport = new SSEServerTransport('/messages', mockRes, {
allowedHosts: ['localhost:3000'],
allowedOrigins: ['http://localhost:3000'],
enableDnsRebindingProtection: true,
});
await transport.start();

// Valid host, invalid origin
const mockReq1 = createMockRequest({
headers: {
host: 'localhost:3000',
origin: 'http://evil.com',
'content-type': 'application/json',
}
});
const mockHandleRes1 = createMockResponse();

await transport.handlePostMessage(mockReq1, mockHandleRes1, { jsonrpc: '2.0', method: 'test' });

expect(mockHandleRes1.writeHead).toHaveBeenCalledWith(403);
expect(mockHandleRes1.end).toHaveBeenCalledWith('Invalid Origin header: http://evil.com');

// Invalid host, valid origin
const mockReq2 = createMockRequest({
headers: {
host: 'evil.com',
origin: 'http://localhost:3000',
'content-type': 'application/json',
}
});
const mockHandleRes2 = createMockResponse();

await transport.handlePostMessage(mockReq2, mockHandleRes2, { jsonrpc: '2.0', method: 'test' });

expect(mockHandleRes2.writeHead).toHaveBeenCalledWith(403);
expect(mockHandleRes2.end).toHaveBeenCalledWith('Invalid Host header: evil.com');

// Both valid
const mockReq3 = createMockRequest({
headers: {
host: 'localhost:3000',
origin: 'http://localhost:3000',
'content-type': 'application/json',
}
});
const mockHandleRes3 = createMockResponse();

await transport.handlePostMessage(mockReq3, mockHandleRes3, { jsonrpc: '2.0', method: 'test' });

expect(mockHandleRes3.writeHead).toHaveBeenCalledWith(202);
expect(mockHandleRes3.end).toHaveBeenCalledWith('Accepted');
});
});
});
});
64 changes: 64 additions & 0 deletions src/server/sse.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,29 @@ import { URL } from 'url';

const MAXIMUM_MESSAGE_SIZE = "4mb";

/**
* Configuration options for SSEServerTransport.
*/
export interface SSEServerTransportOptions {
/**
* List of allowed host header values for DNS rebinding protection.
* If not specified, host validation is disabled.
*/
allowedHosts?: string[];

/**
* List of allowed origin header values for DNS rebinding protection.
* If not specified, origin validation is disabled.
*/
allowedOrigins?: string[];

/**
* Enable DNS rebinding protection (requires allowedHosts and/or allowedOrigins to be configured).
* Default is false for backwards compatibility.
*/
enableDnsRebindingProtection?: boolean;
}

/**
* Server transport for SSE: this will send messages over an SSE connection and receive messages from HTTP POST requests.
*
Expand All @@ -17,6 +40,7 @@ const MAXIMUM_MESSAGE_SIZE = "4mb";
export class SSEServerTransport implements Transport {
private _sseResponse?: ServerResponse;
private _sessionId: string;
private _options: SSEServerTransportOptions;
onclose?: () => void;
onerror?: (error: Error) => void;
onmessage?: (message: JSONRPCMessage, extra?: MessageExtraInfo) => void;
Expand All @@ -27,8 +51,39 @@ export class SSEServerTransport implements Transport {
constructor(
private _endpoint: string,
private res: ServerResponse,
options?: SSEServerTransportOptions,
) {
this._sessionId = randomUUID();
this._options = options || {enableDnsRebindingProtection: false};
}

/**
* Validates request headers for DNS rebinding protection.
* @returns Error message if validation fails, undefined if validation passes.
*/
private validateRequestHeaders(req: IncomingMessage): string | undefined {
// Skip validation if protection is not enabled
if (!this._options.enableDnsRebindingProtection) {
return undefined;
}

// Validate Host header if allowedHosts is configured
if (this._options.allowedHosts && this._options.allowedHosts.length > 0) {
const hostHeader = req.headers.host;
if (!hostHeader || !this._options.allowedHosts.includes(hostHeader)) {
return `Invalid Host header: ${hostHeader}`;
}
}

// Validate Origin header if allowedOrigins is configured
if (this._options.allowedOrigins && this._options.allowedOrigins.length > 0) {
const originHeader = req.headers.origin;
if (!originHeader || !this._options.allowedOrigins.includes(originHeader)) {
return `Invalid Origin header: ${originHeader}`;
}
}

return undefined;
}

/**
Expand Down Expand Up @@ -85,6 +140,15 @@ export class SSEServerTransport implements Transport {
res.writeHead(500).end(message);
throw new Error(message);
}

// Validate request headers for DNS rebinding protection
const validationError = this.validateRequestHeaders(req);
if (validationError) {
res.writeHead(403).end(validationError);
this.onerror?.(new Error(validationError));
return;
}

const authInfo: AuthInfo | undefined = req.auth;
const requestInfo: RequestInfo = { headers: req.headers };

Expand Down
Loading
Loading