Skip to content

Commit cb5178f

Browse files
chore: make expectedToolCalls part of PromptResult
1 parent 129147d commit cb5178f

File tree

5 files changed

+73
-49
lines changed

5 files changed

+73
-49
lines changed

scripts/accuracy/generate-test-summary.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ type ComparableAccuracyResult = Omit<AccuracyResult, "promptResults"> & {
2121

2222
interface PromptAndModelResponse extends ModelResponse {
2323
prompt: string;
24+
expectedToolCalls: ExpectedToolCall[];
2425
baselineToolAccuracy?: number;
2526
}
2627

@@ -293,6 +294,7 @@ async function generateTestSummary() {
293294
return {
294295
...currentModelResponse,
295296
prompt: currentPromptResult.prompt,
297+
expectedToolCalls: currentPromptResult.expectedToolCalls,
296298
baselineToolAccuracy: baselineModelResponse?.toolCallingAccuracy,
297299
};
298300
});

tests/accuracy/sdk/accuracy-result-storage/disk-storage.ts

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import {
77
AccuracyResultStorage,
88
AccuracyRunStatus,
99
AccuracyRunStatuses,
10+
ExpectedToolCall,
1011
ModelResponse,
1112
} from "./result-storage.js";
1213

@@ -74,31 +75,36 @@ export class DiskBasedResultStorage implements AccuracyResultStorage {
7475
}
7576
}
7677

