-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
Copy pathws.ts
96 lines (85 loc) · 2.89 KB
/
ws.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import * as WS from 'ws';
import { AzureOpenAI, OpenAI } from '../../index';
import type { RealtimeClientEvent, RealtimeServerEvent } from '../../resources/beta/realtime/realtime';
import { OpenAIRealtimeEmitter, buildRealtimeURL, isAzure } from './internal-base';
export class OpenAIRealtimeWS extends OpenAIRealtimeEmitter {
url: URL;
socket: WS.WebSocket;
constructor(
props: { model: string; options?: WS.ClientOptions | undefined },
client?: Pick<OpenAI, 'apiKey' | 'baseURL'>,
) {
super();
client ??= new OpenAI();
this.url = buildRealtimeURL(client, props.model);
this.socket = new WS.WebSocket(this.url, {
...props.options,
headers: {
...props.options?.headers,
...(isAzure(client) ? {} : { Authorization: `Bearer ${client.apiKey}` }),
'OpenAI-Beta': 'realtime=v1',
},
});
this.socket.on('message', (wsEvent) => {
const event = (() => {
try {
return JSON.parse(wsEvent.toString()) as RealtimeServerEvent;
} catch (err) {
this._onError(null, 'could not parse websocket event', err);
return null;
}
})();
if (event) {
this._emit('event', event);
if (event.type === 'error') {
this._onError(event);
} else {
// @ts-expect-error TS isn't smart enough to get the relationship right here
this._emit(event.type, event);
}
}
});
this.socket.on('error', (err) => {
this._onError(null, err.message, err);
});
}
static async azure(
client: Pick<AzureOpenAI, '_getAzureADToken' | 'apiVersion' | 'apiKey' | 'baseURL' | 'deploymentName'>,
options: { deploymentName?: string; options?: WS.ClientOptions | undefined } = {},
): Promise<OpenAIRealtimeWS> {
const deploymentName = options.deploymentName ?? client.deploymentName;
if (!deploymentName) {
throw new Error('No deployment name provided');
}
return new OpenAIRealtimeWS(
{ model: deploymentName, options: { headers: await getAzureHeaders(client) } },
client,
);
}
send(event: RealtimeClientEvent) {
try {
this.socket.send(JSON.stringify(event));
} catch (err) {
this._onError(null, 'could not send data', err);
}
}
close(props?: { code: number; reason: string }) {
try {
this.socket.close(props?.code ?? 1000, props?.reason ?? 'OK');
} catch (err) {
this._onError(null, 'could not close the connection', err);
}
}
}
async function getAzureHeaders(client: Pick<AzureOpenAI, '_getAzureADToken' | 'apiKey'>) {
if (client.apiKey !== '<Missing Key>') {
return { 'api-key': client.apiKey };
} else {
const token = await client._getAzureADToken();
if (token) {
return { Authorization: `Bearer ${token}` };
} else {
throw new Error('AzureOpenAI is not instantiated correctly. No API key or token provided.');
}
}
}