diff --git a/src/helpers/EJsonTransport.ts b/src/helpers/EJsonTransport.ts new file mode 100644 index 00000000..307e90bd --- /dev/null +++ b/src/helpers/EJsonTransport.ts @@ -0,0 +1,47 @@ +import { JSONRPCMessage, JSONRPCMessageSchema } from "@modelcontextprotocol/sdk/types.js"; +import { EJSON } from "bson"; +import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio.js"; + +// This is almost a copy of ReadBuffer from @modelcontextprotocol/sdk +// but it uses EJSON.parse instead of JSON.parse to handle BSON types +export class EJsonReadBuffer { + private _buffer?: Buffer; + + append(chunk: Buffer): void { + this._buffer = this._buffer ? Buffer.concat([this._buffer, chunk]) : chunk; + } + + readMessage(): JSONRPCMessage | null { + if (!this._buffer) { + return null; + } + + const index = this._buffer.indexOf("\n"); + if (index === -1) { + return null; + } + + const line = this._buffer.toString("utf8", 0, index).replace(/\r$/, ""); + this._buffer = this._buffer.subarray(index + 1); + + // This is using EJSON.parse instead of JSON.parse to handle BSON types + return JSONRPCMessageSchema.parse(EJSON.parse(line)); + } + + clear(): void { + this._buffer = undefined; + } +} + +// This is a hacky workaround for https://github.com/mongodb-js/mongodb-mcp-server/issues/211 +// The underlying issue is that StdioServerTransport uses JSON.parse to deserialize +// messages, but that doesn't handle bson types, such as ObjectId when serialized as EJSON. +// +// This function creates a StdioServerTransport and replaces the internal readBuffer with EJsonReadBuffer +// that uses EJson.parse instead. +export function createEJsonTransport(): StdioServerTransport { + const server = new StdioServerTransport(); + server["_readBuffer"] = new EJsonReadBuffer(); + + return server; +} diff --git a/src/index.ts b/src/index.ts index f91db447..ee332072 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,6 +1,5 @@ #!/usr/bin/env node -import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio.js"; import logger, { LogId } from "./logger.js"; import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; import { config } from "./config.js"; @@ -8,6 +7,7 @@ import { Session } from "./session.js"; import { Server } from "./server.js"; import { packageInfo } from "./helpers/packageInfo.js"; import { Telemetry } from "./telemetry/telemetry.js"; +import { createEJsonTransport } from "./helpers/EJsonTransport.js"; try { const session = new Session({ @@ -29,7 +29,7 @@ try { userConfig: config, }); - const transport = new StdioServerTransport(); + const transport = createEJsonTransport(); await server.connect(transport); } catch (error: unknown) { diff --git a/tests/integration/helpers.ts b/tests/integration/helpers.ts index b5c31b9b..fd79ecfa 100644 --- a/tests/integration/helpers.ts +++ b/tests/integration/helpers.ts @@ -227,6 +227,7 @@ export function validateThrowsForInvalidArguments( } /** Expects the argument being defined and asserts it */ -export function expectDefined(arg: T): asserts arg is Exclude { +export function expectDefined(arg: T): asserts arg is Exclude { expect(arg).toBeDefined(); + expect(arg).not.toBeNull(); } diff --git a/tests/integration/tools/mongodb/read/find.test.ts b/tests/integration/tools/mongodb/read/find.test.ts index d62d67a9..05fd0b75 100644 --- a/tests/integration/tools/mongodb/read/find.test.ts +++ b/tests/integration/tools/mongodb/read/find.test.ts @@ -4,6 +4,7 @@ import { validateToolMetadata, validateThrowsForInvalidArguments, getResponseElements, + expectDefined, } from "../../../helpers.js"; import { describeWithMongoDB, validateAutoConnectBehavior } from "../mongodbHelpers.js"; @@ -171,6 +172,33 @@ describeWithMongoDB("find tool", (integration) => { expect(JSON.parse(elements[i + 1].text).value).toEqual(i); } }); + + it("can find objects by $oid", async () => { + await integration.connectMcpClient(); + + const fooObject = await integration + .mongoClient() + .db(integration.randomDbName()) + .collection("foo") + .findOne(); + expectDefined(fooObject); + + const response = await integration.mcpClient().callTool({ + name: "find", + arguments: { + database: integration.randomDbName(), + collection: "foo", + filter: { _id: fooObject._id }, + }, + }); + + const elements = getResponseElements(response.content); + expect(elements).toHaveLength(2); + expect(elements[0].text).toEqual('Found 1 documents in the collection "foo":'); + + // eslint-disable-next-line @typescript-eslint/no-unsafe-member-access + expect(JSON.parse(elements[1].text).value).toEqual(fooObject.value); + }); }); validateAutoConnectBehavior(integration, "find", () => { diff --git a/tests/unit/EJsonTransport.test.ts b/tests/unit/EJsonTransport.test.ts new file mode 100644 index 00000000..f0371cf4 --- /dev/null +++ b/tests/unit/EJsonTransport.test.ts @@ -0,0 +1,71 @@ +import { Decimal128, MaxKey, MinKey, ObjectId, Timestamp, UUID } from "bson"; +import { createEJsonTransport, EJsonReadBuffer } from "../../src/helpers/EJsonTransport.js"; +import { JSONRPCMessage } from "@modelcontextprotocol/sdk/types.js"; +import { AuthInfo } from "@modelcontextprotocol/sdk/server/auth/types.js"; +import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio.js"; +import { Readable } from "stream"; +import { ReadBuffer } from "@modelcontextprotocol/sdk/shared/stdio.js"; + +describe("EJsonTransport", () => { + let transport: StdioServerTransport; + beforeEach(async () => { + transport = createEJsonTransport(); + await transport.start(); + }); + + afterEach(async () => { + await transport.close(); + }); + + it("ejson deserializes messages", () => { + const messages: { message: JSONRPCMessage; extra?: { authInfo?: AuthInfo } }[] = []; + transport.onmessage = ( + message, + extra?: { + authInfo?: AuthInfo; + } + ) => { + messages.push({ message, extra }); + }; + + (transport["_stdin"] as Readable).emit( + "data", + Buffer.from( + '{"jsonrpc":"2.0","id":1,"method":"testMethod","params":{"oid":{"$oid":"681b741f13aa74a0687b5110"},"uuid":{"$uuid":"f81d4fae-7dec-11d0-a765-00a0c91e6bf6"},"date":{"$date":"2025-05-07T14:54:23.973Z"},"decimal":{"$numberDecimal":"1234567890987654321"},"int32":123,"maxKey":{"$maxKey":1},"minKey":{"$minKey":1},"timestamp":{"$timestamp":{"t":123,"i":456}}}}\n', + "utf-8" + ) + ); + + expect(messages.length).toBe(1); + const message = messages[0].message; + + expect(message).toEqual({ + jsonrpc: "2.0", + id: 1, + method: "testMethod", + params: { + oid: new ObjectId("681b741f13aa74a0687b5110"), + uuid: new UUID("f81d4fae-7dec-11d0-a765-00a0c91e6bf6"), + date: new Date(Date.parse("2025-05-07T14:54:23.973Z")), + decimal: new Decimal128("1234567890987654321"), + int32: 123, + maxKey: new MaxKey(), + minKey: new MinKey(), + timestamp: new Timestamp({ t: 123, i: 456 }), + }, + }); + }); + + it("has _readBuffer field of type EJsonReadBuffer", () => { + expect(transport["_readBuffer"]).toBeDefined(); + expect(transport["_readBuffer"]).toBeInstanceOf(EJsonReadBuffer); + }); + + describe("standard StdioServerTransport", () => { + it("has a _readBuffer field", () => { + const standardTransport = new StdioServerTransport(); + expect(standardTransport["_readBuffer"]).toBeDefined(); + expect(standardTransport["_readBuffer"]).toBeInstanceOf(ReadBuffer); + }); + }); +});