77-
async saveModelResponseForPrompt(
78-
commitSHA: string,
79-
runId: string,
80-
prompt: string,
81-
modelResponse: ModelResponse
82-
): Promise<void> {
78+
async saveModelResponseForPrompt({
79+
commitSHA,
80+
runId,
81+
prompt,
82+
expectedToolCalls,
83+
modelResponse,
84+
}: {
85+
commitSHA: string;
86+
runId: string;
87+
prompt: string;
88+
expectedToolCalls: ExpectedToolCall[];
89+
modelResponse: ModelResponse;
90+
}): Promise<void> {
91+
const initialData: AccuracyResult = {
92+
runId,
93+
runStatus: AccuracyRunStatus.InProgress,
94+
createdOn: Date.now(),
95+
commitSHA,
96+
promptResults: [
97+
{
98+
prompt,
99+
expectedToolCalls,
100+
modelResponses: [modelResponse],
101+
},
102+
],
103+
};
83104
const resultFilePath = this.getAccuracyResultFilePath(commitSHA, runId);
84105
const { fileCreatedWithInitialData } = await this.ensureAccuracyResultFile(
85106
resultFilePath,
86-
JSON.stringify(
87-
{
88-
runId,
89-
runStatus: AccuracyRunStatus.InProgress,
90-
createdOn: Date.now(),
91-
commitSHA,
92-
promptResults: [
93-
{
94-
prompt,
95-
modelResponses: [modelResponse],
96-
},
97-
],
98-
},
99-
null,
100-
2
101-
)
107+
JSON.stringify(initialData, null, 2)
102108
);
103109

104110
if (fileCreatedWithInitialData) {
@@ -124,6 +130,7 @@ export class DiskBasedResultStorage implements AccuracyResultStorage {
124130
...accuracyResult.promptResults,
125131
{
126132
prompt,
133+
expectedToolCalls,
127134
modelResponses: [modelResponse],
128135
},
129136
],
@@ -136,6 +143,7 @@ export class DiskBasedResultStorage implements AccuracyResultStorage {
136143

137144
accuracyResult.promptResults.splice(existingPromptIdx, 1, {
138145
prompt: promptResult.prompt,
146+
expectedToolCalls: promptResult.expectedToolCalls,
139147
modelResponses: [...promptResult.modelResponses, modelResponse],
140148
});
141149

tests/accuracy/sdk/accuracy-result-storage/mongodb-storage.ts

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import {
44
AccuracyResultStorage,
55
AccuracyRunStatus,
66
AccuracyRunStatuses,
7+
ExpectedToolCall,
78
ModelResponse,
89
} from "./result-storage.js";
910

@@ -48,12 +49,19 @@ export class MongoDBBasedResultStorage implements AccuracyResultStorage {
4849
);
4950
}
5051

51-
async saveModelResponseForPrompt(
52-
commitSHA: string,
53-
runId: string,
54-
prompt: string,
55-
modelResponse: ModelResponse
56-
): Promise<void> {
52+
async saveModelResponseForPrompt({
53+
commitSHA,
54+
runId,
55+
prompt,
56+
expectedToolCalls,
57+
modelResponse,
58+
}: {
59+
commitSHA: string;
60+
runId: string;
61+
prompt: string;
62+
expectedToolCalls: ExpectedToolCall[];
63+
modelResponse: ModelResponse;
64+
}): Promise<void> {
5765
const savedModelResponse: ModelResponse = { ...modelResponse };
5866
for (const field of this.omittedModelResponseFields) {
5967
delete savedModelResponse[field];
@@ -81,7 +89,7 @@ export class MongoDBBasedResultStorage implements AccuracyResultStorage {
8189
},
8290
{
8391
$push: {
84-
promptResults: { prompt, modelResponses: [] },
92+
promptResults: { prompt, expectedToolCalls, modelResponses: [] },
8593
},
8694
}
8795
);

tests/accuracy/sdk/accuracy-result-storage/result-storage.ts

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ export interface PromptResult {
4242
/**
4343
* The actual prompt that was provided to LLM as test */
4444
prompt: string;
45+
/**
46+
* A list of tools, along with their parameters, that are expected to be
47+
* called by the LLM in test. */
48+
expectedToolCalls: ExpectedToolCall[];
4549
/**
4650
* The responses from the LLMs tested, when provided with the prompt. */
4751
modelResponses: ModelResponse[];
@@ -65,10 +69,6 @@ export interface ModelResponse {
6569
* were called by LLM when responding to the provided prompts. To know more
6670
* about how this number is generated, check - toolCallingAccuracy.ts */
6771
toolCallingAccuracy: number;
68-
/**
69-
* A list of tools, along with their parameters, that are expected to be
70-
* called by the LLM in test. */
71-
expectedToolCalls: ExpectedToolCall[];
7272
/**
7373
* A list of tools, along with their parameters, that were actually called
7474
* by the LLM in test. */
@@ -106,11 +106,12 @@ export interface AccuracyResultStorage {
106106
/**
107107
* Attempts to atomically insert the model response for the prompt in the
108108
* stored accuracy result. */
109-
saveModelResponseForPrompt(
110-
commitSHA: string,
111-
runId: string,
112-
prompt: string,
113-
modelResponse: ModelResponse
114-
): Promise<void>;
109+
saveModelResponseForPrompt(data: {
110+
commitSHA: string;
111+
runId: string;
112+
prompt: string;
113+
expectedToolCalls: ExpectedToolCall[];
114+
modelResponse: ModelResponse;
115+
}): Promise<void>;
115116
close(): Promise<void>;
116117
}

tests/accuracy/sdk/describe-accuracy-tests.ts

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -102,17 +102,22 @@ export function describeAccuracyTests(models: TestableModels, accuracyTestConfig
102102
const toolCallingAccuracy = calculateToolCallingAccuracy(testConfig.expectedToolCalls, llmToolCalls);
103103

104104
const responseTime = timeAfterPrompt - timeBeforePrompt;
105-
await accuracyResultStorage.saveModelResponseForPrompt(commitSHA, accuracyRunId, testConfig.prompt, {
106-
provider: model.provider,
107-
requestedModel: model.modelName,
108-
respondingModel: result.respondingModel,
109-
llmResponseTime: responseTime,
110-
toolCallingAccuracy: toolCallingAccuracy,
105+
await accuracyResultStorage.saveModelResponseForPrompt({
106+
commitSHA,
107+
runId: accuracyRunId,
108+
prompt: testConfig.prompt,
111109
expectedToolCalls: testConfig.expectedToolCalls,
112-
llmToolCalls: llmToolCalls,
113-
tokensUsed: result.tokensUsage,
114-
text: result.text,
115-
messages: result.messages,
110+
modelResponse: {
111+
provider: model.provider,
112+
requestedModel: model.modelName,
113+
respondingModel: result.respondingModel,
114+
llmResponseTime: responseTime,
115+
toolCallingAccuracy: toolCallingAccuracy,
116+
llmToolCalls: llmToolCalls,
117+
tokensUsed: result.tokensUsage,
118+
text: result.text,
119+
messages: result.messages,
120+
},
116121
});
117122
});
118123
});

0 commit comments

Comments
 (0)