Skip to content

Commit 7bda811

Browse files
committed
feat(llm): add OpenAI support to LLM abstraction
1 parent d73389a commit 7bda811

File tree

4 files changed

+431
-0
lines changed

4 files changed

+431
-0
lines changed
Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
import { describe, expect, it, vi, beforeEach } from 'vitest';
2+
3+
import { TokenUsage } from '../../tokens.js';
4+
import { OpenAIProvider } from '../providers/openai.js';
5+
6+
// Mock the OpenAI module
7+
vi.mock('openai', () => {
8+
// Create a mock function for the create method
9+
const mockCreate = vi.fn().mockResolvedValue({
10+
id: 'chatcmpl-123',
11+
object: 'chat.completion',
12+
created: 1677858242,
13+
model: 'gpt-4',
14+
choices: [
15+
{
16+
index: 0,
17+
message: {
18+
role: 'assistant',
19+
content: 'This is a test response',
20+
tool_calls: [
21+
{
22+
id: 'tool-call-1',
23+
type: 'function',
24+
function: {
25+
name: 'testFunction',
26+
arguments: '{"arg1":"value1"}',
27+
},
28+
},
29+
],
30+
},
31+
finish_reason: 'stop',
32+
},
33+
],
34+
usage: {
35+
prompt_tokens: 10,
36+
completion_tokens: 20,
37+
total_tokens: 30,
38+
},
39+
});
40+
41+
// Return a mocked version of the OpenAI class
42+
return {
43+
default: class MockOpenAI {
44+
constructor() {
45+
// Constructor implementation
46+
}
47+
48+
chat = {
49+
completions: {
50+
create: mockCreate,
51+
},
52+
};
53+
},
54+
};
55+
});
56+
57+
describe('OpenAIProvider', () => {
58+
let provider: OpenAIProvider;
59+
60+
beforeEach(() => {
61+
// Set environment variable for testing
62+
process.env.OPENAI_API_KEY = 'test-api-key';
63+
provider = new OpenAIProvider('gpt-4');
64+
});
65+
66+
it('should initialize with correct properties', () => {
67+
expect(provider.name).toBe('openai');
68+
expect(provider.provider).toBe('openai.chat');
69+
expect(provider.model).toBe('gpt-4');
70+
});
71+
72+
it('should throw error if API key is missing', () => {
73+
// Clear environment variable
74+
const originalKey = process.env.OPENAI_API_KEY;
75+
delete process.env.OPENAI_API_KEY;
76+
77+
expect(() => new OpenAIProvider('gpt-4')).toThrow(
78+
'OpenAI API key is required',
79+
);
80+
81+
// Restore environment variable
82+
process.env.OPENAI_API_KEY = originalKey;
83+
});
84+
85+
it('should generate text and handle tool calls', async () => {
86+
const response = await provider.generateText({
87+
messages: [
88+
{ role: 'system', content: 'You are a helpful assistant.' },
89+
{ role: 'user', content: 'Hello, can you help me?' },
90+
],
91+
functions: [
92+
{
93+
name: 'testFunction',
94+
description: 'A test function',
95+
parameters: {
96+
type: 'object',
97+
properties: {
98+
arg1: { type: 'string' },
99+
},
100+
},
101+
},
102+
],
103+
});
104+
105+
expect(response.text).toBe('This is a test response');
106+
expect(response.toolCalls).toHaveLength(1);
107+
108+
const toolCall = response.toolCalls[0];
109+
expect(toolCall).toBeDefined();
110+
expect(toolCall?.name).toBe('testFunction');
111+
expect(toolCall?.id).toBe('tool-call-1');
112+
expect(toolCall?.content).toBe('{"arg1":"value1"}');
113+
114+
// Check token usage
115+
expect(response.tokenUsage).toBeInstanceOf(TokenUsage);
116+
expect(response.tokenUsage.input).toBe(10);
117+
expect(response.tokenUsage.output).toBe(20);
118+
});
119+
120+
it('should format messages correctly', async () => {
121+
await provider.generateText({
122+
messages: [
123+
{ role: 'system', content: 'You are a helpful assistant.' },
124+
{ role: 'user', content: 'Hello' },
125+
{ role: 'assistant', content: 'Hi there' },
126+
{
127+
role: 'tool_use',
128+
id: 'tool-1',
129+
name: 'testTool',
130+
content: '{"param":"value"}',
131+
},
132+
{
133+
role: 'tool_result',
134+
tool_use_id: 'tool-1',
135+
content: '{"result":"success"}',
136+
is_error: false,
137+
},
138+
],
139+
});
140+
141+
// Get the mock instance
142+
const client = provider['client'];
143+
const mockOpenAI = client?.chat?.completions
144+
?.create as unknown as ReturnType<typeof vi.fn>;
145+
146+
// Check that messages were formatted correctly
147+
expect(mockOpenAI).toHaveBeenCalled();
148+
149+
// Get the second call arguments (from this test)
150+
const calledWith = mockOpenAI.mock.calls[1]?.[0] || {};
151+
152+
expect(calledWith.messages).toHaveLength(5);
153+
154+
// We need to check each message individually to avoid TypeScript errors
155+
const systemMessage = calledWith.messages[0];
156+
if (
157+
systemMessage &&
158+
typeof systemMessage === 'object' &&
159+
'role' in systemMessage
160+
) {
161+
expect(systemMessage.role).toBe('system');
162+
expect(systemMessage.content).toBe('You are a helpful assistant.');
163+
}
164+
165+
const userMessage = calledWith.messages[1];
166+
if (
167+
userMessage &&
168+
typeof userMessage === 'object' &&
169+
'role' in userMessage
170+
) {
171+
expect(userMessage.role).toBe('user');
172+
expect(userMessage.content).toBe('Hello');
173+
}
174+
175+
const assistantMessage = calledWith.messages[2];
176+
if (
177+
assistantMessage &&
178+
typeof assistantMessage === 'object' &&
179+
'role' in assistantMessage
180+
) {
181+
expect(assistantMessage.role).toBe('assistant');
182+
expect(assistantMessage.content).toBe('Hi there');
183+
}
184+
185+
// Check tool_use formatting
186+
const toolUseMessage = calledWith.messages[3];
187+
if (
188+
toolUseMessage &&
189+
typeof toolUseMessage === 'object' &&
190+
'role' in toolUseMessage
191+
) {
192+
expect(toolUseMessage.role).toBe('assistant');
193+
expect(toolUseMessage.content).toBe(null);
194+
195+
if (
196+
'tool_calls' in toolUseMessage &&
197+
Array.isArray(toolUseMessage.tool_calls)
198+
) {
199+
expect(toolUseMessage.tool_calls.length).toBe(1);
200+
const toolCall = toolUseMessage.tool_calls[0];
201+
if (toolCall && 'function' in toolCall) {
202+
expect(toolCall.function.name).toBe('testTool');
203+
}
204+
}
205+
}
206+
207+
// Check tool_result formatting
208+
const toolResultMessage = calledWith.messages[4];
209+
if (
210+
toolResultMessage &&
211+
typeof toolResultMessage === 'object' &&
212+
'role' in toolResultMessage
213+
) {
214+
expect(toolResultMessage.role).toBe('tool');
215+
expect(toolResultMessage.content).toBe('{"result":"success"}');
216+
if ('tool_call_id' in toolResultMessage) {
217+
expect(toolResultMessage.tool_call_id).toBe('tool-1');
218+
}
219+
}
220+
});
221+
});

packages/agent/src/core/llm/provider.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
*/
44

55
import { AnthropicProvider } from './providers/anthropic.js';
6+
import { OpenAIProvider } from './providers/openai.js';
67
import { ProviderOptions, GenerateOptions, LLMResponse } from './types.js';
78

89
/**
@@ -39,6 +40,7 @@ const providerFactories: Record<
3940
(model: string, options: ProviderOptions) => LLMProvider
4041
> = {
4142
anthropic: (model, options) => new AnthropicProvider(model, options),
43+
openai: (model, options) => new OpenAIProvider(model, options),
4244
};
4345

4446
/**

0 commit comments

Comments
 (0)