Skip to content

Commit 7ed01ed

Browse files
chore: LangChain based accuracy tests
1 parent d7d4aa9 commit 7ed01ed

File tree

8 files changed

+965
-4
lines changed

8 files changed

+965
-4
lines changed

package-lock.json

Lines changed: 506 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

package.json

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@
3535
"devDependencies": {
3636
"@eslint/js": "^9.24.0",
3737
"@jest/globals": "^30.0.0",
38+
"@langchain/core": "^0.3.61",
39+
"@langchain/google-genai": "^0.2.14",
40+
"@langchain/ollama": "^0.2.3",
41+
"@langchain/openai": "^0.5.16",
3842
"@modelcontextprotocol/inspector": "^0.14.0",
3943
"@redocly/cli": "^1.34.2",
4044
"@types/jest": "^29.5.14",
@@ -49,6 +53,7 @@
4953
"jest": "^29.7.0",
5054
"jest-environment-node": "^29.7.0",
5155
"jest-extended": "^6.0.0",
56+
"langchain": "^0.3.29",
5257
"mongodb-runner": "^5.8.2",
5358
"openapi-types": "^12.1.3",
5459
"openapi-typescript": "^7.6.1",
@@ -57,6 +62,7 @@
5762
"tsx": "^4.19.3",
5863
"typescript": "^5.8.2",
5964
"typescript-eslint": "^8.29.1",
65+
"uuid": "^11.1.0",
6066
"yaml": "^2.7.1"
6167
},
6268
"dependencies": {

tests/accuracy/list-databases.test.ts

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import { describeAccuracyTests } from "./sdk/describe-accuracy-tests.js";
2+
import { getAvailableModels } from "./sdk/models.js";
3+
4+
describeAccuracyTests("list-databases", getAvailableModels(), [
5+
{
6+
prompt: "Assume that you're already connected. How many collections are there in sample_mflix database",
7+
mockedTools: {
8+
"list-collections": function listCollections() {
9+
return {
10+
content: [
11+
{
12+
type: "text",
13+
text: "Name: coll1",
14+
},
15+
],
16+
};
17+
},
18+
},
19+
expectedToolCalls: [
20+
{
21+
toolName: "list-collections",
22+
parameters: { database: "sample_mflix" },
23+
},
24+
],
25+
},
26+
]);
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
export type ToolCall = {
2+
toolCallId: string;
3+
toolName: string;
4+
parameters: unknown;
5+
};
6+
export type ExpectedToolCall = Omit<ToolCall, "toolCallId">;
7+
8+
export function toolCallingAccuracyScorer(expectedToolCalls: ExpectedToolCall[], actualToolCalls: ToolCall[]): number {
9+
if (actualToolCalls.length < expectedToolCalls.length) {
10+
return 0;
11+
}
12+
13+
const possibleScore = actualToolCalls.length > expectedToolCalls.length ? 0.75 : 1;
14+
const checkedToolCallIds = new Set<string>();
15+
for (const expectedToolCall of expectedToolCalls) {
16+
const matchingActualToolCall = actualToolCalls.find(
17+
(actualToolCall) =>
18+
actualToolCall.toolName === expectedToolCall.toolName &&
19+
!checkedToolCallIds.has(actualToolCall.toolCallId)
20+
);
21+
22+
if (!matchingActualToolCall) {
23+
return 0;
24+
}
25+
26+
checkedToolCallIds.add(matchingActualToolCall.toolCallId);
27+
}
28+
29+
return possibleScore;
30+
}
31+
32+
export function parameterMatchingAccuracyScorer(
33+
expectedToolCalls: ExpectedToolCall[],
34+
actualToolCalls: ToolCall[]
35+
): number {
36+
if (expectedToolCalls.length === 0) {
37+
return 1;
38+
}
39+
40+
const toolCallScores: number[] = [];
41+
const checkedToolCallIds = new Set<string>();
42+
43+
for (const expectedToolCall of expectedToolCalls) {
44+
const matchingActualToolCall = actualToolCalls.find(
45+
(actualToolCall) =>
46+
actualToolCall.toolName === expectedToolCall.toolName &&
47+
!checkedToolCallIds.has(actualToolCall.toolCallId)
48+
);
49+
50+
if (!matchingActualToolCall) {
51+
toolCallScores.push(0);
52+
continue;
53+
}
54+
55+
checkedToolCallIds.add(matchingActualToolCall.toolCallId);
56+
const score = compareParams(expectedToolCall.parameters, matchingActualToolCall.parameters);
57+
toolCallScores.push(score);
58+
}
59+
60+
const totalScore = toolCallScores.reduce((sum, score) => sum + score, 0);
61+
return totalScore / toolCallScores.length;
62+
}
63+
64+
/**
65+
* Recursively compares expected and actual parameters and returns a score.
66+
* - 1: Perfect match.
67+
* - 0.75: All expected parameters are present and match, but there are extra actual parameters.
68+
* - 0: Missing parameters or mismatched values.
69+
*/
70+
function compareParams(expected: unknown, actual: unknown): number {
71+
if (expected === null || expected === undefined) {
72+
return actual === null || actual === undefined ? 1 : 0;
73+
}
74+
if (actual === null || actual === undefined) {
75+
return 0;
76+
}
77+
78+
if (Array.isArray(expected)) {
79+
if (!Array.isArray(actual) || actual.length < expected.length) {
80+
return 0;
81+
}
82+
let minScore = 1;
83+
for (let i = 0; i < expected.length; i++) {
84+
minScore = Math.min(minScore, compareParams(expected[i], actual[i]));
85+
}
86+
if (minScore === 0) {
87+
return 0;
88+
}
89+
if (actual.length > expected.length) {
90+
minScore = Math.min(minScore, 0.75);
91+
}
92+
return minScore;
93+
}
94+
95+
if (typeof expected === "object") {
96+
if (typeof actual !== "object" || Array.isArray(actual)) {
97+
return 0;
98+
}
99+
const expectedKeys = Object.keys(expected as Record<string, unknown>);
100+
const actualKeys = Object.keys(actual as Record<string, unknown>);
101+
102+
let minScore = 1;
103+
for (const key of expectedKeys) {
104+
if (!Object.prototype.hasOwnProperty.call(actual, key)) {
105+
return 0;
106+
}
107+
minScore = Math.min(
108+
minScore,
109+
compareParams((expected as Record<string, unknown>)[key], (actual as Record<string, unknown>)[key])
110+
);
111+
}
112+
113+
if (minScore === 0) {
114+
return 0;
115+
}
116+
117+
if (actualKeys.length > expectedKeys.length) {
118+
minScore = Math.min(minScore, 0.75);
119+
}
120+
return minScore;
121+
}
122+
123+
// eslint-disable-next-line eqeqeq
124+
return expected == actual ? 1 : 0;
125+
}
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import { AgentExecutor } from "langchain/agents";
2+
import { Tool } from "@modelcontextprotocol/sdk/types.js";
3+
import { discoverMongoDBTools, TestTools, ToolResultGenerators } from "./test-tools.js";
4+
import { TestableModels } from "./models.js";
5+
import { getToolCallingAgent } from "./tool-calling-agent.js";
6+
import { ExpectedToolCall, parameterMatchingAccuracyScorer, toolCallingAccuracyScorer } from "./accuracy-scorers.js";
7+
8+
interface AccuracyTestConfig {
9+
prompt: string;
10+
expectedToolCalls: ExpectedToolCall[];
11+
mockedTools: ToolResultGenerators;
12+
}
13+
14+
export function describeAccuracyTests(
15+
suiteName: string,
16+
models: TestableModels,
17+
accuracyTestConfigs: AccuracyTestConfig[]
18+
) {
19+
const eachModel = describe.each(models);
20+
const eachTest = it.each(accuracyTestConfigs);
21+
22+
eachModel(`$modelName - ${suiteName}`, function (model) {
23+
let mcpTools: Tool[];
24+
let testTools: TestTools;
25+
let agent: AgentExecutor;
26+
27+
beforeAll(async () => {
28+
mcpTools = await discoverMongoDBTools();
29+
});
30+
31+
beforeEach(() => {
32+
testTools = new TestTools(mcpTools);
33+
const transformToolResult = model.transformToolResult.bind(model);
34+
agent = getToolCallingAgent(model, testTools.langChainTools(transformToolResult));
35+
});
36+
37+
eachTest("$prompt", async function (testConfig) {
38+
testTools.mockTools(testConfig.mockedTools);
39+
const conversation = await agent.invoke({ input: testConfig.prompt });
40+
console.log("conversation", conversation);
41+
const toolCalls = testTools.getToolCalls();
42+
console.log("?????? toolCalls", toolCalls);
43+
console.log("???? expected", testConfig.expectedToolCalls);
44+
const toolCallingAccuracy = toolCallingAccuracyScorer(testConfig.expectedToolCalls, toolCalls);
45+
const parameterMatchingAccuracy = parameterMatchingAccuracyScorer(testConfig.expectedToolCalls, toolCalls);
46+
47+
expect(toolCallingAccuracy).not.toEqual(0);
48+
expect(parameterMatchingAccuracy).toBeGreaterThanOrEqual(0.5);
49+
});
50+
});
51+
}

tests/accuracy/sdk/models.ts

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import { BaseChatModel } from "@langchain/core/language_models/chat_models";
2+
import { ChatGoogleGenerativeAI } from "@langchain/google-genai";
3+
import { ChatOllama } from "@langchain/ollama";
4+
import { CallToolResult } from "@modelcontextprotocol/sdk/types.js";
5+
6+
type ToolResultForOllama = string;
7+
export type AcceptableToolResponse = CallToolResult | ToolResultForOllama;
8+
9+
export interface Model<M extends BaseChatModel = BaseChatModel, T extends AcceptableToolResponse = CallToolResult> {
10+
isAvailable(): boolean;
11+
getLangChainModel(): M;
12+
transformToolResult(callToolResult: CallToolResult): T;
13+
}
14+
15+
export class GeminiModel implements Model<ChatGoogleGenerativeAI> {
16+
constructor(readonly modelName: string) {}
17+
18+
isAvailable(): boolean {
19+
return !!process.env.MDB_GEMINI_API_KEY;
20+
}
21+
22+
getLangChainModel(): ChatGoogleGenerativeAI {
23+
return new ChatGoogleGenerativeAI({
24+
model: this.modelName,
25+
apiKey: process.env.MDB_GEMINI_API_KEY,
26+
});
27+
}
28+
29+
transformToolResult(callToolResult: CallToolResult) {
30+
return callToolResult;
31+
}
32+
}
33+
34+
export class OllamaModel implements Model<ChatOllama, ToolResultForOllama> {
35+
constructor(readonly modelName: string) {}
36+
37+
isAvailable(): boolean {
38+
return !!process.env.MDB_GEMINI_API_KEY;
39+
}
40+
41+
getLangChainModel(): ChatOllama {
42+
return new ChatOllama({
43+
model: this.modelName,
44+
});
45+
}
46+
47+
transformToolResult(callToolResult: CallToolResult): ToolResultForOllama {
48+
return JSON.stringify(callToolResult);
49+
}
50+
}
51+
52+
const ALL_TESTABLE_MODELS = [
53+
// new GeminiModel("gemini-1.5-flash"),
54+
// new GeminiModel("gemini-2.0-flash"),
55+
new OllamaModel("qwen3:latest"),
56+
];
57+
58+
export type TestableModels = ReturnType<typeof getAvailableModels>;
59+
60+
export function getAvailableModels() {
61+
return ALL_TESTABLE_MODELS.filter((model) => model.isAvailable());
62+
}

0 commit comments

Comments
 (0)