Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions eslint-rules/no-config-imports.js
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ const allowedConfigValueImportFiles = [
"src/index.ts",
// Config resource definition that works with the some config values
"src/resources/common/config.ts",
// The file exports, a factory function to create MCPConnectionManager and
// it relies on driver options generator and default driver options from
// config file.
"src/common/connectionManager.ts",
];

// Ref: https://eslint.org/docs/latest/extend/custom-rules
Expand Down
155 changes: 103 additions & 52 deletions src/common/connectionManager.ts
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
import type { UserConfig, DriverOptions } from "./config.js";
import { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver";
import EventEmitter from "events";
import { setAppNameParamIfMissing } from "../helpers/connectionOptions.js";
import { packageInfo } from "./packageInfo.js";
import ConnectionString from "mongodb-connection-string-url";
import { EventEmitter } from "events";
import type { MongoClientOptions } from "mongodb";
import { ErrorCodes, MongoDBError } from "./errors.js";
import ConnectionString from "mongodb-connection-string-url";
import { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver";
import { type ConnectionInfo, generateConnectionInfoFromCliArgs } from "@mongosh/arg-parser";
import type { DeviceId } from "../helpers/deviceId.js";
import type { AppNameComponents } from "../helpers/connectionOptions.js";
import type { CompositeLogger } from "./logger.js";
import { LogId } from "./logger.js";
import type { ConnectionInfo } from "@mongosh/arg-parser";
import { generateConnectionInfoFromCliArgs } from "@mongosh/arg-parser";
import { defaultDriverOptions, setupDriverConfig, type DriverOptions, type UserConfig } from "./config.js";
import { MongoDBError, ErrorCodes } from "./errors.js";
import { type LoggerBase, LogId } from "./logger.js";
import { packageInfo } from "./packageInfo.js";
import { type AppNameComponents, setAppNameParamIfMissing } from "../helpers/connectionOptions.js";

export interface AtlasClusterConnectionInfo {
username: string;
Expand Down Expand Up @@ -71,39 +68,76 @@ export interface ConnectionManagerEvents {
"connection-error": [ConnectionStateErrored];
}

export class ConnectionManager extends EventEmitter<ConnectionManagerEvents> {
/**
* For a few tests, we need the changeState method to force a connection state
* which is we have this type to typecast the actual ConnectionManager with
* public changeState (only to make TS happy).
*/
export type TestConnectionManager = ConnectionManager & {
changeState<Event extends keyof ConnectionManagerEvents, State extends ConnectionManagerEvents[Event][0]>(
event: Event,
newState: State
): State;
};

export abstract class ConnectionManager {
protected clientName: string;
protected readonly _events;
readonly events: Pick<EventEmitter<ConnectionManagerEvents>, "on" | "off" | "once">;
private state: AnyConnectionState;

constructor() {
this.clientName = "unknown";
this.events = this._events = new EventEmitter<ConnectionManagerEvents>();
this.state = { tag: "disconnected" };
}

get currentConnectionState(): AnyConnectionState {
return this.state;
}

protected changeState<Event extends keyof ConnectionManagerEvents, State extends ConnectionManagerEvents[Event][0]>(
event: Event,
newState: State
): State {
this.state = newState;
// TypeScript doesn't seem to be happy with the spread operator and generics
// eslint-disable-next-line
this._events.emit(event, ...([newState] as any));
return newState;
}

setClientName(clientName: string): void {
this.clientName = clientName;
}

abstract connect(settings: ConnectionSettings): Promise<AnyConnectionState>;

abstract disconnect(): Promise<ConnectionStateDisconnected | ConnectionStateErrored>;
}

export class MCPConnectionManager extends ConnectionManager {
private deviceId: DeviceId;
private clientName: string;
private bus: EventEmitter;

constructor(
private userConfig: UserConfig,
private driverOptions: DriverOptions,
private logger: CompositeLogger,
private logger: LoggerBase,
deviceId: DeviceId,
bus?: EventEmitter
) {
super();

this.bus = bus ?? new EventEmitter();
this.state = { tag: "disconnected" };

this.bus.on("mongodb-oidc-plugin:auth-failed", this.onOidcAuthFailed.bind(this));
this.bus.on("mongodb-oidc-plugin:auth-succeeded", this.onOidcAuthSucceeded.bind(this));

this.deviceId = deviceId;
this.clientName = "unknown";
}

setClientName(clientName: string): void {
this.clientName = clientName;
}

async connect(settings: ConnectionSettings): Promise<AnyConnectionState> {
this.emit("connection-request", this.state);
this._events.emit("connection-request", this.currentConnectionState);

if (this.state.tag === "connected" || this.state.tag === "connecting") {
if (this.currentConnectionState.tag === "connected" || this.currentConnectionState.tag === "connecting") {
await this.disconnect();
}

Expand Down Expand Up @@ -138,7 +172,7 @@ export class ConnectionManager extends EventEmitter<ConnectionManagerEvents> {
connectionInfo.driverOptions.proxy ??= { useEnvironmentVariableProxies: true };
connectionInfo.driverOptions.applyProxyToOIDC ??= true;

connectionStringAuthType = ConnectionManager.inferConnectionTypeFromSettings(
connectionStringAuthType = MCPConnectionManager.inferConnectionTypeFromSettings(
this.userConfig,
connectionInfo
);
Expand All @@ -165,7 +199,10 @@ export class ConnectionManager extends EventEmitter<ConnectionManagerEvents> {
}

try {
const connectionType = ConnectionManager.inferConnectionTypeFromSettings(this.userConfig, connectionInfo);
const connectionType = MCPConnectionManager.inferConnectionTypeFromSettings(
this.userConfig,
connectionInfo
);
if (connectionType.startsWith("oidc")) {
void this.pingAndForget(serviceProvider);

Expand Down Expand Up @@ -199,13 +236,13 @@ export class ConnectionManager extends EventEmitter<ConnectionManagerEvents> {
}

async disconnect(): Promise<ConnectionStateDisconnected | ConnectionStateErrored> {
if (this.state.tag === "disconnected" || this.state.tag === "errored") {
return this.state;
if (this.currentConnectionState.tag === "disconnected" || this.currentConnectionState.tag === "errored") {
return this.currentConnectionState;
}

if (this.state.tag === "connected" || this.state.tag === "connecting") {
if (this.currentConnectionState.tag === "connected" || this.currentConnectionState.tag === "connecting") {
try {
await this.state.serviceProvider?.close(true);
await this.currentConnectionState.serviceProvider?.close(true);
} finally {
this.changeState("connection-close", {
tag: "disconnected",
Expand All @@ -216,30 +253,21 @@ export class ConnectionManager extends EventEmitter<ConnectionManagerEvents> {
return { tag: "disconnected" };
}

get currentConnectionState(): AnyConnectionState {
return this.state;
}

changeState<Event extends keyof ConnectionManagerEvents, State extends ConnectionManagerEvents[Event][0]>(
event: Event,
newState: State
): State {
this.state = newState;
// TypeScript doesn't seem to be happy with the spread operator and generics
// eslint-disable-next-line
this.emit(event, ...([newState] as any));
return newState;
}

private onOidcAuthFailed(error: unknown): void {
if (this.state.tag === "connecting" && this.state.connectionStringAuthType?.startsWith("oidc")) {
if (
this.currentConnectionState.tag === "connecting" &&
this.currentConnectionState.connectionStringAuthType?.startsWith("oidc")
) {
void this.disconnectOnOidcError(error);
}
}

private onOidcAuthSucceeded(): void {
if (this.state.tag === "connecting" && this.state.connectionStringAuthType?.startsWith("oidc")) {
this.changeState("connection-success", { ...this.state, tag: "connected" });
if (
this.currentConnectionState.tag === "connecting" &&
this.currentConnectionState.connectionStringAuthType?.startsWith("oidc")
) {
this.changeState("connection-success", { ...this.currentConnectionState, tag: "connected" });
}

this.logger.info({
Expand All @@ -250,9 +278,12 @@ export class ConnectionManager extends EventEmitter<ConnectionManagerEvents> {
}

private onOidcNotifyDeviceFlow(flowInfo: { verificationUrl: string; userCode: string }): void {
if (this.state.tag === "connecting" && this.state.connectionStringAuthType?.startsWith("oidc")) {
if (
this.currentConnectionState.tag === "connecting" &&
this.currentConnectionState.connectionStringAuthType?.startsWith("oidc")
) {
this.changeState("connection-request", {
...this.state,
...this.currentConnectionState,
tag: "connecting",
connectionStringAuthType: "oidc-device-flow",
oidcLoginUrl: flowInfo.verificationUrl,
Expand Down Expand Up @@ -329,3 +360,23 @@ export class ConnectionManager extends EventEmitter<ConnectionManagerEvents> {
}
}
}

/**
* Consumers of MCP server library have option to bring their own connection
* management if they need to. To support that, we enable injecting connection
* manager implementation through a factory function.
*/
export type ConnectionManagerFactoryFn = (createParams: {
logger: LoggerBase;
deviceId: DeviceId;
userConfig: UserConfig;
}) => Promise<ConnectionManager>;

export const createMCPConnectionManager: ConnectionManagerFactoryFn = ({ logger, deviceId, userConfig }) => {
const driverOptions = setupDriverConfig({
config: userConfig,
defaults: defaultDriverOptions,
});

return Promise.resolve(new MCPConnectionManager(userConfig, driverOptions, logger, deviceId));
};
8 changes: 4 additions & 4 deletions src/common/session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,10 @@ export class Session extends EventEmitter<SessionEvents> {
this.apiClient = new ApiClient({ baseUrl: apiBaseUrl, credentials }, logger);
this.exportsManager = exportsManager;
this.connectionManager = connectionManager;
this.connectionManager.on("connection-success", () => this.emit("connect"));
this.connectionManager.on("connection-time-out", (error) => this.emit("connection-error", error));
this.connectionManager.on("connection-close", () => this.emit("disconnect"));
this.connectionManager.on("connection-error", (error) => this.emit("connection-error", error));
this.connectionManager.events.on("connection-success", () => this.emit("connect"));
this.connectionManager.events.on("connection-time-out", (error) => this.emit("connection-error", error));
this.connectionManager.events.on("connection-close", () => this.emit("disconnect"));
this.connectionManager.events.on("connection-error", (error) => this.emit("connection-error", error));
}

setMcpClient(mcpClient: Implementation | undefined): void {
Expand Down
7 changes: 2 additions & 5 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ function enableFipsIfRequested(): void {
enableFipsIfRequested();

import { ConsoleLogger, LogId } from "./common/logger.js";
import { config, driverOptions } from "./common/config.js";
import { config } from "./common/config.js";
import crypto from "crypto";
import { packageInfo } from "./common/packageInfo.js";
import { StdioRunner } from "./transports/stdio.js";
Expand All @@ -49,10 +49,7 @@ async function main(): Promise<void> {
assertHelpMode();
assertVersionMode();

const transportRunner =
config.transport === "stdio"
? new StdioRunner(config, driverOptions)
: new StreamableHttpRunner(config, driverOptions);
const transportRunner = config.transport === "stdio" ? new StdioRunner(config) : new StreamableHttpRunner(config);
const shutdown = (): void => {
transportRunner.logger.info({
id: LogId.serverCloseRequested,
Expand Down
15 changes: 11 additions & 4 deletions src/lib.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
export { Server, type ServerOptions } from "./server.js";
export { Telemetry } from "./telemetry/telemetry.js";
export { Session, type SessionOptions } from "./common/session.js";
export { type UserConfig, defaultUserConfig } from "./common/config.js";
export { defaultUserConfig, type UserConfig } from "./common/config.js";
export { LoggerBase, type LogPayload, type LoggerType, type LogLevel } from "./common/logger.js";
export { StreamableHttpRunner } from "./transports/streamableHttp.js";
export { LoggerBase } from "./common/logger.js";
export type { LogPayload, LoggerType, LogLevel } from "./common/logger.js";
export {
ConnectionManager,
type AnyConnectionState,
type ConnectionState,
type ConnectionStateDisconnected,
type ConnectionStateErrored,
type ConnectionManagerFactoryFn,
} from "./common/connectionManager.js";
export { Telemetry } from "./telemetry/telemetry.js";
14 changes: 9 additions & 5 deletions src/transports/base.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import type { DriverOptions, UserConfig } from "../common/config.js";
import type { UserConfig } from "../common/config.js";
import { packageInfo } from "../common/packageInfo.js";
import { Server } from "../server.js";
import { Session } from "../common/session.js";
Expand All @@ -7,16 +7,16 @@ import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js";
import type { LoggerBase } from "../common/logger.js";
import { CompositeLogger, ConsoleLogger, DiskLogger, McpLogger } from "../common/logger.js";
import { ExportsManager } from "../common/exportsManager.js";
import { ConnectionManager } from "../common/connectionManager.js";
import { DeviceId } from "../helpers/deviceId.js";
import { type ConnectionManagerFactoryFn } from "../common/connectionManager.js";

export abstract class TransportRunnerBase {
public logger: LoggerBase;
public deviceId: DeviceId;

protected constructor(
protected readonly userConfig: UserConfig,
private readonly driverOptions: DriverOptions,
private readonly createConnectionManager: ConnectionManagerFactoryFn,
additionalLoggers: LoggerBase[]
) {
const loggers: LoggerBase[] = [...additionalLoggers];
Expand All @@ -38,15 +38,19 @@ export abstract class TransportRunnerBase {
this.deviceId = DeviceId.create(this.logger);
}

protected setupServer(): Server {
protected async setupServer(): Promise<Server> {
const mcpServer = new McpServer({
name: packageInfo.mcpServerName,
version: packageInfo.version,
});

const logger = new CompositeLogger(this.logger);
const exportsManager = ExportsManager.init(this.userConfig, logger);
const connectionManager = new ConnectionManager(this.userConfig, this.driverOptions, logger, this.deviceId);
const connectionManager = await this.createConnectionManager({
logger,
userConfig: this.userConfig,
deviceId: this.deviceId,
});

const session = new Session({
apiBaseUrl: this.userConfig.apiBaseUrl,
Expand Down
22 changes: 13 additions & 9 deletions src/transports/stdio.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import type { LoggerBase } from "../common/logger.js";
import { LogId } from "../common/logger.js";
import type { Server } from "../server.js";
import { TransportRunnerBase } from "./base.js";
import { EJSON } from "bson";
import type { JSONRPCMessage } from "@modelcontextprotocol/sdk/types.js";
import { JSONRPCMessageSchema } from "@modelcontextprotocol/sdk/types.js";
import { EJSON } from "bson";
import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio.js";
import type { DriverOptions, UserConfig } from "../common/config.js";
import { type LoggerBase, LogId } from "../common/logger.js";
import type { Server } from "../server.js";
import { TransportRunnerBase } from "./base.js";
import { type UserConfig } from "../common/config.js";
import { createMCPConnectionManager, type ConnectionManagerFactoryFn } from "../common/connectionManager.js";

// This is almost a copy of ReadBuffer from @modelcontextprotocol/sdk
// but it uses EJSON.parse instead of JSON.parse to handle BSON types
Expand Down Expand Up @@ -55,13 +55,17 @@ export function createStdioTransport(): StdioServerTransport {
export class StdioRunner extends TransportRunnerBase {
private server: Server | undefined;

constructor(userConfig: UserConfig, driverOptions: DriverOptions, additionalLoggers: LoggerBase[] = []) {
super(userConfig, driverOptions, additionalLoggers);
constructor(
userConfig: UserConfig,
createConnectionManager: ConnectionManagerFactoryFn = createMCPConnectionManager,
additionalLoggers: LoggerBase[] = []
) {
super(userConfig, createConnectionManager, additionalLoggers);
}

async start(): Promise<void> {
try {
this.server = this.setupServer();
this.server = await this.setupServer();

const transport = createStdioTransport();

Expand Down
Loading
Loading