From 23729304fb332ba07e80fdea37a0d22a5c02b55e Mon Sep 17 00:00:00 2001 From: Himanshu Singh Date: Tue, 26 Aug 2025 11:51:43 +0200 Subject: [PATCH 01/13] chore: rename connectionManager to mcpConnectionManager --- ...tionManager.ts => mcpConnectionManager.ts} | 11 ++++--- src/common/session.ts | 8 ++--- src/tools/atlas/connect/connectCluster.ts | 2 +- src/transports/base.ts | 4 +-- .../common/connectionManager.oidc.test.ts | 2 +- .../common/connectionManager.test.ts | 29 ++++++++++--------- tests/integration/helpers.ts | 8 ++--- tests/integration/telemetry.test.ts | 4 +-- tests/unit/common/session.test.ts | 4 +-- tests/unit/resources/common/debug.test.ts | 4 +-- 10 files changed, 41 insertions(+), 35 deletions(-) rename src/common/{connectionManager.ts => mcpConnectionManager.ts} (96%) diff --git a/src/common/connectionManager.ts b/src/common/mcpConnectionManager.ts similarity index 96% rename from src/common/connectionManager.ts rename to src/common/mcpConnectionManager.ts index 23183768..2d2c2be8 100644 --- a/src/common/connectionManager.ts +++ b/src/common/mcpConnectionManager.ts @@ -63,7 +63,7 @@ export type AnyConnectionState = | ConnectionStateDisconnected | ConnectionStateErrored; -export interface ConnectionManagerEvents { +export interface MCPConnectionManagerEvents { "connection-requested": [AnyConnectionState]; "connection-succeeded": [ConnectionStateConnected]; "connection-timed-out": [ConnectionStateErrored]; @@ -71,7 +71,7 @@ export interface ConnectionManagerEvents { "connection-errored": [ConnectionStateErrored]; } -export class ConnectionManager extends EventEmitter { +export class MCPConnectionManager extends EventEmitter { private state: AnyConnectionState; private deviceId: DeviceId; private clientName: string; @@ -158,7 +158,10 @@ export class ConnectionManager extends EventEmitter { } try { - const connectionType = ConnectionManager.inferConnectionTypeFromSettings(this.userConfig, connectionInfo); + const connectionType = MCPConnectionManager.inferConnectionTypeFromSettings( + this.userConfig, + connectionInfo + ); if (connectionType.startsWith("oidc")) { void this.pingAndForget(serviceProvider); @@ -212,7 +215,7 @@ export class ConnectionManager extends EventEmitter { return this.state; } - changeState( + changeState( event: Event, newState: State ): State { diff --git a/src/common/session.ts b/src/common/session.ts index 5080c05a..40594cbb 100644 --- a/src/common/session.ts +++ b/src/common/session.ts @@ -7,10 +7,10 @@ import { LogId } from "./logger.js"; import EventEmitter from "events"; import type { AtlasClusterConnectionInfo, - ConnectionManager, + MCPConnectionManager, ConnectionSettings, ConnectionStateConnected, -} from "./connectionManager.js"; +} from "./mcpConnectionManager.js"; import type { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver"; import { ErrorCodes, MongoDBError } from "./errors.js"; import type { ExportsManager } from "./exportsManager.js"; @@ -21,7 +21,7 @@ export interface SessionOptions { apiClientSecret?: string; logger: CompositeLogger; exportsManager: ExportsManager; - connectionManager: ConnectionManager; + connectionManager: MCPConnectionManager; } export type SessionEvents = { @@ -34,7 +34,7 @@ export type SessionEvents = { export class Session extends EventEmitter { readonly sessionId: string = new ObjectId().toString(); readonly exportsManager: ExportsManager; - readonly connectionManager: ConnectionManager; + readonly connectionManager: MCPConnectionManager; readonly apiClient: ApiClient; mcpClient?: { name?: string; diff --git a/src/tools/atlas/connect/connectCluster.ts b/src/tools/atlas/connect/connectCluster.ts index 9695ff36..e7445d92 100644 --- a/src/tools/atlas/connect/connectCluster.ts +++ b/src/tools/atlas/connect/connectCluster.ts @@ -6,7 +6,7 @@ import { generateSecurePassword } from "../../../helpers/generatePassword.js"; import { LogId } from "../../../common/logger.js"; import { inspectCluster } from "../../../common/atlas/cluster.js"; import { ensureCurrentIpInAccessList } from "../../../common/atlas/accessListUtils.js"; -import type { AtlasClusterConnectionInfo } from "../../../common/connectionManager.js"; +import type { AtlasClusterConnectionInfo } from "../../../common/mcpConnectionManager.js"; import { getDefaultRoleFromConfig } from "../../../common/atlas/roles.js"; const EXPIRY_MS = 1000 * 60 * 60 * 12; // 12 hours diff --git a/src/transports/base.ts b/src/transports/base.ts index 485752e7..02db4663 100644 --- a/src/transports/base.ts +++ b/src/transports/base.ts @@ -7,7 +7,7 @@ import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; import type { LoggerBase } from "../common/logger.js"; import { CompositeLogger, ConsoleLogger, DiskLogger, McpLogger } from "../common/logger.js"; import { ExportsManager } from "../common/exportsManager.js"; -import { ConnectionManager } from "../common/connectionManager.js"; +import { MCPConnectionManager } from "../common/mcpConnectionManager.js"; import { DeviceId } from "../helpers/deviceId.js"; export abstract class TransportRunnerBase { @@ -46,7 +46,7 @@ export abstract class TransportRunnerBase { const logger = new CompositeLogger(this.logger); const exportsManager = ExportsManager.init(this.userConfig, logger); - const connectionManager = new ConnectionManager(this.userConfig, this.driverOptions, logger, this.deviceId); + const connectionManager = new MCPConnectionManager(this.userConfig, this.driverOptions, logger, this.deviceId); const session = new Session({ apiBaseUrl: this.userConfig.apiBaseUrl, diff --git a/tests/integration/common/connectionManager.oidc.test.ts b/tests/integration/common/connectionManager.oidc.test.ts index e3a406eb..1f9fa263 100644 --- a/tests/integration/common/connectionManager.oidc.test.ts +++ b/tests/integration/common/connectionManager.oidc.test.ts @@ -5,7 +5,7 @@ import process from "process"; import type { MongoDBIntegrationTestCase } from "../tools/mongodb/mongodbHelpers.js"; import { describeWithMongoDB, isCommunityServer, getServerVersion } from "../tools/mongodb/mongodbHelpers.js"; import { defaultTestConfig, responseAsText, timeout, waitUntil } from "../helpers.js"; -import type { ConnectionStateConnected, ConnectionStateConnecting } from "../../../src/common/connectionManager.js"; +import type { ConnectionStateConnected, ConnectionStateConnecting } from "../../../src/common/mcpConnectionManager.js"; import type { UserConfig } from "../../../src/common/config.js"; import { setupDriverConfig } from "../../../src/common/config.js"; import path from "path"; diff --git a/tests/integration/common/connectionManager.test.ts b/tests/integration/common/connectionManager.test.ts index 40386fcd..ccea2010 100644 --- a/tests/integration/common/connectionManager.test.ts +++ b/tests/integration/common/connectionManager.test.ts @@ -1,15 +1,15 @@ import type { - ConnectionManagerEvents, + MCPConnectionManagerEvents, ConnectionStateConnected, ConnectionStringAuthType, -} from "../../../src/common/connectionManager.js"; -import { ConnectionManager } from "../../../src/common/connectionManager.js"; +} from "../../../src/common/mcpConnectionManager.js"; +import { MCPConnectionManager } from "../../../src/common/mcpConnectionManager.js"; import type { UserConfig } from "../../../src/common/config.js"; import { describeWithMongoDB } from "../tools/mongodb/mongodbHelpers.js"; import { describe, beforeEach, expect, it, vi, afterEach } from "vitest"; describeWithMongoDB("Connection Manager", (integration) => { - function connectionManager(): ConnectionManager { + function connectionManager(): MCPConnectionManager { return integration.mcpServer().session.connectionManager; } @@ -24,11 +24,11 @@ describeWithMongoDB("Connection Manager", (integration) => { describe("when successfully connected", () => { type ConnectionManagerSpies = { - "connection-requested": (event: ConnectionManagerEvents["connection-requested"][0]) => void; - "connection-succeeded": (event: ConnectionManagerEvents["connection-succeeded"][0]) => void; - "connection-timed-out": (event: ConnectionManagerEvents["connection-timed-out"][0]) => void; - "connection-closed": (event: ConnectionManagerEvents["connection-closed"][0]) => void; - "connection-errored": (event: ConnectionManagerEvents["connection-errored"][0]) => void; + "connection-requested": (event: MCPConnectionManagerEvents["connection-requested"][0]) => void; + "connection-succeeded": (event: MCPConnectionManagerEvents["connection-succeeded"][0]) => void; + "connection-timed-out": (event: MCPConnectionManagerEvents["connection-timed-out"][0]) => void; + "connection-closed": (event: MCPConnectionManagerEvents["connection-closed"][0]) => void; + "connection-errored": (event: MCPConnectionManagerEvents["connection-errored"][0]) => void; }; let connectionManagerSpies: ConnectionManagerSpies; @@ -43,7 +43,7 @@ describeWithMongoDB("Connection Manager", (integration) => { }; for (const [event, spy] of Object.entries(connectionManagerSpies)) { - connectionManager().on(event as keyof ConnectionManagerEvents, spy); + connectionManager().on(event as keyof MCPConnectionManagerEvents, spy); } await connectionManager().connect({ @@ -182,9 +182,12 @@ describe("Connection Manager connection type inference", () => { for (const { userConfig, connectionString, connectionType } of testCases) { it(`infers ${connectionType} from ${connectionString}`, () => { - const actualConnectionType = ConnectionManager.inferConnectionTypeFromSettings(userConfig as UserConfig, { - connectionString, - }); + const actualConnectionType = MCPConnectionManager.inferConnectionTypeFromSettings( + userConfig as UserConfig, + { + connectionString, + } + ); expect(actualConnectionType).toBe(connectionType); }); diff --git a/tests/integration/helpers.ts b/tests/integration/helpers.ts index b67fbc16..5c267cee 100644 --- a/tests/integration/helpers.ts +++ b/tests/integration/helpers.ts @@ -10,8 +10,8 @@ import type { UserConfig, DriverOptions } from "../../src/common/config.js"; import { McpError, ResourceUpdatedNotificationSchema } from "@modelcontextprotocol/sdk/types.js"; import { config, driverOptions } from "../../src/common/config.js"; import { afterAll, afterEach, beforeAll, describe, expect, it, vi } from "vitest"; -import type { ConnectionState } from "../../src/common/connectionManager.js"; -import { ConnectionManager } from "../../src/common/connectionManager.js"; +import type { ConnectionState } from "../../src/common/mcpConnectionManager.js"; +import { MCPConnectionManager } from "../../src/common/mcpConnectionManager.js"; import { DeviceId } from "../../src/helpers/deviceId.js"; interface ParameterInfo { @@ -72,7 +72,7 @@ export function setupIntegrationTest( const exportsManager = ExportsManager.init(userConfig, logger); deviceId = DeviceId.create(logger); - const connectionManager = new ConnectionManager(userConfig, driverOptions, logger, deviceId); + const connectionManager = new MCPConnectionManager(userConfig, driverOptions, logger, deviceId); const session = new Session({ apiBaseUrl: userConfig.apiBaseUrl, @@ -315,7 +315,7 @@ export function responseAsText(response: Awaited> export function waitUntil( tag: T["tag"], - cm: ConnectionManager, + cm: MCPConnectionManager, signal: AbortSignal, additionalCondition?: (state: T) => boolean ): Promise { diff --git a/tests/integration/telemetry.test.ts b/tests/integration/telemetry.test.ts index cc51ed8b..29a78469 100644 --- a/tests/integration/telemetry.test.ts +++ b/tests/integration/telemetry.test.ts @@ -4,7 +4,7 @@ import { config, driverOptions } from "../../src/common/config.js"; import { DeviceId } from "../../src/helpers/deviceId.js"; import { describe, expect, it } from "vitest"; import { CompositeLogger } from "../../src/common/logger.js"; -import { ConnectionManager } from "../../src/common/connectionManager.js"; +import { MCPConnectionManager } from "../../src/common/mcpConnectionManager.js"; import { ExportsManager } from "../../src/common/exportsManager.js"; describe("Telemetry", () => { @@ -19,7 +19,7 @@ describe("Telemetry", () => { apiBaseUrl: "", logger, exportsManager: ExportsManager.init(config, logger), - connectionManager: new ConnectionManager(config, driverOptions, logger, deviceId), + connectionManager: new MCPConnectionManager(config, driverOptions, logger, deviceId), }), config, deviceId diff --git a/tests/unit/common/session.test.ts b/tests/unit/common/session.test.ts index 6b2b3552..29075969 100644 --- a/tests/unit/common/session.test.ts +++ b/tests/unit/common/session.test.ts @@ -4,7 +4,7 @@ import { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver import { Session } from "../../../src/common/session.js"; import { config, driverOptions } from "../../../src/common/config.js"; import { CompositeLogger } from "../../../src/common/logger.js"; -import { ConnectionManager } from "../../../src/common/connectionManager.js"; +import { MCPConnectionManager } from "../../../src/common/mcpConnectionManager.js"; import { ExportsManager } from "../../../src/common/exportsManager.js"; import { DeviceId } from "../../../src/helpers/deviceId.js"; @@ -27,7 +27,7 @@ describe("Session", () => { apiBaseUrl: "https://api.test.com", logger, exportsManager: ExportsManager.init(config, logger), - connectionManager: new ConnectionManager(config, driverOptions, logger, mockDeviceId), + connectionManager: new MCPConnectionManager(config, driverOptions, logger, mockDeviceId), }); MockNodeDriverServiceProvider.connect = vi.fn().mockResolvedValue({} as unknown as NodeDriverServiceProvider); diff --git a/tests/unit/resources/common/debug.test.ts b/tests/unit/resources/common/debug.test.ts index 0292a726..59c0d3aa 100644 --- a/tests/unit/resources/common/debug.test.ts +++ b/tests/unit/resources/common/debug.test.ts @@ -4,7 +4,7 @@ import { Session } from "../../../../src/common/session.js"; import { Telemetry } from "../../../../src/telemetry/telemetry.js"; import { config, driverOptions } from "../../../../src/common/config.js"; import { CompositeLogger } from "../../../../src/common/logger.js"; -import { ConnectionManager } from "../../../../src/common/connectionManager.js"; +import { MCPConnectionManager } from "../../../../src/common/mcpConnectionManager.js"; import { ExportsManager } from "../../../../src/common/exportsManager.js"; import { DeviceId } from "../../../../src/helpers/deviceId.js"; @@ -15,7 +15,7 @@ describe("debug resource", () => { apiBaseUrl: "", logger, exportsManager: ExportsManager.init(config, logger), - connectionManager: new ConnectionManager(config, driverOptions, logger, deviceId), + connectionManager: new MCPConnectionManager(config, driverOptions, logger, deviceId), }); const telemetry = Telemetry.create(session, { ...config, telemetry: "disabled" }, deviceId); From 47c6a56a1f1ab252cf9770fd0ab716f029ab7100 Mon Sep 17 00:00:00 2001 From: Himanshu Singh Date: Wed, 27 Aug 2025 11:15:41 +0200 Subject: [PATCH 02/13] chore: extract abstract ConnectionManager This commit extracts an abstract class ConnectionManager out of MCPConnectionManager and modifies the TransportRunner interface to have the ConnectionManager implementation injected through a factory function. Contains a small drive by fix for making the ConnectionManager event emitting internal to the class itself. --- package-lock.json | 1 - src/common/connectionManager.ts | 92 ++++++++++++++ src/common/mcpConnectionManager.ts | 116 ++++-------------- src/common/session.ts | 26 ++-- src/index.ts | 10 +- src/lib.ts | 4 +- src/tools/atlas/connect/connectCluster.ts | 2 +- src/transports/base.ts | 19 +-- src/transports/stdio.ts | 17 +-- src/transports/streamableHttp.ts | 11 +- tests/integration/build.test.ts | 7 +- .../common/connectionManager.oidc.test.ts | 2 +- .../common/connectionManager.test.ts | 19 +-- tests/integration/helpers.ts | 4 +- .../transports/streamableHttp.test.ts | 22 +++- 15 files changed, 195 insertions(+), 157 deletions(-) create mode 100644 src/common/connectionManager.ts diff --git a/package-lock.json b/package-lock.json index d3148181..c3bb5783 100644 --- a/package-lock.json +++ b/package-lock.json @@ -18,7 +18,6 @@ "@vitest/eslint-plugin": "^1.3.4", "bson": "^6.10.4", "express": "^5.1.0", - "kerberos": "*", "lru-cache": "^11.1.0", "mongodb": "^6.17.0", "mongodb-connection-string-url": "^3.0.2", diff --git a/src/common/connectionManager.ts b/src/common/connectionManager.ts new file mode 100644 index 00000000..ae976171 --- /dev/null +++ b/src/common/connectionManager.ts @@ -0,0 +1,92 @@ +import { EventEmitter } from "events"; +import type { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver"; + +export interface AtlasClusterConnectionInfo { + username: string; + projectId: string; + clusterName: string; + expiryDate: Date; +} + +type ConnectionTag = "connected" | "connecting" | "disconnected" | "errored"; +export type OIDCConnectionAuthType = "oidc-auth-flow" | "oidc-device-flow"; +export type ConnectionStringAuthType = "scram" | "ldap" | "kerberos" | OIDCConnectionAuthType | "x.509"; + +export interface ConnectionState { + tag: ConnectionTag; + connectionStringAuthType?: ConnectionStringAuthType; + connectedAtlasCluster?: AtlasClusterConnectionInfo; +} + +export interface ConnectionStateConnected extends ConnectionState { + tag: "connected"; + serviceProvider: NodeDriverServiceProvider; +} + +export interface ConnectionStateConnecting extends ConnectionState { + tag: "connecting"; + serviceProvider: NodeDriverServiceProvider; + oidcConnectionType: OIDCConnectionAuthType; + oidcLoginUrl?: string; + oidcUserCode?: string; +} + +export interface ConnectionStateDisconnected extends ConnectionState { + tag: "disconnected"; +} + +export interface ConnectionStateErrored extends ConnectionState { + tag: "errored"; + errorReason: string; +} + +export type AnyConnectionState = + | ConnectionStateConnected + | ConnectionStateConnecting + | ConnectionStateDisconnected + | ConnectionStateErrored; + +export interface ConnectionManagerEvents { + "connection-requested": [AnyConnectionState]; + "connection-succeeded": [ConnectionStateConnected]; + "connection-timed-out": [ConnectionStateErrored]; + "connection-closed": [ConnectionStateDisconnected]; + "connection-errored": [ConnectionStateErrored]; +} + +export interface MCPConnectParams { + connectionString: string; + atlas?: AtlasClusterConnectionInfo; +} + +export abstract class ConnectionManager { + protected clientName: string = "unknown"; + + protected readonly _events = new EventEmitter(); + readonly events: Pick, "on" | "off" | "once"> = this._events; + + protected state: AnyConnectionState = { tag: "disconnected" }; + + get currentConnectionState(): AnyConnectionState { + return this.state; + } + + changeState( + event: Event, + newState: State + ): State { + this.state = newState; + // TypeScript doesn't seem to be happy with the spread operator and generics + // eslint-disable-next-line + this._events.emit(event, ...([newState] as any)); + return newState; + } + + setClientName(clientName: string): void { + this.clientName = clientName; + } + + abstract connect(connectParams: ConnectParams): Promise; + + abstract disconnect(): Promise; +} diff --git a/src/common/mcpConnectionManager.ts b/src/common/mcpConnectionManager.ts index 2d2c2be8..4119cc12 100644 --- a/src/common/mcpConnectionManager.ts +++ b/src/common/mcpConnectionManager.ts @@ -12,69 +12,18 @@ import type { CompositeLogger } from "./logger.js"; import { LogId } from "./logger.js"; import type { ConnectionInfo } from "@mongosh/arg-parser"; import { generateConnectionInfoFromCliArgs } from "@mongosh/arg-parser"; - -export interface AtlasClusterConnectionInfo { - username: string; - projectId: string; - clusterName: string; - expiryDate: Date; -} - -export interface ConnectionSettings { - connectionString: string; - atlas?: AtlasClusterConnectionInfo; -} - -type ConnectionTag = "connected" | "connecting" | "disconnected" | "errored"; -type OIDCConnectionAuthType = "oidc-auth-flow" | "oidc-device-flow"; -export type ConnectionStringAuthType = "scram" | "ldap" | "kerberos" | OIDCConnectionAuthType | "x.509"; - -export interface ConnectionState { - tag: ConnectionTag; - connectionStringAuthType?: ConnectionStringAuthType; - connectedAtlasCluster?: AtlasClusterConnectionInfo; -} - -export interface ConnectionStateConnected extends ConnectionState { - tag: "connected"; - serviceProvider: NodeDriverServiceProvider; -} - -export interface ConnectionStateConnecting extends ConnectionState { - tag: "connecting"; - serviceProvider: NodeDriverServiceProvider; - oidcConnectionType: OIDCConnectionAuthType; - oidcLoginUrl?: string; - oidcUserCode?: string; -} - -export interface ConnectionStateDisconnected extends ConnectionState { - tag: "disconnected"; -} - -export interface ConnectionStateErrored extends ConnectionState { - tag: "errored"; - errorReason: string; -} - -export type AnyConnectionState = - | ConnectionStateConnected - | ConnectionStateConnecting - | ConnectionStateDisconnected - | ConnectionStateErrored; - -export interface MCPConnectionManagerEvents { - "connection-requested": [AnyConnectionState]; - "connection-succeeded": [ConnectionStateConnected]; - "connection-timed-out": [ConnectionStateErrored]; - "connection-closed": [ConnectionStateDisconnected]; - "connection-errored": [ConnectionStateErrored]; -} - -export class MCPConnectionManager extends EventEmitter { - private state: AnyConnectionState; +import { + ConnectionManager, + type AnyConnectionState, + type ConnectionStringAuthType, + type OIDCConnectionAuthType, + type ConnectionStateDisconnected, + type ConnectionStateErrored, + type MCPConnectParams, +} from "./connectionManager.js"; + +export class MCPConnectionManager extends ConnectionManager { private deviceId: DeviceId; - private clientName: string; private bus: EventEmitter; constructor( @@ -85,23 +34,15 @@ export class MCPConnectionManager extends EventEmitter { - this.emit("connection-requested", this.state); + async connect(connectParams: MCPConnectParams): Promise { + this._events.emit("connection-requested", this.state); if (this.state.tag === "connected" || this.state.tag === "connecting") { await this.disconnect(); @@ -111,22 +52,22 @@ export class MCPConnectionManager extends EventEmitter( - event: Event, - newState: State - ): State { - this.state = newState; - // TypeScript doesn't seem to be happy with the spread operator and generics - // eslint-disable-next-line - this.emit(event, ...([newState] as any)); - return newState; - } - private onOidcAuthFailed(error: unknown): void { if (this.state.tag === "connecting" && this.state.connectionStringAuthType?.startsWith("oidc")) { void this.disconnectOnOidcError(error); diff --git a/src/common/session.ts b/src/common/session.ts index 40594cbb..c09a3bcd 100644 --- a/src/common/session.ts +++ b/src/common/session.ts @@ -7,10 +7,10 @@ import { LogId } from "./logger.js"; import EventEmitter from "events"; import type { AtlasClusterConnectionInfo, - MCPConnectionManager, - ConnectionSettings, + ConnectionManager, ConnectionStateConnected, -} from "./mcpConnectionManager.js"; + MCPConnectParams, +} from "./connectionManager.js"; import type { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver"; import { ErrorCodes, MongoDBError } from "./errors.js"; import type { ExportsManager } from "./exportsManager.js"; @@ -21,7 +21,7 @@ export interface SessionOptions { apiClientSecret?: string; logger: CompositeLogger; exportsManager: ExportsManager; - connectionManager: MCPConnectionManager; + connectionManager: ConnectionManager; } export type SessionEvents = { @@ -34,7 +34,7 @@ export type SessionEvents = { export class Session extends EventEmitter { readonly sessionId: string = new ObjectId().toString(); readonly exportsManager: ExportsManager; - readonly connectionManager: MCPConnectionManager; + readonly connectionManager: ConnectionManager; readonly apiClient: ApiClient; mcpClient?: { name?: string; @@ -66,10 +66,14 @@ export class Session extends EventEmitter { this.apiClient = new ApiClient({ baseUrl: apiBaseUrl, credentials }, logger); this.exportsManager = exportsManager; this.connectionManager = connectionManager; - this.connectionManager.on("connection-succeeded", () => this.emit("connect")); - this.connectionManager.on("connection-timed-out", (error) => this.emit("connection-error", error.errorReason)); - this.connectionManager.on("connection-closed", () => this.emit("disconnect")); - this.connectionManager.on("connection-errored", (error) => this.emit("connection-error", error.errorReason)); + this.connectionManager.events.on("connection-succeeded", () => this.emit("connect")); + this.connectionManager.events.on("connection-timed-out", (error) => + this.emit("connection-error", error.errorReason) + ); + this.connectionManager.events.on("connection-closed", () => this.emit("disconnect")); + this.connectionManager.events.on("connection-errored", (error) => + this.emit("connection-error", error.errorReason) + ); } setMcpClient(mcpClient: Implementation | undefined): void { @@ -135,9 +139,9 @@ export class Session extends EventEmitter { this.emit("close"); } - async connectToMongoDB(settings: ConnectionSettings): Promise { + async connectToMongoDB(connectParams: MCPConnectParams): Promise { try { - await this.connectionManager.connect({ ...settings }); + await this.connectionManager.connect({ ...connectParams }); } catch (error: unknown) { const message = error instanceof Error ? error.message : (error as string); this.emit("connection-error", message); diff --git a/src/index.ts b/src/index.ts index b1ac4b48..6138a6e4 100644 --- a/src/index.ts +++ b/src/index.ts @@ -42,6 +42,9 @@ import { packageInfo } from "./common/packageInfo.js"; import { StdioRunner } from "./transports/stdio.js"; import { StreamableHttpRunner } from "./transports/streamableHttp.js"; import { systemCA } from "@mongodb-js/devtools-proxy-support"; +import type { MCPConnectParams } from "./lib.js"; +import type { CreateConnectionManagerFn } from "./transports/base.js"; +import { MCPConnectionManager } from "./common/mcpConnectionManager.js"; async function main(): Promise { systemCA().catch(() => undefined); // load system CA asynchronously as in mongosh @@ -49,10 +52,13 @@ async function main(): Promise { assertHelpMode(); assertVersionMode(); + const createConnectionManager: CreateConnectionManagerFn = ({ logger, deviceId }) => + new MCPConnectionManager(config, driverOptions, logger, deviceId); + const transportRunner = config.transport === "stdio" - ? new StdioRunner(config, driverOptions) - : new StreamableHttpRunner(config, driverOptions); + ? new StdioRunner(config, createConnectionManager) + : new StreamableHttpRunner(config, createConnectionManager); const shutdown = (): void => { transportRunner.logger.info({ id: LogId.serverCloseRequested, diff --git a/src/lib.ts b/src/lib.ts index 9fd921e4..9985f381 100644 --- a/src/lib.ts +++ b/src/lib.ts @@ -3,5 +3,5 @@ export { Telemetry } from "./telemetry/telemetry.js"; export { Session, type SessionOptions } from "./common/session.js"; export { type UserConfig, defaultUserConfig } from "./common/config.js"; export { StreamableHttpRunner } from "./transports/streamableHttp.js"; -export { LoggerBase } from "./common/logger.js"; -export type { LogPayload, LoggerType, LogLevel } from "./common/logger.js"; +export { LoggerBase, CompositeLogger, type LogPayload, type LoggerType, type LogLevel } from "./common/logger.js"; +export * from "./common/connectionManager.js"; diff --git a/src/tools/atlas/connect/connectCluster.ts b/src/tools/atlas/connect/connectCluster.ts index e7445d92..9695ff36 100644 --- a/src/tools/atlas/connect/connectCluster.ts +++ b/src/tools/atlas/connect/connectCluster.ts @@ -6,7 +6,7 @@ import { generateSecurePassword } from "../../../helpers/generatePassword.js"; import { LogId } from "../../../common/logger.js"; import { inspectCluster } from "../../../common/atlas/cluster.js"; import { ensureCurrentIpInAccessList } from "../../../common/atlas/accessListUtils.js"; -import type { AtlasClusterConnectionInfo } from "../../../common/mcpConnectionManager.js"; +import type { AtlasClusterConnectionInfo } from "../../../common/connectionManager.js"; import { getDefaultRoleFromConfig } from "../../../common/atlas/roles.js"; const EXPIRY_MS = 1000 * 60 * 60 * 12; // 12 hours diff --git a/src/transports/base.ts b/src/transports/base.ts index 02db4663..2a60b725 100644 --- a/src/transports/base.ts +++ b/src/transports/base.ts @@ -1,4 +1,4 @@ -import type { DriverOptions, UserConfig } from "../common/config.js"; +import type { UserConfig } from "../common/config.js"; import { packageInfo } from "../common/packageInfo.js"; import { Server } from "../server.js"; import { Session } from "../common/session.js"; @@ -7,17 +7,22 @@ import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; import type { LoggerBase } from "../common/logger.js"; import { CompositeLogger, ConsoleLogger, DiskLogger, McpLogger } from "../common/logger.js"; import { ExportsManager } from "../common/exportsManager.js"; -import { MCPConnectionManager } from "../common/mcpConnectionManager.js"; +import type { ConnectionManager, MCPConnectParams } from "../common/connectionManager.js"; import { DeviceId } from "../helpers/deviceId.js"; -export abstract class TransportRunnerBase { +export type CreateConnectionManagerFn = (createParams: { + logger: CompositeLogger; + deviceId: DeviceId; +}) => ConnectionManager; + +export abstract class TransportRunnerBase { public logger: LoggerBase; public deviceId: DeviceId; - protected constructor( + constructor( protected readonly userConfig: UserConfig, - private readonly driverOptions: DriverOptions, - additionalLoggers: LoggerBase[] + private readonly createConnectionManager: CreateConnectionManagerFn, + additionalLoggers: LoggerBase[] = [] ) { const loggers: LoggerBase[] = [...additionalLoggers]; if (this.userConfig.loggers.includes("stderr")) { @@ -46,7 +51,7 @@ export abstract class TransportRunnerBase { const logger = new CompositeLogger(this.logger); const exportsManager = ExportsManager.init(this.userConfig, logger); - const connectionManager = new MCPConnectionManager(this.userConfig, this.driverOptions, logger, this.deviceId); + const connectionManager = this.createConnectionManager({ logger, deviceId: this.deviceId }); const session = new Session({ apiBaseUrl: this.userConfig.apiBaseUrl, diff --git a/src/transports/stdio.ts b/src/transports/stdio.ts index 0751cac7..7e9be4b4 100644 --- a/src/transports/stdio.ts +++ b/src/transports/stdio.ts @@ -1,12 +1,11 @@ -import type { LoggerBase } from "../common/logger.js"; -import { LogId } from "../common/logger.js"; -import type { Server } from "../server.js"; -import { TransportRunnerBase } from "./base.js"; +import { EJSON } from "bson"; import type { JSONRPCMessage } from "@modelcontextprotocol/sdk/types.js"; import { JSONRPCMessageSchema } from "@modelcontextprotocol/sdk/types.js"; -import { EJSON } from "bson"; import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio.js"; -import type { DriverOptions, UserConfig } from "../common/config.js"; +import { LogId } from "../common/logger.js"; +import type { Server } from "../server.js"; +import { TransportRunnerBase } from "./base.js"; +import type { MCPConnectParams } from "../common/connectionManager.js"; // This is almost a copy of ReadBuffer from @modelcontextprotocol/sdk // but it uses EJSON.parse instead of JSON.parse to handle BSON types @@ -52,13 +51,9 @@ export function createStdioTransport(): StdioServerTransport { return server; } -export class StdioRunner extends TransportRunnerBase { +export class StdioRunner extends TransportRunnerBase { private server: Server | undefined; - constructor(userConfig: UserConfig, driverOptions: DriverOptions, additionalLoggers: LoggerBase[] = []) { - super(userConfig, driverOptions, additionalLoggers); - } - async start(): Promise { try { this.server = this.setupServer(); diff --git a/src/transports/streamableHttp.ts b/src/transports/streamableHttp.ts index 1718252c..be92478e 100644 --- a/src/transports/streamableHttp.ts +++ b/src/transports/streamableHttp.ts @@ -1,13 +1,12 @@ import express from "express"; import type http from "http"; +import { randomUUID } from "crypto"; import { StreamableHTTPServerTransport } from "@modelcontextprotocol/sdk/server/streamableHttp.js"; import { isInitializeRequest } from "@modelcontextprotocol/sdk/types.js"; import { TransportRunnerBase } from "./base.js"; -import type { DriverOptions, UserConfig } from "../common/config.js"; -import type { LoggerBase } from "../common/logger.js"; import { LogId } from "../common/logger.js"; -import { randomUUID } from "crypto"; import { SessionStore } from "../common/sessionStore.js"; +import type { MCPConnectParams } from "../common/connectionManager.js"; const JSON_RPC_ERROR_CODE_PROCESSING_REQUEST_FAILED = -32000; const JSON_RPC_ERROR_CODE_SESSION_ID_REQUIRED = -32001; @@ -15,7 +14,7 @@ const JSON_RPC_ERROR_CODE_SESSION_ID_INVALID = -32002; const JSON_RPC_ERROR_CODE_SESSION_NOT_FOUND = -32003; const JSON_RPC_ERROR_CODE_INVALID_REQUEST = -32004; -export class StreamableHttpRunner extends TransportRunnerBase { +export class StreamableHttpRunner extends TransportRunnerBase { private httpServer: http.Server | undefined; private sessionStore!: SessionStore; @@ -31,10 +30,6 @@ export class StreamableHttpRunner extends TransportRunnerBase { throw new Error("Server is not started yet"); } - constructor(userConfig: UserConfig, driverOptions: DriverOptions, additionalLoggers: LoggerBase[] = []) { - super(userConfig, driverOptions, additionalLoggers); - } - async start(): Promise { const app = express(); this.sessionStore = new SessionStore( diff --git a/tests/integration/build.test.ts b/tests/integration/build.test.ts index f5b26827..e49509ca 100644 --- a/tests/integration/build.test.ts +++ b/tests/integration/build.test.ts @@ -41,11 +41,14 @@ describe("Build Test", () => { const esmKeys = Object.keys(esmModule).sort(); expect(cjsKeys).toEqual(esmKeys); - expect(cjsKeys).toIncludeSameMembers([ + expect(cjsKeys).toContainEqual([ + "CompositeLogger", + "ConnectionManager", + "LoggerBase", "Server", "Session", - "Telemetry", "StreamableHttpRunner", + "Telemetry", "defaultUserConfig", "LoggerBase", ]); diff --git a/tests/integration/common/connectionManager.oidc.test.ts b/tests/integration/common/connectionManager.oidc.test.ts index 1f9fa263..e3a406eb 100644 --- a/tests/integration/common/connectionManager.oidc.test.ts +++ b/tests/integration/common/connectionManager.oidc.test.ts @@ -5,7 +5,7 @@ import process from "process"; import type { MongoDBIntegrationTestCase } from "../tools/mongodb/mongodbHelpers.js"; import { describeWithMongoDB, isCommunityServer, getServerVersion } from "../tools/mongodb/mongodbHelpers.js"; import { defaultTestConfig, responseAsText, timeout, waitUntil } from "../helpers.js"; -import type { ConnectionStateConnected, ConnectionStateConnecting } from "../../../src/common/mcpConnectionManager.js"; +import type { ConnectionStateConnected, ConnectionStateConnecting } from "../../../src/common/connectionManager.js"; import type { UserConfig } from "../../../src/common/config.js"; import { setupDriverConfig } from "../../../src/common/config.js"; import path from "path"; diff --git a/tests/integration/common/connectionManager.test.ts b/tests/integration/common/connectionManager.test.ts index ccea2010..a5815b40 100644 --- a/tests/integration/common/connectionManager.test.ts +++ b/tests/integration/common/connectionManager.test.ts @@ -1,15 +1,16 @@ import type { - MCPConnectionManagerEvents, + ConnectionManager, + ConnectionManagerEvents, ConnectionStateConnected, ConnectionStringAuthType, -} from "../../../src/common/mcpConnectionManager.js"; +} from "../../../src/common/connectionManager.js"; import { MCPConnectionManager } from "../../../src/common/mcpConnectionManager.js"; import type { UserConfig } from "../../../src/common/config.js"; import { describeWithMongoDB } from "../tools/mongodb/mongodbHelpers.js"; import { describe, beforeEach, expect, it, vi, afterEach } from "vitest"; describeWithMongoDB("Connection Manager", (integration) => { - function connectionManager(): MCPConnectionManager { + function connectionManager(): ConnectionManager { return integration.mcpServer().session.connectionManager; } @@ -24,11 +25,11 @@ describeWithMongoDB("Connection Manager", (integration) => { describe("when successfully connected", () => { type ConnectionManagerSpies = { - "connection-requested": (event: MCPConnectionManagerEvents["connection-requested"][0]) => void; - "connection-succeeded": (event: MCPConnectionManagerEvents["connection-succeeded"][0]) => void; - "connection-timed-out": (event: MCPConnectionManagerEvents["connection-timed-out"][0]) => void; - "connection-closed": (event: MCPConnectionManagerEvents["connection-closed"][0]) => void; - "connection-errored": (event: MCPConnectionManagerEvents["connection-errored"][0]) => void; + "connection-requested": (event: ConnectionManagerEvents["connection-requested"][0]) => void; + "connection-succeeded": (event: ConnectionManagerEvents["connection-succeeded"][0]) => void; + "connection-timed-out": (event: ConnectionManagerEvents["connection-timed-out"][0]) => void; + "connection-closed": (event: ConnectionManagerEvents["connection-closed"][0]) => void; + "connection-errored": (event: ConnectionManagerEvents["connection-errored"][0]) => void; }; let connectionManagerSpies: ConnectionManagerSpies; @@ -43,7 +44,7 @@ describeWithMongoDB("Connection Manager", (integration) => { }; for (const [event, spy] of Object.entries(connectionManagerSpies)) { - connectionManager().on(event as keyof MCPConnectionManagerEvents, spy); + connectionManager().events.on(event as keyof ConnectionManagerEvents, spy); } await connectionManager().connect({ diff --git a/tests/integration/helpers.ts b/tests/integration/helpers.ts index 5c267cee..3c646f49 100644 --- a/tests/integration/helpers.ts +++ b/tests/integration/helpers.ts @@ -10,7 +10,7 @@ import type { UserConfig, DriverOptions } from "../../src/common/config.js"; import { McpError, ResourceUpdatedNotificationSchema } from "@modelcontextprotocol/sdk/types.js"; import { config, driverOptions } from "../../src/common/config.js"; import { afterAll, afterEach, beforeAll, describe, expect, it, vi } from "vitest"; -import type { ConnectionState } from "../../src/common/mcpConnectionManager.js"; +import type { ConnectionManager, ConnectionState } from "../../src/common/connectionManager.js"; import { MCPConnectionManager } from "../../src/common/mcpConnectionManager.js"; import { DeviceId } from "../../src/helpers/deviceId.js"; @@ -315,7 +315,7 @@ export function responseAsText(response: Awaited> export function waitUntil( tag: T["tag"], - cm: MCPConnectionManager, + cm: ConnectionManager, signal: AbortSignal, additionalCondition?: (state: T) => boolean ): Promise { diff --git a/tests/integration/transports/streamableHttp.test.ts b/tests/integration/transports/streamableHttp.test.ts index f45ce3cd..eb1d42b4 100644 --- a/tests/integration/transports/streamableHttp.test.ts +++ b/tests/integration/transports/streamableHttp.test.ts @@ -5,9 +5,11 @@ import { describe, expect, it, beforeAll, afterAll, beforeEach } from "vitest"; import { config, driverOptions } from "../../../src/common/config.js"; import type { LoggerType, LogLevel, LogPayload } from "../../../src/common/logger.js"; import { LoggerBase, LogId } from "../../../src/common/logger.js"; +import { MCPConnectionManager } from "../../../src/common/mcpConnectionManager.js"; +import type { MCPConnectParams } from "../../../src/lib.js"; describe("StreamableHttpRunner", () => { - let runner: StreamableHttpRunner; + let runner: StreamableHttpRunner; let oldTelemetry: "enabled" | "disabled"; let oldLoggers: ("stderr" | "disk" | "mcp")[]; @@ -28,7 +30,10 @@ describe("StreamableHttpRunner", () => { describe(description, () => { beforeAll(async () => { config.httpHeaders = headers; - runner = new StreamableHttpRunner(config, driverOptions); + runner = new StreamableHttpRunner( + config, + ({ logger, deviceId }) => new MCPConnectionManager(config, driverOptions, logger, deviceId) + ); await runner.start(); }); @@ -105,11 +110,14 @@ describe("StreamableHttpRunner", () => { } it("can create multiple runners", async () => { - const runners: StreamableHttpRunner[] = []; + const runners: StreamableHttpRunner[] = []; try { for (let i = 0; i < 3; i++) { config.httpPort = 0; // Use a random port for each runner - const runner = new StreamableHttpRunner(config, driverOptions); + const runner = new StreamableHttpRunner( + config, + ({ logger, deviceId }) => new MCPConnectionManager(config, driverOptions, logger, deviceId) + ); await runner.start(); runners.push(runner); } @@ -138,7 +146,11 @@ describe("StreamableHttpRunner", () => { it("can provide custom logger", async () => { const logger = new CustomLogger(); - const runner = new StreamableHttpRunner(config, driverOptions, [logger]); + const runner = new StreamableHttpRunner( + config, + ({ logger, deviceId }) => new MCPConnectionManager(config, driverOptions, logger, deviceId), + [logger] + ); await runner.start(); const messages = logger.messages; From 2573a45740800a030898e589a9ff60b7c628781d Mon Sep 17 00:00:00 2001 From: Himanshu Singh Date: Wed, 27 Aug 2025 14:35:09 +0200 Subject: [PATCH 03/13] chore: awaitable connection manager factory --- src/lib.ts | 18 ++++++++++++++---- src/transports/base.ts | 6 +++--- src/transports/stdio.ts | 2 +- src/transports/streamableHttp.ts | 2 +- tests/integration/build.test.ts | 23 ++++++++++++----------- 5 files changed, 31 insertions(+), 20 deletions(-) diff --git a/src/lib.ts b/src/lib.ts index 9985f381..2c72fc20 100644 --- a/src/lib.ts +++ b/src/lib.ts @@ -1,7 +1,17 @@ export { Server, type ServerOptions } from "./server.js"; -export { Telemetry } from "./telemetry/telemetry.js"; export { Session, type SessionOptions } from "./common/session.js"; -export { type UserConfig, defaultUserConfig } from "./common/config.js"; -export { StreamableHttpRunner } from "./transports/streamableHttp.js"; +export { defaultUserConfig, type UserConfig } from "./common/config.js"; export { LoggerBase, CompositeLogger, type LogPayload, type LoggerType, type LogLevel } from "./common/logger.js"; -export * from "./common/connectionManager.js"; +export { StreamableHttpRunner } from "./transports/streamableHttp.js"; +export { type CreateConnectionManagerFn } from "./transports/base.js"; +export { + ConnectionManager, + type MCPConnectParams, + type AnyConnectionState, + type ConnectionState, + type ConnectionStateConnected, + type ConnectionStateConnecting, + type ConnectionStateDisconnected, + type ConnectionStateErrored, +} from "./common/connectionManager.js"; +export { Telemetry } from "./telemetry/telemetry.js"; diff --git a/src/transports/base.ts b/src/transports/base.ts index 2a60b725..bfd09461 100644 --- a/src/transports/base.ts +++ b/src/transports/base.ts @@ -13,7 +13,7 @@ import { DeviceId } from "../helpers/deviceId.js"; export type CreateConnectionManagerFn = (createParams: { logger: CompositeLogger; deviceId: DeviceId; -}) => ConnectionManager; +}) => ConnectionManager | Promise>; export abstract class TransportRunnerBase { public logger: LoggerBase; @@ -43,7 +43,7 @@ export abstract class TransportRunnerBase { const mcpServer = new McpServer({ name: packageInfo.mcpServerName, version: packageInfo.version, @@ -51,7 +51,7 @@ export abstract class TransportRunnerBase extends Transpo async start(): Promise { try { - this.server = this.setupServer(); + this.server = await this.setupServer(); const transport = createStdioTransport(); diff --git a/src/transports/streamableHttp.ts b/src/transports/streamableHttp.ts index be92478e..0c39765a 100644 --- a/src/transports/streamableHttp.ts +++ b/src/transports/streamableHttp.ts @@ -108,7 +108,7 @@ export class StreamableHttpRunner extend return; } - const server = this.setupServer(); + const server = await this.setupServer(); let keepAliveLoop: NodeJS.Timeout; const transport = new StreamableHTTPServerTransport({ sessionIdGenerator: (): string => randomUUID().toString(), diff --git a/tests/integration/build.test.ts b/tests/integration/build.test.ts index e49509ca..ef328320 100644 --- a/tests/integration/build.test.ts +++ b/tests/integration/build.test.ts @@ -41,16 +41,17 @@ describe("Build Test", () => { const esmKeys = Object.keys(esmModule).sort(); expect(cjsKeys).toEqual(esmKeys); - expect(cjsKeys).toContainEqual([ - "CompositeLogger", - "ConnectionManager", - "LoggerBase", - "Server", - "Session", - "StreamableHttpRunner", - "Telemetry", - "defaultUserConfig", - "LoggerBase", - ]); + expect(cjsKeys).toEqual( + expect.arrayContaining([ + "CompositeLogger", + "ConnectionManager", + "LoggerBase", + "Server", + "Session", + "StreamableHttpRunner", + "Telemetry", + "defaultUserConfig", + ]) + ); }); }); From b7ac5785e326a6d4d9524c40e83ccbcabfe9928c Mon Sep 17 00:00:00 2001 From: Himanshu Singh Date: Wed, 27 Aug 2025 14:47:17 +0200 Subject: [PATCH 04/13] chore: bring ctor back on runner implementations --- src/transports/base.ts | 2 +- src/transports/stdio.ts | 13 +++++++++++-- src/transports/streamableHttp.ts | 13 +++++++++++-- 3 files changed, 23 insertions(+), 5 deletions(-) diff --git a/src/transports/base.ts b/src/transports/base.ts index bfd09461..ea7f571e 100644 --- a/src/transports/base.ts +++ b/src/transports/base.ts @@ -22,7 +22,7 @@ export abstract class TransportRunnerBase, - additionalLoggers: LoggerBase[] = [] + additionalLoggers: LoggerBase[] ) { const loggers: LoggerBase[] = [...additionalLoggers]; if (this.userConfig.loggers.includes("stderr")) { diff --git a/src/transports/stdio.ts b/src/transports/stdio.ts index 9a33d820..cc823858 100644 --- a/src/transports/stdio.ts +++ b/src/transports/stdio.ts @@ -2,9 +2,10 @@ import { EJSON } from "bson"; import type { JSONRPCMessage } from "@modelcontextprotocol/sdk/types.js"; import { JSONRPCMessageSchema } from "@modelcontextprotocol/sdk/types.js"; import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio.js"; -import { LogId } from "../common/logger.js"; +import { type LoggerBase, LogId } from "../common/logger.js"; import type { Server } from "../server.js"; -import { TransportRunnerBase } from "./base.js"; +import { type CreateConnectionManagerFn, TransportRunnerBase } from "./base.js"; +import { type UserConfig } from "../common/config.js"; import type { MCPConnectParams } from "../common/connectionManager.js"; // This is almost a copy of ReadBuffer from @modelcontextprotocol/sdk @@ -54,6 +55,14 @@ export function createStdioTransport(): StdioServerTransport { export class StdioRunner extends TransportRunnerBase { private server: Server | undefined; + constructor( + userConfig: UserConfig, + createConnectionManager: CreateConnectionManagerFn, + additionalLoggers: LoggerBase[] = [] + ) { + super(userConfig, createConnectionManager, additionalLoggers); + } + async start(): Promise { try { this.server = await this.setupServer(); diff --git a/src/transports/streamableHttp.ts b/src/transports/streamableHttp.ts index 0c39765a..b43125f8 100644 --- a/src/transports/streamableHttp.ts +++ b/src/transports/streamableHttp.ts @@ -3,10 +3,11 @@ import type http from "http"; import { randomUUID } from "crypto"; import { StreamableHTTPServerTransport } from "@modelcontextprotocol/sdk/server/streamableHttp.js"; import { isInitializeRequest } from "@modelcontextprotocol/sdk/types.js"; -import { TransportRunnerBase } from "./base.js"; -import { LogId } from "../common/logger.js"; +import { LogId, type LoggerBase } from "../common/logger.js"; +import { type UserConfig } from "../common/config.js"; import { SessionStore } from "../common/sessionStore.js"; import type { MCPConnectParams } from "../common/connectionManager.js"; +import { type CreateConnectionManagerFn, TransportRunnerBase } from "./base.js"; const JSON_RPC_ERROR_CODE_PROCESSING_REQUEST_FAILED = -32000; const JSON_RPC_ERROR_CODE_SESSION_ID_REQUIRED = -32001; @@ -18,6 +19,14 @@ export class StreamableHttpRunner extend private httpServer: http.Server | undefined; private sessionStore!: SessionStore; + constructor( + userConfig: UserConfig, + createConnectionManager: CreateConnectionManagerFn, + additionalLoggers: LoggerBase[] = [] + ) { + super(userConfig, createConnectionManager, additionalLoggers); + } + public get serverAddress(): string { const result = this.httpServer?.address(); if (typeof result === "string") { From 871be6ff70345f89ecd500d0b075f018fe812e9f Mon Sep 17 00:00:00 2001 From: Himanshu Singh Date: Thu, 28 Aug 2025 13:51:20 +0200 Subject: [PATCH 05/13] chore: remove generic from abstract ConnectionManager --- src/common/connectionManager.ts | 246 ++++++++++++++++- src/common/mcpConnectionManager.ts | 252 ------------------ src/common/session.ts | 6 +- src/index.ts | 9 +- src/lib.ts | 1 - src/transports/base.ts | 10 +- src/transports/stdio.ts | 5 +- src/transports/streamableHttp.ts | 5 +- .../common/connectionManager.test.ts | 2 +- tests/integration/helpers.ts | 2 +- tests/integration/telemetry.test.ts | 2 +- .../transports/streamableHttp.test.ts | 7 +- tests/unit/common/session.test.ts | 2 +- tests/unit/resources/common/debug.test.ts | 2 +- 14 files changed, 266 insertions(+), 285 deletions(-) delete mode 100644 src/common/mcpConnectionManager.ts diff --git a/src/common/connectionManager.ts b/src/common/connectionManager.ts index ae976171..9bea9daf 100644 --- a/src/common/connectionManager.ts +++ b/src/common/connectionManager.ts @@ -1,5 +1,14 @@ import { EventEmitter } from "events"; -import type { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver"; +import type { MongoClientOptions } from "mongodb"; +import ConnectionString from "mongodb-connection-string-url"; +import { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver"; +import { type ConnectionInfo, generateConnectionInfoFromCliArgs } from "@mongosh/arg-parser"; +import type { DeviceId } from "../helpers/deviceId.js"; +import type { DriverOptions, UserConfig } from "./config.js"; +import { MongoDBError, ErrorCodes } from "./errors.js"; +import { type CompositeLogger, LogId } from "./logger.js"; +import { packageInfo } from "./packageInfo.js"; +import { type AppNameComponents, setAppNameParamIfMissing } from "../helpers/connectionOptions.js"; export interface AtlasClusterConnectionInfo { username: string; @@ -54,12 +63,12 @@ export interface ConnectionManagerEvents { "connection-errored": [ConnectionStateErrored]; } -export interface MCPConnectParams { +export interface ConnectionSettings { connectionString: string; atlas?: AtlasClusterConnectionInfo; } -export abstract class ConnectionManager { +export abstract class ConnectionManager { protected clientName: string = "unknown"; protected readonly _events = new EventEmitter(); @@ -86,7 +95,236 @@ export abstract class ConnectionManager; + abstract connect(settings: ConnectionSettings): Promise; abstract disconnect(): Promise; } + +export class MCPConnectionManager extends ConnectionManager { + private deviceId: DeviceId; + private bus: EventEmitter; + + constructor( + private userConfig: UserConfig, + private driverOptions: DriverOptions, + private logger: CompositeLogger, + deviceId: DeviceId, + bus?: EventEmitter + ) { + super(); + this.bus = bus ?? new EventEmitter(); + this.bus.on("mongodb-oidc-plugin:auth-failed", this.onOidcAuthFailed.bind(this)); + this.bus.on("mongodb-oidc-plugin:auth-succeeded", this.onOidcAuthSucceeded.bind(this)); + this.deviceId = deviceId; + this.clientName = "unknown"; + } + + async connect(connectParams: ConnectionSettings): Promise { + this._events.emit("connection-requested", this.state); + + if (this.state.tag === "connected" || this.state.tag === "connecting") { + await this.disconnect(); + } + + let serviceProvider: NodeDriverServiceProvider; + let connectionInfo: ConnectionInfo; + + try { + connectParams = { ...connectParams }; + const appNameComponents: AppNameComponents = { + appName: `${packageInfo.mcpServerName} ${packageInfo.version}`, + deviceId: this.deviceId.get(), + clientName: this.clientName, + }; + + connectParams.connectionString = await setAppNameParamIfMissing({ + connectionString: connectParams.connectionString, + components: appNameComponents, + }); + + connectionInfo = generateConnectionInfoFromCliArgs({ + ...this.userConfig, + ...this.driverOptions, + connectionSpecifier: connectParams.connectionString, + }); + + if (connectionInfo.driverOptions.oidc) { + connectionInfo.driverOptions.oidc.allowedFlows ??= ["auth-code"]; + connectionInfo.driverOptions.oidc.notifyDeviceFlow ??= this.onOidcNotifyDeviceFlow.bind(this); + } + + connectionInfo.driverOptions.proxy ??= { useEnvironmentVariableProxies: true }; + connectionInfo.driverOptions.applyProxyToOIDC ??= true; + + serviceProvider = await NodeDriverServiceProvider.connect( + connectionInfo.connectionString, + { + productDocsLink: "https://github.com/mongodb-js/mongodb-mcp-server/", + productName: "MongoDB MCP", + ...connectionInfo.driverOptions, + }, + undefined, + this.bus + ); + } catch (error: unknown) { + const errorReason = error instanceof Error ? error.message : `${error as string}`; + this.changeState("connection-errored", { + tag: "errored", + errorReason, + connectedAtlasCluster: connectParams.atlas, + }); + throw new MongoDBError(ErrorCodes.MisconfiguredConnectionString, errorReason); + } + + try { + const connectionType = MCPConnectionManager.inferConnectionTypeFromSettings( + this.userConfig, + connectionInfo + ); + if (connectionType.startsWith("oidc")) { + void this.pingAndForget(serviceProvider); + + return this.changeState("connection-requested", { + tag: "connecting", + connectedAtlasCluster: connectParams.atlas, + serviceProvider, + connectionStringAuthType: connectionType, + oidcConnectionType: connectionType as OIDCConnectionAuthType, + }); + } + + await serviceProvider?.runCommand?.("admin", { hello: 1 }); + + return this.changeState("connection-succeeded", { + tag: "connected", + connectedAtlasCluster: connectParams.atlas, + serviceProvider, + connectionStringAuthType: connectionType, + }); + } catch (error: unknown) { + const errorReason = error instanceof Error ? error.message : `${error as string}`; + this.changeState("connection-errored", { + tag: "errored", + errorReason, + connectedAtlasCluster: connectParams.atlas, + }); + throw new MongoDBError(ErrorCodes.NotConnectedToMongoDB, errorReason); + } + } + + async disconnect(): Promise { + if (this.state.tag === "disconnected" || this.state.tag === "errored") { + return this.state; + } + + if (this.state.tag === "connected" || this.state.tag === "connecting") { + try { + await this.state.serviceProvider?.close(true); + } finally { + this.changeState("connection-closed", { + tag: "disconnected", + }); + } + } + + return { tag: "disconnected" }; + } + + private onOidcAuthFailed(error: unknown): void { + if (this.state.tag === "connecting" && this.state.connectionStringAuthType?.startsWith("oidc")) { + void this.disconnectOnOidcError(error); + } + } + + private onOidcAuthSucceeded(): void { + if (this.state.tag === "connecting" && this.state.connectionStringAuthType?.startsWith("oidc")) { + this.changeState("connection-succeeded", { ...this.state, tag: "connected" }); + } + + this.logger.info({ + id: LogId.oidcFlow, + context: "mongodb-oidc-plugin:auth-succeeded", + message: "Authenticated successfully.", + }); + } + + private onOidcNotifyDeviceFlow(flowInfo: { verificationUrl: string; userCode: string }): void { + if (this.state.tag === "connecting" && this.state.connectionStringAuthType?.startsWith("oidc")) { + this.changeState("connection-requested", { + ...this.state, + tag: "connecting", + connectionStringAuthType: "oidc-device-flow", + oidcLoginUrl: flowInfo.verificationUrl, + oidcUserCode: flowInfo.userCode, + }); + } + + this.logger.info({ + id: LogId.oidcFlow, + context: "mongodb-oidc-plugin:notify-device-flow", + message: "OIDC Flow changed automatically to device flow.", + }); + } + + static inferConnectionTypeFromSettings( + config: UserConfig, + settings: { connectionString: string } + ): ConnectionStringAuthType { + const connString = new ConnectionString(settings.connectionString); + const searchParams = connString.typedSearchParams(); + + switch (searchParams.get("authMechanism")) { + case "MONGODB-OIDC": { + if (config.transport === "stdio" && config.browser) { + return "oidc-auth-flow"; + } + + if (config.transport === "http" && config.httpHost === "127.0.0.1" && config.browser) { + return "oidc-auth-flow"; + } + + return "oidc-device-flow"; + } + case "MONGODB-X509": + return "x.509"; + case "GSSAPI": + return "kerberos"; + case "PLAIN": + if (searchParams.get("authSource") === "$external") { + return "ldap"; + } + return "scram"; + // default should catch also null, but eslint complains + // about it. + case null: + default: + return "scram"; + } + } + + private async pingAndForget(serviceProvider: NodeDriverServiceProvider): Promise { + try { + await serviceProvider?.runCommand?.("admin", { hello: 1 }); + } catch (error: unknown) { + this.logger.warning({ + id: LogId.oidcFlow, + context: "pingAndForget", + message: String(error), + }); + } + } + + private async disconnectOnOidcError(error: unknown): Promise { + try { + await this.disconnect(); + } catch (error: unknown) { + this.logger.warning({ + id: LogId.oidcFlow, + context: "disconnectOnOidcError", + message: String(error), + }); + } finally { + this.changeState("connection-errored", { tag: "errored", errorReason: String(error) }); + } + } +} diff --git a/src/common/mcpConnectionManager.ts b/src/common/mcpConnectionManager.ts deleted file mode 100644 index 4119cc12..00000000 --- a/src/common/mcpConnectionManager.ts +++ /dev/null @@ -1,252 +0,0 @@ -import type { UserConfig, DriverOptions } from "./config.js"; -import { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver"; -import EventEmitter from "events"; -import { setAppNameParamIfMissing } from "../helpers/connectionOptions.js"; -import { packageInfo } from "./packageInfo.js"; -import ConnectionString from "mongodb-connection-string-url"; -import type { MongoClientOptions } from "mongodb"; -import { ErrorCodes, MongoDBError } from "./errors.js"; -import type { DeviceId } from "../helpers/deviceId.js"; -import type { AppNameComponents } from "../helpers/connectionOptions.js"; -import type { CompositeLogger } from "./logger.js"; -import { LogId } from "./logger.js"; -import type { ConnectionInfo } from "@mongosh/arg-parser"; -import { generateConnectionInfoFromCliArgs } from "@mongosh/arg-parser"; -import { - ConnectionManager, - type AnyConnectionState, - type ConnectionStringAuthType, - type OIDCConnectionAuthType, - type ConnectionStateDisconnected, - type ConnectionStateErrored, - type MCPConnectParams, -} from "./connectionManager.js"; - -export class MCPConnectionManager extends ConnectionManager { - private deviceId: DeviceId; - private bus: EventEmitter; - - constructor( - private userConfig: UserConfig, - private driverOptions: DriverOptions, - private logger: CompositeLogger, - deviceId: DeviceId, - bus?: EventEmitter - ) { - super(); - this.bus = bus ?? new EventEmitter(); - this.bus.on("mongodb-oidc-plugin:auth-failed", this.onOidcAuthFailed.bind(this)); - this.bus.on("mongodb-oidc-plugin:auth-succeeded", this.onOidcAuthSucceeded.bind(this)); - this.deviceId = deviceId; - this.clientName = "unknown"; - } - - async connect(connectParams: MCPConnectParams): Promise { - this._events.emit("connection-requested", this.state); - - if (this.state.tag === "connected" || this.state.tag === "connecting") { - await this.disconnect(); - } - - let serviceProvider: NodeDriverServiceProvider; - let connectionInfo: ConnectionInfo; - - try { - connectParams = { ...connectParams }; - const appNameComponents: AppNameComponents = { - appName: `${packageInfo.mcpServerName} ${packageInfo.version}`, - deviceId: this.deviceId.get(), - clientName: this.clientName, - }; - - connectParams.connectionString = await setAppNameParamIfMissing({ - connectionString: connectParams.connectionString, - components: appNameComponents, - }); - - connectionInfo = generateConnectionInfoFromCliArgs({ - ...this.userConfig, - ...this.driverOptions, - connectionSpecifier: connectParams.connectionString, - }); - - if (connectionInfo.driverOptions.oidc) { - connectionInfo.driverOptions.oidc.allowedFlows ??= ["auth-code"]; - connectionInfo.driverOptions.oidc.notifyDeviceFlow ??= this.onOidcNotifyDeviceFlow.bind(this); - } - - connectionInfo.driverOptions.proxy ??= { useEnvironmentVariableProxies: true }; - connectionInfo.driverOptions.applyProxyToOIDC ??= true; - - serviceProvider = await NodeDriverServiceProvider.connect( - connectionInfo.connectionString, - { - productDocsLink: "https://github.com/mongodb-js/mongodb-mcp-server/", - productName: "MongoDB MCP", - ...connectionInfo.driverOptions, - }, - undefined, - this.bus - ); - } catch (error: unknown) { - const errorReason = error instanceof Error ? error.message : `${error as string}`; - this.changeState("connection-errored", { - tag: "errored", - errorReason, - connectedAtlasCluster: connectParams.atlas, - }); - throw new MongoDBError(ErrorCodes.MisconfiguredConnectionString, errorReason); - } - - try { - const connectionType = MCPConnectionManager.inferConnectionTypeFromSettings( - this.userConfig, - connectionInfo - ); - if (connectionType.startsWith("oidc")) { - void this.pingAndForget(serviceProvider); - - return this.changeState("connection-requested", { - tag: "connecting", - connectedAtlasCluster: connectParams.atlas, - serviceProvider, - connectionStringAuthType: connectionType, - oidcConnectionType: connectionType as OIDCConnectionAuthType, - }); - } - - await serviceProvider?.runCommand?.("admin", { hello: 1 }); - - return this.changeState("connection-succeeded", { - tag: "connected", - connectedAtlasCluster: connectParams.atlas, - serviceProvider, - connectionStringAuthType: connectionType, - }); - } catch (error: unknown) { - const errorReason = error instanceof Error ? error.message : `${error as string}`; - this.changeState("connection-errored", { - tag: "errored", - errorReason, - connectedAtlasCluster: connectParams.atlas, - }); - throw new MongoDBError(ErrorCodes.NotConnectedToMongoDB, errorReason); - } - } - - async disconnect(): Promise { - if (this.state.tag === "disconnected" || this.state.tag === "errored") { - return this.state; - } - - if (this.state.tag === "connected" || this.state.tag === "connecting") { - try { - await this.state.serviceProvider?.close(true); - } finally { - this.changeState("connection-closed", { - tag: "disconnected", - }); - } - } - - return { tag: "disconnected" }; - } - - private onOidcAuthFailed(error: unknown): void { - if (this.state.tag === "connecting" && this.state.connectionStringAuthType?.startsWith("oidc")) { - void this.disconnectOnOidcError(error); - } - } - - private onOidcAuthSucceeded(): void { - if (this.state.tag === "connecting" && this.state.connectionStringAuthType?.startsWith("oidc")) { - this.changeState("connection-succeeded", { ...this.state, tag: "connected" }); - } - - this.logger.info({ - id: LogId.oidcFlow, - context: "mongodb-oidc-plugin:auth-succeeded", - message: "Authenticated successfully.", - }); - } - - private onOidcNotifyDeviceFlow(flowInfo: { verificationUrl: string; userCode: string }): void { - if (this.state.tag === "connecting" && this.state.connectionStringAuthType?.startsWith("oidc")) { - this.changeState("connection-requested", { - ...this.state, - tag: "connecting", - connectionStringAuthType: "oidc-device-flow", - oidcLoginUrl: flowInfo.verificationUrl, - oidcUserCode: flowInfo.userCode, - }); - } - - this.logger.info({ - id: LogId.oidcFlow, - context: "mongodb-oidc-plugin:notify-device-flow", - message: "OIDC Flow changed automatically to device flow.", - }); - } - - static inferConnectionTypeFromSettings( - config: UserConfig, - settings: { connectionString: string } - ): ConnectionStringAuthType { - const connString = new ConnectionString(settings.connectionString); - const searchParams = connString.typedSearchParams(); - - switch (searchParams.get("authMechanism")) { - case "MONGODB-OIDC": { - if (config.transport === "stdio" && config.browser) { - return "oidc-auth-flow"; - } - - if (config.transport === "http" && config.httpHost === "127.0.0.1" && config.browser) { - return "oidc-auth-flow"; - } - - return "oidc-device-flow"; - } - case "MONGODB-X509": - return "x.509"; - case "GSSAPI": - return "kerberos"; - case "PLAIN": - if (searchParams.get("authSource") === "$external") { - return "ldap"; - } - return "scram"; - // default should catch also null, but eslint complains - // about it. - case null: - default: - return "scram"; - } - } - - private async pingAndForget(serviceProvider: NodeDriverServiceProvider): Promise { - try { - await serviceProvider?.runCommand?.("admin", { hello: 1 }); - } catch (error: unknown) { - this.logger.warning({ - id: LogId.oidcFlow, - context: "pingAndForget", - message: String(error), - }); - } - } - - private async disconnectOnOidcError(error: unknown): Promise { - try { - await this.disconnect(); - } catch (error: unknown) { - this.logger.warning({ - id: LogId.oidcFlow, - context: "disconnectOnOidcError", - message: String(error), - }); - } finally { - this.changeState("connection-errored", { tag: "errored", errorReason: String(error) }); - } - } -} diff --git a/src/common/session.ts b/src/common/session.ts index c09a3bcd..13b89adf 100644 --- a/src/common/session.ts +++ b/src/common/session.ts @@ -8,8 +8,8 @@ import EventEmitter from "events"; import type { AtlasClusterConnectionInfo, ConnectionManager, + ConnectionSettings, ConnectionStateConnected, - MCPConnectParams, } from "./connectionManager.js"; import type { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver"; import { ErrorCodes, MongoDBError } from "./errors.js"; @@ -139,9 +139,9 @@ export class Session extends EventEmitter { this.emit("close"); } - async connectToMongoDB(connectParams: MCPConnectParams): Promise { + async connectToMongoDB(settings: ConnectionSettings): Promise { try { - await this.connectionManager.connect({ ...connectParams }); + await this.connectionManager.connect({ ...settings }); } catch (error: unknown) { const message = error instanceof Error ? error.message : (error as string); this.emit("connection-error", message); diff --git a/src/index.ts b/src/index.ts index 6138a6e4..ccac557a 100644 --- a/src/index.ts +++ b/src/index.ts @@ -42,9 +42,8 @@ import { packageInfo } from "./common/packageInfo.js"; import { StdioRunner } from "./transports/stdio.js"; import { StreamableHttpRunner } from "./transports/streamableHttp.js"; import { systemCA } from "@mongodb-js/devtools-proxy-support"; -import type { MCPConnectParams } from "./lib.js"; import type { CreateConnectionManagerFn } from "./transports/base.js"; -import { MCPConnectionManager } from "./common/mcpConnectionManager.js"; +import { MCPConnectionManager } from "./common/connectionManager.js"; async function main(): Promise { systemCA().catch(() => undefined); // load system CA asynchronously as in mongosh @@ -52,13 +51,13 @@ async function main(): Promise { assertHelpMode(); assertVersionMode(); - const createConnectionManager: CreateConnectionManagerFn = ({ logger, deviceId }) => + const createConnectionManager: CreateConnectionManagerFn = ({ logger, deviceId }) => new MCPConnectionManager(config, driverOptions, logger, deviceId); const transportRunner = config.transport === "stdio" - ? new StdioRunner(config, createConnectionManager) - : new StreamableHttpRunner(config, createConnectionManager); + ? new StdioRunner(config, createConnectionManager) + : new StreamableHttpRunner(config, createConnectionManager); const shutdown = (): void => { transportRunner.logger.info({ id: LogId.serverCloseRequested, diff --git a/src/lib.ts b/src/lib.ts index 2c72fc20..e3af7401 100644 --- a/src/lib.ts +++ b/src/lib.ts @@ -6,7 +6,6 @@ export { StreamableHttpRunner } from "./transports/streamableHttp.js"; export { type CreateConnectionManagerFn } from "./transports/base.js"; export { ConnectionManager, - type MCPConnectParams, type AnyConnectionState, type ConnectionState, type ConnectionStateConnected, diff --git a/src/transports/base.ts b/src/transports/base.ts index ea7f571e..b7a5a982 100644 --- a/src/transports/base.ts +++ b/src/transports/base.ts @@ -7,21 +7,21 @@ import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; import type { LoggerBase } from "../common/logger.js"; import { CompositeLogger, ConsoleLogger, DiskLogger, McpLogger } from "../common/logger.js"; import { ExportsManager } from "../common/exportsManager.js"; -import type { ConnectionManager, MCPConnectParams } from "../common/connectionManager.js"; +import type { ConnectionManager } from "../common/connectionManager.js"; import { DeviceId } from "../helpers/deviceId.js"; -export type CreateConnectionManagerFn = (createParams: { +export type CreateConnectionManagerFn = (createParams: { logger: CompositeLogger; deviceId: DeviceId; -}) => ConnectionManager | Promise>; +}) => ConnectionManager | Promise; -export abstract class TransportRunnerBase { +export abstract class TransportRunnerBase { public logger: LoggerBase; public deviceId: DeviceId; constructor( protected readonly userConfig: UserConfig, - private readonly createConnectionManager: CreateConnectionManagerFn, + private readonly createConnectionManager: CreateConnectionManagerFn, additionalLoggers: LoggerBase[] ) { const loggers: LoggerBase[] = [...additionalLoggers]; diff --git a/src/transports/stdio.ts b/src/transports/stdio.ts index cc823858..e537ed1b 100644 --- a/src/transports/stdio.ts +++ b/src/transports/stdio.ts @@ -6,7 +6,6 @@ import { type LoggerBase, LogId } from "../common/logger.js"; import type { Server } from "../server.js"; import { type CreateConnectionManagerFn, TransportRunnerBase } from "./base.js"; import { type UserConfig } from "../common/config.js"; -import type { MCPConnectParams } from "../common/connectionManager.js"; // This is almost a copy of ReadBuffer from @modelcontextprotocol/sdk // but it uses EJSON.parse instead of JSON.parse to handle BSON types @@ -52,12 +51,12 @@ export function createStdioTransport(): StdioServerTransport { return server; } -export class StdioRunner extends TransportRunnerBase { +export class StdioRunner extends TransportRunnerBase { private server: Server | undefined; constructor( userConfig: UserConfig, - createConnectionManager: CreateConnectionManagerFn, + createConnectionManager: CreateConnectionManagerFn, additionalLoggers: LoggerBase[] = [] ) { super(userConfig, createConnectionManager, additionalLoggers); diff --git a/src/transports/streamableHttp.ts b/src/transports/streamableHttp.ts index b43125f8..3b4a9962 100644 --- a/src/transports/streamableHttp.ts +++ b/src/transports/streamableHttp.ts @@ -6,7 +6,6 @@ import { isInitializeRequest } from "@modelcontextprotocol/sdk/types.js"; import { LogId, type LoggerBase } from "../common/logger.js"; import { type UserConfig } from "../common/config.js"; import { SessionStore } from "../common/sessionStore.js"; -import type { MCPConnectParams } from "../common/connectionManager.js"; import { type CreateConnectionManagerFn, TransportRunnerBase } from "./base.js"; const JSON_RPC_ERROR_CODE_PROCESSING_REQUEST_FAILED = -32000; @@ -15,13 +14,13 @@ const JSON_RPC_ERROR_CODE_SESSION_ID_INVALID = -32002; const JSON_RPC_ERROR_CODE_SESSION_NOT_FOUND = -32003; const JSON_RPC_ERROR_CODE_INVALID_REQUEST = -32004; -export class StreamableHttpRunner extends TransportRunnerBase { +export class StreamableHttpRunner extends TransportRunnerBase { private httpServer: http.Server | undefined; private sessionStore!: SessionStore; constructor( userConfig: UserConfig, - createConnectionManager: CreateConnectionManagerFn, + createConnectionManager: CreateConnectionManagerFn, additionalLoggers: LoggerBase[] = [] ) { super(userConfig, createConnectionManager, additionalLoggers); diff --git a/tests/integration/common/connectionManager.test.ts b/tests/integration/common/connectionManager.test.ts index a5815b40..edda9b8a 100644 --- a/tests/integration/common/connectionManager.test.ts +++ b/tests/integration/common/connectionManager.test.ts @@ -4,7 +4,7 @@ import type { ConnectionStateConnected, ConnectionStringAuthType, } from "../../../src/common/connectionManager.js"; -import { MCPConnectionManager } from "../../../src/common/mcpConnectionManager.js"; +import { MCPConnectionManager } from "../../../src/common/connectionManager.js"; import type { UserConfig } from "../../../src/common/config.js"; import { describeWithMongoDB } from "../tools/mongodb/mongodbHelpers.js"; import { describe, beforeEach, expect, it, vi, afterEach } from "vitest"; diff --git a/tests/integration/helpers.ts b/tests/integration/helpers.ts index 3c646f49..e4913f6d 100644 --- a/tests/integration/helpers.ts +++ b/tests/integration/helpers.ts @@ -11,7 +11,7 @@ import { McpError, ResourceUpdatedNotificationSchema } from "@modelcontextprotoc import { config, driverOptions } from "../../src/common/config.js"; import { afterAll, afterEach, beforeAll, describe, expect, it, vi } from "vitest"; import type { ConnectionManager, ConnectionState } from "../../src/common/connectionManager.js"; -import { MCPConnectionManager } from "../../src/common/mcpConnectionManager.js"; +import { MCPConnectionManager } from "../../src/common/connectionManager.js"; import { DeviceId } from "../../src/helpers/deviceId.js"; interface ParameterInfo { diff --git a/tests/integration/telemetry.test.ts b/tests/integration/telemetry.test.ts index 29a78469..b63a3796 100644 --- a/tests/integration/telemetry.test.ts +++ b/tests/integration/telemetry.test.ts @@ -4,7 +4,7 @@ import { config, driverOptions } from "../../src/common/config.js"; import { DeviceId } from "../../src/helpers/deviceId.js"; import { describe, expect, it } from "vitest"; import { CompositeLogger } from "../../src/common/logger.js"; -import { MCPConnectionManager } from "../../src/common/mcpConnectionManager.js"; +import { MCPConnectionManager } from "../../src/common/connectionManager.js"; import { ExportsManager } from "../../src/common/exportsManager.js"; describe("Telemetry", () => { diff --git a/tests/integration/transports/streamableHttp.test.ts b/tests/integration/transports/streamableHttp.test.ts index eb1d42b4..3854a2d0 100644 --- a/tests/integration/transports/streamableHttp.test.ts +++ b/tests/integration/transports/streamableHttp.test.ts @@ -5,11 +5,10 @@ import { describe, expect, it, beforeAll, afterAll, beforeEach } from "vitest"; import { config, driverOptions } from "../../../src/common/config.js"; import type { LoggerType, LogLevel, LogPayload } from "../../../src/common/logger.js"; import { LoggerBase, LogId } from "../../../src/common/logger.js"; -import { MCPConnectionManager } from "../../../src/common/mcpConnectionManager.js"; -import type { MCPConnectParams } from "../../../src/lib.js"; +import { MCPConnectionManager } from "../../../src/common/connectionManager.js"; describe("StreamableHttpRunner", () => { - let runner: StreamableHttpRunner; + let runner: StreamableHttpRunner; let oldTelemetry: "enabled" | "disabled"; let oldLoggers: ("stderr" | "disk" | "mcp")[]; @@ -110,7 +109,7 @@ describe("StreamableHttpRunner", () => { } it("can create multiple runners", async () => { - const runners: StreamableHttpRunner[] = []; + const runners: StreamableHttpRunner[] = []; try { for (let i = 0; i < 3; i++) { config.httpPort = 0; // Use a random port for each runner diff --git a/tests/unit/common/session.test.ts b/tests/unit/common/session.test.ts index 29075969..53129d9e 100644 --- a/tests/unit/common/session.test.ts +++ b/tests/unit/common/session.test.ts @@ -4,7 +4,7 @@ import { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver import { Session } from "../../../src/common/session.js"; import { config, driverOptions } from "../../../src/common/config.js"; import { CompositeLogger } from "../../../src/common/logger.js"; -import { MCPConnectionManager } from "../../../src/common/mcpConnectionManager.js"; +import { MCPConnectionManager } from "../../../src/common/connectionManager.js"; import { ExportsManager } from "../../../src/common/exportsManager.js"; import { DeviceId } from "../../../src/helpers/deviceId.js"; diff --git a/tests/unit/resources/common/debug.test.ts b/tests/unit/resources/common/debug.test.ts index 59c0d3aa..d691bf31 100644 --- a/tests/unit/resources/common/debug.test.ts +++ b/tests/unit/resources/common/debug.test.ts @@ -4,7 +4,7 @@ import { Session } from "../../../../src/common/session.js"; import { Telemetry } from "../../../../src/telemetry/telemetry.js"; import { config, driverOptions } from "../../../../src/common/config.js"; import { CompositeLogger } from "../../../../src/common/logger.js"; -import { MCPConnectionManager } from "../../../../src/common/mcpConnectionManager.js"; +import { MCPConnectionManager } from "../../../../src/common/connectionManager.js"; import { ExportsManager } from "../../../../src/common/exportsManager.js"; import { DeviceId } from "../../../../src/helpers/deviceId.js"; From b4cbe994ffb5b65785b74518dab61e6fb28f6f42 Mon Sep 17 00:00:00 2001 From: Himanshu Singh Date: Thu, 28 Aug 2025 13:57:39 +0200 Subject: [PATCH 06/13] chore: further harmonisation with old code --- src/common/connectionManager.ts | 30 +++++++++++++++--------------- src/transports/base.ts | 2 +- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/src/common/connectionManager.ts b/src/common/connectionManager.ts index 9bea9daf..2f9159ca 100644 --- a/src/common/connectionManager.ts +++ b/src/common/connectionManager.ts @@ -17,8 +17,13 @@ export interface AtlasClusterConnectionInfo { expiryDate: Date; } +export interface ConnectionSettings { + connectionString: string; + atlas?: AtlasClusterConnectionInfo; +} + type ConnectionTag = "connected" | "connecting" | "disconnected" | "errored"; -export type OIDCConnectionAuthType = "oidc-auth-flow" | "oidc-device-flow"; +type OIDCConnectionAuthType = "oidc-auth-flow" | "oidc-device-flow"; export type ConnectionStringAuthType = "scram" | "ldap" | "kerberos" | OIDCConnectionAuthType | "x.509"; export interface ConnectionState { @@ -63,11 +68,6 @@ export interface ConnectionManagerEvents { "connection-errored": [ConnectionStateErrored]; } -export interface ConnectionSettings { - connectionString: string; - atlas?: AtlasClusterConnectionInfo; -} - export abstract class ConnectionManager { protected clientName: string = "unknown"; @@ -119,7 +119,7 @@ export class MCPConnectionManager extends ConnectionManager { this.clientName = "unknown"; } - async connect(connectParams: ConnectionSettings): Promise { + async connect(settings: ConnectionSettings): Promise { this._events.emit("connection-requested", this.state); if (this.state.tag === "connected" || this.state.tag === "connecting") { @@ -130,22 +130,22 @@ export class MCPConnectionManager extends ConnectionManager { let connectionInfo: ConnectionInfo; try { - connectParams = { ...connectParams }; + settings = { ...settings }; const appNameComponents: AppNameComponents = { appName: `${packageInfo.mcpServerName} ${packageInfo.version}`, deviceId: this.deviceId.get(), clientName: this.clientName, }; - connectParams.connectionString = await setAppNameParamIfMissing({ - connectionString: connectParams.connectionString, + settings.connectionString = await setAppNameParamIfMissing({ + connectionString: settings.connectionString, components: appNameComponents, }); connectionInfo = generateConnectionInfoFromCliArgs({ ...this.userConfig, ...this.driverOptions, - connectionSpecifier: connectParams.connectionString, + connectionSpecifier: settings.connectionString, }); if (connectionInfo.driverOptions.oidc) { @@ -171,7 +171,7 @@ export class MCPConnectionManager extends ConnectionManager { this.changeState("connection-errored", { tag: "errored", errorReason, - connectedAtlasCluster: connectParams.atlas, + connectedAtlasCluster: settings.atlas, }); throw new MongoDBError(ErrorCodes.MisconfiguredConnectionString, errorReason); } @@ -186,7 +186,7 @@ export class MCPConnectionManager extends ConnectionManager { return this.changeState("connection-requested", { tag: "connecting", - connectedAtlasCluster: connectParams.atlas, + connectedAtlasCluster: settings.atlas, serviceProvider, connectionStringAuthType: connectionType, oidcConnectionType: connectionType as OIDCConnectionAuthType, @@ -197,7 +197,7 @@ export class MCPConnectionManager extends ConnectionManager { return this.changeState("connection-succeeded", { tag: "connected", - connectedAtlasCluster: connectParams.atlas, + connectedAtlasCluster: settings.atlas, serviceProvider, connectionStringAuthType: connectionType, }); @@ -206,7 +206,7 @@ export class MCPConnectionManager extends ConnectionManager { this.changeState("connection-errored", { tag: "errored", errorReason, - connectedAtlasCluster: connectParams.atlas, + connectedAtlasCluster: settings.atlas, }); throw new MongoDBError(ErrorCodes.NotConnectedToMongoDB, errorReason); } diff --git a/src/transports/base.ts b/src/transports/base.ts index b7a5a982..49c1474e 100644 --- a/src/transports/base.ts +++ b/src/transports/base.ts @@ -19,7 +19,7 @@ export abstract class TransportRunnerBase { public logger: LoggerBase; public deviceId: DeviceId; - constructor( + protected constructor( protected readonly userConfig: UserConfig, private readonly createConnectionManager: CreateConnectionManagerFn, additionalLoggers: LoggerBase[] From 3637c4c1ddcf3edd0a89c253354b344297f5bf55 Mon Sep 17 00:00:00 2001 From: Himanshu Singh Date: Thu, 28 Aug 2025 15:10:33 +0200 Subject: [PATCH 07/13] chore: changeState does not have to be public --- src/common/connectionManager.ts | 14 +++++++++++++- .../common/connectionManager.oidc.test.ts | 9 +++++++-- tests/integration/common/connectionManager.test.ts | 6 +++--- 3 files changed, 23 insertions(+), 6 deletions(-) diff --git a/src/common/connectionManager.ts b/src/common/connectionManager.ts index 2f9159ca..afdb7e04 100644 --- a/src/common/connectionManager.ts +++ b/src/common/connectionManager.ts @@ -68,6 +68,18 @@ export interface ConnectionManagerEvents { "connection-errored": [ConnectionStateErrored]; } +/** + * For a few tests, we need the changeState method to force a connection state + * which is we have this type to typecast the actual ConnectionManager with + * public changeState (only to make TS happy). + */ +export type TestConnectionManager = ConnectionManager & { + changeState( + event: Event, + newState: State + ): State; +}; + export abstract class ConnectionManager { protected clientName: string = "unknown"; @@ -80,7 +92,7 @@ export abstract class ConnectionManager { return this.state; } - changeState( + protected changeState( event: Event, newState: State ): State { diff --git a/tests/integration/common/connectionManager.oidc.test.ts b/tests/integration/common/connectionManager.oidc.test.ts index e3a406eb..c8d47d59 100644 --- a/tests/integration/common/connectionManager.oidc.test.ts +++ b/tests/integration/common/connectionManager.oidc.test.ts @@ -5,7 +5,11 @@ import process from "process"; import type { MongoDBIntegrationTestCase } from "../tools/mongodb/mongodbHelpers.js"; import { describeWithMongoDB, isCommunityServer, getServerVersion } from "../tools/mongodb/mongodbHelpers.js"; import { defaultTestConfig, responseAsText, timeout, waitUntil } from "../helpers.js"; -import type { ConnectionStateConnected, ConnectionStateConnecting } from "../../../src/common/connectionManager.js"; +import type { + ConnectionStateConnected, + ConnectionStateConnecting, + TestConnectionManager, +} from "../../../src/common/connectionManager.js"; import type { UserConfig } from "../../../src/common/config.js"; import { setupDriverConfig } from "../../../src/common/config.js"; import path from "path"; @@ -122,7 +126,8 @@ describe.skipIf(process.platform !== "linux")("ConnectionManager OIDC Tests", as } beforeEach(async () => { - const connectionManager = integration.mcpServer().session.connectionManager; + const connectionManager = integration.mcpServer().session + .connectionManager as TestConnectionManager; // disconnect on purpose doesn't change the state if it was failed to avoid losing // information in production. await connectionManager.disconnect(); diff --git a/tests/integration/common/connectionManager.test.ts b/tests/integration/common/connectionManager.test.ts index edda9b8a..62c3fba9 100644 --- a/tests/integration/common/connectionManager.test.ts +++ b/tests/integration/common/connectionManager.test.ts @@ -1,8 +1,8 @@ import type { - ConnectionManager, ConnectionManagerEvents, ConnectionStateConnected, ConnectionStringAuthType, + TestConnectionManager, } from "../../../src/common/connectionManager.js"; import { MCPConnectionManager } from "../../../src/common/connectionManager.js"; import type { UserConfig } from "../../../src/common/config.js"; @@ -10,8 +10,8 @@ import { describeWithMongoDB } from "../tools/mongodb/mongodbHelpers.js"; import { describe, beforeEach, expect, it, vi, afterEach } from "vitest"; describeWithMongoDB("Connection Manager", (integration) => { - function connectionManager(): ConnectionManager { - return integration.mcpServer().session.connectionManager; + function connectionManager(): TestConnectionManager { + return integration.mcpServer().session.connectionManager as TestConnectionManager; } afterEach(async () => { From 65c5332863e4ef115abdd62b4730eeb072edcd20 Mon Sep 17 00:00:00 2001 From: Himanshu Singh Date: Thu, 28 Aug 2025 15:52:00 +0200 Subject: [PATCH 08/13] chore: PR feedback --- src/index.ts | 6 +++--- src/lib.ts | 2 +- src/transports/base.ts | 6 +++--- src/transports/stdio.ts | 4 ++-- src/transports/streamableHttp.ts | 4 ++-- 5 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/index.ts b/src/index.ts index ccac557a..e8a60829 100644 --- a/src/index.ts +++ b/src/index.ts @@ -42,7 +42,7 @@ import { packageInfo } from "./common/packageInfo.js"; import { StdioRunner } from "./transports/stdio.js"; import { StreamableHttpRunner } from "./transports/streamableHttp.js"; import { systemCA } from "@mongodb-js/devtools-proxy-support"; -import type { CreateConnectionManagerFn } from "./transports/base.js"; +import type { ConnectionManagerFactoryFn } from "./transports/base.js"; import { MCPConnectionManager } from "./common/connectionManager.js"; async function main(): Promise { @@ -51,8 +51,8 @@ async function main(): Promise { assertHelpMode(); assertVersionMode(); - const createConnectionManager: CreateConnectionManagerFn = ({ logger, deviceId }) => - new MCPConnectionManager(config, driverOptions, logger, deviceId); + const createConnectionManager: ConnectionManagerFactoryFn = ({ logger, deviceId }) => + Promise.resolve(new MCPConnectionManager(config, driverOptions, logger, deviceId)); const transportRunner = config.transport === "stdio" diff --git a/src/lib.ts b/src/lib.ts index e3af7401..ae490d24 100644 --- a/src/lib.ts +++ b/src/lib.ts @@ -3,7 +3,7 @@ export { Session, type SessionOptions } from "./common/session.js"; export { defaultUserConfig, type UserConfig } from "./common/config.js"; export { LoggerBase, CompositeLogger, type LogPayload, type LoggerType, type LogLevel } from "./common/logger.js"; export { StreamableHttpRunner } from "./transports/streamableHttp.js"; -export { type CreateConnectionManagerFn } from "./transports/base.js"; +export { type ConnectionManagerFactoryFn } from "./transports/base.js"; export { ConnectionManager, type AnyConnectionState, diff --git a/src/transports/base.ts b/src/transports/base.ts index 49c1474e..4e215093 100644 --- a/src/transports/base.ts +++ b/src/transports/base.ts @@ -10,10 +10,10 @@ import { ExportsManager } from "../common/exportsManager.js"; import type { ConnectionManager } from "../common/connectionManager.js"; import { DeviceId } from "../helpers/deviceId.js"; -export type CreateConnectionManagerFn = (createParams: { +export type ConnectionManagerFactoryFn = (createParams: { logger: CompositeLogger; deviceId: DeviceId; -}) => ConnectionManager | Promise; +}) => Promise; export abstract class TransportRunnerBase { public logger: LoggerBase; @@ -21,7 +21,7 @@ export abstract class TransportRunnerBase { protected constructor( protected readonly userConfig: UserConfig, - private readonly createConnectionManager: CreateConnectionManagerFn, + private readonly createConnectionManager: ConnectionManagerFactoryFn, additionalLoggers: LoggerBase[] ) { const loggers: LoggerBase[] = [...additionalLoggers]; diff --git a/src/transports/stdio.ts b/src/transports/stdio.ts index e537ed1b..009a9c01 100644 --- a/src/transports/stdio.ts +++ b/src/transports/stdio.ts @@ -4,7 +4,7 @@ import { JSONRPCMessageSchema } from "@modelcontextprotocol/sdk/types.js"; import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio.js"; import { type LoggerBase, LogId } from "../common/logger.js"; import type { Server } from "../server.js"; -import { type CreateConnectionManagerFn, TransportRunnerBase } from "./base.js"; +import { type ConnectionManagerFactoryFn, TransportRunnerBase } from "./base.js"; import { type UserConfig } from "../common/config.js"; // This is almost a copy of ReadBuffer from @modelcontextprotocol/sdk @@ -56,7 +56,7 @@ export class StdioRunner extends TransportRunnerBase { constructor( userConfig: UserConfig, - createConnectionManager: CreateConnectionManagerFn, + createConnectionManager: ConnectionManagerFactoryFn, additionalLoggers: LoggerBase[] = [] ) { super(userConfig, createConnectionManager, additionalLoggers); diff --git a/src/transports/streamableHttp.ts b/src/transports/streamableHttp.ts index 3b4a9962..c06a6b97 100644 --- a/src/transports/streamableHttp.ts +++ b/src/transports/streamableHttp.ts @@ -6,7 +6,7 @@ import { isInitializeRequest } from "@modelcontextprotocol/sdk/types.js"; import { LogId, type LoggerBase } from "../common/logger.js"; import { type UserConfig } from "../common/config.js"; import { SessionStore } from "../common/sessionStore.js"; -import { type CreateConnectionManagerFn, TransportRunnerBase } from "./base.js"; +import { type ConnectionManagerFactoryFn, TransportRunnerBase } from "./base.js"; const JSON_RPC_ERROR_CODE_PROCESSING_REQUEST_FAILED = -32000; const JSON_RPC_ERROR_CODE_SESSION_ID_REQUIRED = -32001; @@ -20,7 +20,7 @@ export class StreamableHttpRunner extends TransportRunnerBase { constructor( userConfig: UserConfig, - createConnectionManager: CreateConnectionManagerFn, + createConnectionManager: ConnectionManagerFactoryFn, additionalLoggers: LoggerBase[] = [] ) { super(userConfig, createConnectionManager, additionalLoggers); From 71680ab3be2ebd3a22967358c2515a8278135d12 Mon Sep 17 00:00:00 2001 From: Himanshu Singh Date: Thu, 28 Aug 2025 16:01:33 +0200 Subject: [PATCH 09/13] chore: test check fixes --- tests/integration/transports/streamableHttp.test.ts | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/tests/integration/transports/streamableHttp.test.ts b/tests/integration/transports/streamableHttp.test.ts index 3854a2d0..70185203 100644 --- a/tests/integration/transports/streamableHttp.test.ts +++ b/tests/integration/transports/streamableHttp.test.ts @@ -29,9 +29,8 @@ describe("StreamableHttpRunner", () => { describe(description, () => { beforeAll(async () => { config.httpHeaders = headers; - runner = new StreamableHttpRunner( - config, - ({ logger, deviceId }) => new MCPConnectionManager(config, driverOptions, logger, deviceId) + runner = new StreamableHttpRunner(config, ({ logger, deviceId }) => + Promise.resolve(new MCPConnectionManager(config, driverOptions, logger, deviceId)) ); await runner.start(); }); @@ -113,9 +112,8 @@ describe("StreamableHttpRunner", () => { try { for (let i = 0; i < 3; i++) { config.httpPort = 0; // Use a random port for each runner - const runner = new StreamableHttpRunner( - config, - ({ logger, deviceId }) => new MCPConnectionManager(config, driverOptions, logger, deviceId) + const runner = new StreamableHttpRunner(config, ({ logger, deviceId }) => + Promise.resolve(new MCPConnectionManager(config, driverOptions, logger, deviceId)) ); await runner.start(); runners.push(runner); @@ -147,7 +145,8 @@ describe("StreamableHttpRunner", () => { const logger = new CustomLogger(); const runner = new StreamableHttpRunner( config, - ({ logger, deviceId }) => new MCPConnectionManager(config, driverOptions, logger, deviceId), + ({ logger, deviceId }) => + Promise.resolve(new MCPConnectionManager(config, driverOptions, logger, deviceId)), [logger] ); await runner.start(); From b079a6deb5eec661beffd6ef031da6f53b2e7550 Mon Sep 17 00:00:00 2001 From: Himanshu Singh Date: Thu, 28 Aug 2025 18:02:07 +0200 Subject: [PATCH 10/13] chore: further feedback --- src/common/connectionManager.ts | 45 +++++++++++++++++++++------------ 1 file changed, 29 insertions(+), 16 deletions(-) diff --git a/src/common/connectionManager.ts b/src/common/connectionManager.ts index afdb7e04..d54538fa 100644 --- a/src/common/connectionManager.ts +++ b/src/common/connectionManager.ts @@ -81,12 +81,16 @@ export type TestConnectionManager = ConnectionManager & { }; export abstract class ConnectionManager { - protected clientName: string = "unknown"; + protected clientName: string; + protected readonly _events; + readonly events: Pick, "on" | "off" | "once">; + private state: AnyConnectionState; - protected readonly _events = new EventEmitter(); - readonly events: Pick, "on" | "off" | "once"> = this._events; - - protected state: AnyConnectionState = { tag: "disconnected" }; + constructor() { + this.clientName = "unknown"; + this.events = this._events = new EventEmitter(); + this.state = { tag: "disconnected" }; + } get currentConnectionState(): AnyConnectionState { return this.state; @@ -132,9 +136,9 @@ export class MCPConnectionManager extends ConnectionManager { } async connect(settings: ConnectionSettings): Promise { - this._events.emit("connection-requested", this.state); + this._events.emit("connection-requested", this.currentConnectionState); - if (this.state.tag === "connected" || this.state.tag === "connecting") { + if (this.currentConnectionState.tag === "connected" || this.currentConnectionState.tag === "connecting") { await this.disconnect(); } @@ -225,13 +229,13 @@ export class MCPConnectionManager extends ConnectionManager { } async disconnect(): Promise { - if (this.state.tag === "disconnected" || this.state.tag === "errored") { - return this.state; + if (this.currentConnectionState.tag === "disconnected" || this.currentConnectionState.tag === "errored") { + return this.currentConnectionState; } - if (this.state.tag === "connected" || this.state.tag === "connecting") { + if (this.currentConnectionState.tag === "connected" || this.currentConnectionState.tag === "connecting") { try { - await this.state.serviceProvider?.close(true); + await this.currentConnectionState.serviceProvider?.close(true); } finally { this.changeState("connection-closed", { tag: "disconnected", @@ -243,14 +247,20 @@ export class MCPConnectionManager extends ConnectionManager { } private onOidcAuthFailed(error: unknown): void { - if (this.state.tag === "connecting" && this.state.connectionStringAuthType?.startsWith("oidc")) { + if ( + this.currentConnectionState.tag === "connecting" && + this.currentConnectionState.connectionStringAuthType?.startsWith("oidc") + ) { void this.disconnectOnOidcError(error); } } private onOidcAuthSucceeded(): void { - if (this.state.tag === "connecting" && this.state.connectionStringAuthType?.startsWith("oidc")) { - this.changeState("connection-succeeded", { ...this.state, tag: "connected" }); + if ( + this.currentConnectionState.tag === "connecting" && + this.currentConnectionState.connectionStringAuthType?.startsWith("oidc") + ) { + this.changeState("connection-succeeded", { ...this.currentConnectionState, tag: "connected" }); } this.logger.info({ @@ -261,9 +271,12 @@ export class MCPConnectionManager extends ConnectionManager { } private onOidcNotifyDeviceFlow(flowInfo: { verificationUrl: string; userCode: string }): void { - if (this.state.tag === "connecting" && this.state.connectionStringAuthType?.startsWith("oidc")) { + if ( + this.currentConnectionState.tag === "connecting" && + this.currentConnectionState.connectionStringAuthType?.startsWith("oidc") + ) { this.changeState("connection-requested", { - ...this.state, + ...this.currentConnectionState, tag: "connecting", connectionStringAuthType: "oidc-device-flow", oidcLoginUrl: flowInfo.verificationUrl, From c178465352f6aaeb83e124387c1f571d1e51ac46 Mon Sep 17 00:00:00 2001 From: Himanshu Singh Date: Thu, 28 Aug 2025 18:45:47 +0200 Subject: [PATCH 11/13] chore: further feedback --- src/common/connectionManager.ts | 5 ++--- src/lib.ts | 5 +---- src/transports/base.ts | 2 +- tests/integration/build.test.ts | 1 - 4 files changed, 4 insertions(+), 9 deletions(-) diff --git a/src/common/connectionManager.ts b/src/common/connectionManager.ts index ae6190b2..199f58e4 100644 --- a/src/common/connectionManager.ts +++ b/src/common/connectionManager.ts @@ -6,7 +6,7 @@ import { type ConnectionInfo, generateConnectionInfoFromCliArgs } from "@mongosh import type { DeviceId } from "../helpers/deviceId.js"; import type { DriverOptions, UserConfig } from "./config.js"; import { MongoDBError, ErrorCodes } from "./errors.js"; -import { type CompositeLogger, LogId } from "./logger.js"; +import { type LoggerBase, LogId } from "./logger.js"; import { packageInfo } from "./packageInfo.js"; import { type AppNameComponents, setAppNameParamIfMissing } from "../helpers/connectionOptions.js"; @@ -123,7 +123,7 @@ export class MCPConnectionManager extends ConnectionManager { constructor( private userConfig: UserConfig, private driverOptions: DriverOptions, - private logger: CompositeLogger, + private logger: LoggerBase, deviceId: DeviceId, bus?: EventEmitter ) { @@ -132,7 +132,6 @@ export class MCPConnectionManager extends ConnectionManager { this.bus.on("mongodb-oidc-plugin:auth-failed", this.onOidcAuthFailed.bind(this)); this.bus.on("mongodb-oidc-plugin:auth-succeeded", this.onOidcAuthSucceeded.bind(this)); this.deviceId = deviceId; - this.clientName = "unknown"; } async connect(settings: ConnectionSettings): Promise { diff --git a/src/lib.ts b/src/lib.ts index ae490d24..b5d752c2 100644 --- a/src/lib.ts +++ b/src/lib.ts @@ -1,16 +1,13 @@ export { Server, type ServerOptions } from "./server.js"; export { Session, type SessionOptions } from "./common/session.js"; export { defaultUserConfig, type UserConfig } from "./common/config.js"; -export { LoggerBase, CompositeLogger, type LogPayload, type LoggerType, type LogLevel } from "./common/logger.js"; +export { LoggerBase, type LogPayload, type LoggerType, type LogLevel } from "./common/logger.js"; export { StreamableHttpRunner } from "./transports/streamableHttp.js"; export { type ConnectionManagerFactoryFn } from "./transports/base.js"; export { ConnectionManager, type AnyConnectionState, type ConnectionState, - type ConnectionStateConnected, - type ConnectionStateConnecting, type ConnectionStateDisconnected, - type ConnectionStateErrored, } from "./common/connectionManager.js"; export { Telemetry } from "./telemetry/telemetry.js"; diff --git a/src/transports/base.ts b/src/transports/base.ts index 4e215093..671d0995 100644 --- a/src/transports/base.ts +++ b/src/transports/base.ts @@ -11,7 +11,7 @@ import type { ConnectionManager } from "../common/connectionManager.js"; import { DeviceId } from "../helpers/deviceId.js"; export type ConnectionManagerFactoryFn = (createParams: { - logger: CompositeLogger; + logger: LoggerBase; deviceId: DeviceId; }) => Promise; diff --git a/tests/integration/build.test.ts b/tests/integration/build.test.ts index ef328320..7453cb3d 100644 --- a/tests/integration/build.test.ts +++ b/tests/integration/build.test.ts @@ -43,7 +43,6 @@ describe("Build Test", () => { expect(cjsKeys).toEqual(esmKeys); expect(cjsKeys).toEqual( expect.arrayContaining([ - "CompositeLogger", "ConnectionManager", "LoggerBase", "Server", From b6ade3cfc0797d2e543c5fc086695899077dce0e Mon Sep 17 00:00:00 2001 From: Himanshu Singh Date: Fri, 29 Aug 2025 12:10:21 +0200 Subject: [PATCH 12/13] chore: have a default connection manager factory --- eslint-rules/no-config-imports.js | 4 ++++ src/common/connectionManager.ts | 22 +++++++++++++++++++++- src/index.ts | 12 ++---------- src/lib.ts | 3 ++- src/transports/base.ts | 13 ++++++------- src/transports/stdio.ts | 5 +++-- src/transports/streamableHttp.ts | 5 +++-- 7 files changed, 41 insertions(+), 23 deletions(-) diff --git a/eslint-rules/no-config-imports.js b/eslint-rules/no-config-imports.js index 908dd5ae..5c4efb7c 100644 --- a/eslint-rules/no-config-imports.js +++ b/eslint-rules/no-config-imports.js @@ -10,6 +10,10 @@ const allowedConfigValueImportFiles = [ "src/index.ts", // Config resource definition that works with the some config values "src/resources/common/config.ts", + // The file exports, a factory function to create MCPConnectionManager and + // it relies on driver options generator and default driver options from + // config file. + "src/common/connectionManager.ts", ]; // Ref: https://eslint.org/docs/latest/extend/custom-rules diff --git a/src/common/connectionManager.ts b/src/common/connectionManager.ts index 199f58e4..78e51edb 100644 --- a/src/common/connectionManager.ts +++ b/src/common/connectionManager.ts @@ -4,7 +4,7 @@ import ConnectionString from "mongodb-connection-string-url"; import { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver"; import { type ConnectionInfo, generateConnectionInfoFromCliArgs } from "@mongosh/arg-parser"; import type { DeviceId } from "../helpers/deviceId.js"; -import type { DriverOptions, UserConfig } from "./config.js"; +import { defaultDriverOptions, setupDriverConfig, type DriverOptions, type UserConfig } from "./config.js"; import { MongoDBError, ErrorCodes } from "./errors.js"; import { type LoggerBase, LogId } from "./logger.js"; import { packageInfo } from "./packageInfo.js"; @@ -360,3 +360,23 @@ export class MCPConnectionManager extends ConnectionManager { } } } + +/** + * Consumers of MCP server library have option to bring their own connection + * management if they need to. To support that, we enable injecting connection + * manager implementation through a factory function. + */ +export type ConnectionManagerFactoryFn = (createParams: { + logger: LoggerBase; + deviceId: DeviceId; + userConfig: UserConfig; +}) => Promise; + +export const createMCPConnectionManager: ConnectionManagerFactoryFn = ({ logger, deviceId, userConfig }) => { + const driverOptions = setupDriverConfig({ + config: userConfig, + defaults: defaultDriverOptions, + }); + + return Promise.resolve(new MCPConnectionManager(userConfig, driverOptions, logger, deviceId)); +}; diff --git a/src/index.ts b/src/index.ts index e8a60829..6a7150e3 100644 --- a/src/index.ts +++ b/src/index.ts @@ -36,14 +36,12 @@ function enableFipsIfRequested(): void { enableFipsIfRequested(); import { ConsoleLogger, LogId } from "./common/logger.js"; -import { config, driverOptions } from "./common/config.js"; +import { config } from "./common/config.js"; import crypto from "crypto"; import { packageInfo } from "./common/packageInfo.js"; import { StdioRunner } from "./transports/stdio.js"; import { StreamableHttpRunner } from "./transports/streamableHttp.js"; import { systemCA } from "@mongodb-js/devtools-proxy-support"; -import type { ConnectionManagerFactoryFn } from "./transports/base.js"; -import { MCPConnectionManager } from "./common/connectionManager.js"; async function main(): Promise { systemCA().catch(() => undefined); // load system CA asynchronously as in mongosh @@ -51,13 +49,7 @@ async function main(): Promise { assertHelpMode(); assertVersionMode(); - const createConnectionManager: ConnectionManagerFactoryFn = ({ logger, deviceId }) => - Promise.resolve(new MCPConnectionManager(config, driverOptions, logger, deviceId)); - - const transportRunner = - config.transport === "stdio" - ? new StdioRunner(config, createConnectionManager) - : new StreamableHttpRunner(config, createConnectionManager); + const transportRunner = config.transport === "stdio" ? new StdioRunner(config) : new StreamableHttpRunner(config); const shutdown = (): void => { transportRunner.logger.info({ id: LogId.serverCloseRequested, diff --git a/src/lib.ts b/src/lib.ts index b5d752c2..01dc8b88 100644 --- a/src/lib.ts +++ b/src/lib.ts @@ -3,11 +3,12 @@ export { Session, type SessionOptions } from "./common/session.js"; export { defaultUserConfig, type UserConfig } from "./common/config.js"; export { LoggerBase, type LogPayload, type LoggerType, type LogLevel } from "./common/logger.js"; export { StreamableHttpRunner } from "./transports/streamableHttp.js"; -export { type ConnectionManagerFactoryFn } from "./transports/base.js"; export { ConnectionManager, type AnyConnectionState, type ConnectionState, type ConnectionStateDisconnected, + type ConnectionStateErrored, + type ConnectionManagerFactoryFn, } from "./common/connectionManager.js"; export { Telemetry } from "./telemetry/telemetry.js"; diff --git a/src/transports/base.ts b/src/transports/base.ts index 671d0995..d6fc53ad 100644 --- a/src/transports/base.ts +++ b/src/transports/base.ts @@ -7,13 +7,8 @@ import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; import type { LoggerBase } from "../common/logger.js"; import { CompositeLogger, ConsoleLogger, DiskLogger, McpLogger } from "../common/logger.js"; import { ExportsManager } from "../common/exportsManager.js"; -import type { ConnectionManager } from "../common/connectionManager.js"; import { DeviceId } from "../helpers/deviceId.js"; - -export type ConnectionManagerFactoryFn = (createParams: { - logger: LoggerBase; - deviceId: DeviceId; -}) => Promise; +import { type ConnectionManagerFactoryFn } from "../common/connectionManager.js"; export abstract class TransportRunnerBase { public logger: LoggerBase; @@ -51,7 +46,11 @@ export abstract class TransportRunnerBase { const logger = new CompositeLogger(this.logger); const exportsManager = ExportsManager.init(this.userConfig, logger); - const connectionManager = await this.createConnectionManager({ logger, deviceId: this.deviceId }); + const connectionManager = await this.createConnectionManager({ + logger, + userConfig: this.userConfig, + deviceId: this.deviceId, + }); const session = new Session({ apiBaseUrl: this.userConfig.apiBaseUrl, diff --git a/src/transports/stdio.ts b/src/transports/stdio.ts index 009a9c01..4ed941ef 100644 --- a/src/transports/stdio.ts +++ b/src/transports/stdio.ts @@ -4,8 +4,9 @@ import { JSONRPCMessageSchema } from "@modelcontextprotocol/sdk/types.js"; import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio.js"; import { type LoggerBase, LogId } from "../common/logger.js"; import type { Server } from "../server.js"; -import { type ConnectionManagerFactoryFn, TransportRunnerBase } from "./base.js"; +import { TransportRunnerBase } from "./base.js"; import { type UserConfig } from "../common/config.js"; +import { createMCPConnectionManager, type ConnectionManagerFactoryFn } from "../common/connectionManager.js"; // This is almost a copy of ReadBuffer from @modelcontextprotocol/sdk // but it uses EJSON.parse instead of JSON.parse to handle BSON types @@ -56,7 +57,7 @@ export class StdioRunner extends TransportRunnerBase { constructor( userConfig: UserConfig, - createConnectionManager: ConnectionManagerFactoryFn, + createConnectionManager: ConnectionManagerFactoryFn = createMCPConnectionManager, additionalLoggers: LoggerBase[] = [] ) { super(userConfig, createConnectionManager, additionalLoggers); diff --git a/src/transports/streamableHttp.ts b/src/transports/streamableHttp.ts index c06a6b97..4e8aebb8 100644 --- a/src/transports/streamableHttp.ts +++ b/src/transports/streamableHttp.ts @@ -6,7 +6,8 @@ import { isInitializeRequest } from "@modelcontextprotocol/sdk/types.js"; import { LogId, type LoggerBase } from "../common/logger.js"; import { type UserConfig } from "../common/config.js"; import { SessionStore } from "../common/sessionStore.js"; -import { type ConnectionManagerFactoryFn, TransportRunnerBase } from "./base.js"; +import { TransportRunnerBase } from "./base.js"; +import { createMCPConnectionManager, type ConnectionManagerFactoryFn } from "../common/connectionManager.js"; const JSON_RPC_ERROR_CODE_PROCESSING_REQUEST_FAILED = -32000; const JSON_RPC_ERROR_CODE_SESSION_ID_REQUIRED = -32001; @@ -20,7 +21,7 @@ export class StreamableHttpRunner extends TransportRunnerBase { constructor( userConfig: UserConfig, - createConnectionManager: ConnectionManagerFactoryFn, + createConnectionManager: ConnectionManagerFactoryFn = createMCPConnectionManager, additionalLoggers: LoggerBase[] = [] ) { super(userConfig, createConnectionManager, additionalLoggers); From 280d8b1ceda0b7dadec42daee693b9f03674d950 Mon Sep 17 00:00:00 2001 From: Himanshu Singh Date: Fri, 29 Aug 2025 12:14:36 +0200 Subject: [PATCH 13/13] chore: use default factory in tests --- .../transports/streamableHttp.test.ts | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/tests/integration/transports/streamableHttp.test.ts b/tests/integration/transports/streamableHttp.test.ts index 70185203..462ba933 100644 --- a/tests/integration/transports/streamableHttp.test.ts +++ b/tests/integration/transports/streamableHttp.test.ts @@ -2,10 +2,10 @@ import { StreamableHttpRunner } from "../../../src/transports/streamableHttp.js" import { Client } from "@modelcontextprotocol/sdk/client/index.js"; import { StreamableHTTPClientTransport } from "@modelcontextprotocol/sdk/client/streamableHttp.js"; import { describe, expect, it, beforeAll, afterAll, beforeEach } from "vitest"; -import { config, driverOptions } from "../../../src/common/config.js"; +import { config } from "../../../src/common/config.js"; import type { LoggerType, LogLevel, LogPayload } from "../../../src/common/logger.js"; import { LoggerBase, LogId } from "../../../src/common/logger.js"; -import { MCPConnectionManager } from "../../../src/common/connectionManager.js"; +import { createMCPConnectionManager } from "../../../src/common/connectionManager.js"; describe("StreamableHttpRunner", () => { let runner: StreamableHttpRunner; @@ -29,9 +29,7 @@ describe("StreamableHttpRunner", () => { describe(description, () => { beforeAll(async () => { config.httpHeaders = headers; - runner = new StreamableHttpRunner(config, ({ logger, deviceId }) => - Promise.resolve(new MCPConnectionManager(config, driverOptions, logger, deviceId)) - ); + runner = new StreamableHttpRunner(config); await runner.start(); }); @@ -112,9 +110,7 @@ describe("StreamableHttpRunner", () => { try { for (let i = 0; i < 3; i++) { config.httpPort = 0; // Use a random port for each runner - const runner = new StreamableHttpRunner(config, ({ logger, deviceId }) => - Promise.resolve(new MCPConnectionManager(config, driverOptions, logger, deviceId)) - ); + const runner = new StreamableHttpRunner(config); await runner.start(); runners.push(runner); } @@ -143,12 +139,7 @@ describe("StreamableHttpRunner", () => { it("can provide custom logger", async () => { const logger = new CustomLogger(); - const runner = new StreamableHttpRunner( - config, - ({ logger, deviceId }) => - Promise.resolve(new MCPConnectionManager(config, driverOptions, logger, deviceId)), - [logger] - ); + const runner = new StreamableHttpRunner(config, createMCPConnectionManager, [logger]); await runner.start(); const messages = logger.messages;