@@ -75,21 +75,15 @@ type Bridge struct {
75
75
cfg codersdk.AIBridgeConfig
76
76
77
77
httpSrv * http.Server
78
- clientFn func () (proto.DRPCAIBridgeDaemonClient , bool )
78
+ clientFn func () (proto.DRPCAIBridgeDaemonClient , error )
79
79
logger slog.Logger
80
80
81
81
tools map [string ]* MCPTool
82
82
}
83
83
84
- func NewBridge (cfg codersdk.AIBridgeConfig , logger slog.Logger , clientFn func () (proto.DRPCAIBridgeDaemonClient , bool ), tools map [string ][]* MCPTool ) (* Bridge , error ) {
85
- var bridge Bridge
86
-
87
- mux := & http.ServeMux {}
88
- mux .HandleFunc ("/v1/chat/completions" , func (w http.ResponseWriter , r * http.Request ) {
89
- prov := NewOpenAIProvider (cfg .OpenAI .BaseURL .String (), cfg .OpenAI .Key .String ())
90
-
91
- // TODO: everything is generic beyond this point...
92
-
84
+ func handleOpenAI (provider * OpenAIChatProvider , drpcClient proto.DRPCAIBridgeDaemonClient , tools map [string ][]* MCPTool , logger slog.Logger ) func (http.ResponseWriter , * http.Request ) {
85
+ return func (w http.ResponseWriter , r * http.Request ) {
86
+ // Read and parse request.
93
87
body , err := io .ReadAll (r .Body )
94
88
if err != nil {
95
89
if isConnectionError (err ) {
@@ -100,35 +94,49 @@ func NewBridge(cfg codersdk.AIBridgeConfig, logger slog.Logger, clientFn func()
100
94
http .Error (w , "failed to read body" , http .StatusInternalServerError )
101
95
return
102
96
}
103
-
104
- req , err := prov .ParseRequest (body )
97
+ req , err := provider .ParseRequest (body )
105
98
if err != nil {
106
99
logger .Error (r .Context (), "failed to parse request" , slog .Error (err ))
107
100
http .Error (w , "failed to parse request" , http .StatusBadRequest )
101
+ return
108
102
}
109
103
110
- var sess Session [ChatCompletionNewParamsWrapper ]
104
+ // Create a new session.
105
+ var sess Session
111
106
if req .Stream {
112
- sess = prov . NewAsynchronousSession (req )
107
+ sess = provider . NewStreamingSession (req )
113
108
} else {
114
- sess = prov . NewSynchronousSession (req )
109
+ sess = provider . NewBlockingSession (req )
115
110
}
116
111
117
- coderdClient , ok := clientFn ()
118
- if ! ok {
119
- logger .Error (r .Context (), "could not acquire coderd client for tracking" )
120
- return
121
- }
112
+ sessID := sess .Init (logger , provider .baseURL , provider .key , NewDRPCTracker (drpcClient ), NewInjectedToolManager (tools ))
113
+ logger .Debug (context .Background (), "starting openai session" , slog .F ("session_id" , sessID ))
122
114
123
- sessID := sess .Init (logger , prov .baseURL , prov .key , NewDRPCTracker (coderdClient ), NewInjectedToolManager (tools ))
124
115
defer func () {
125
116
if err := sess .Close (); err != nil {
126
117
logger .Warn (context .Background (), "failed to close session" , slog .Error (err ), slog .F ("session_id" , sessID ), slog .F ("kind" , fmt .Sprintf ("%T" , sess )))
127
118
}
128
119
}()
129
120
130
- sess .Execute (req , w , r ) // TODO: handle error?
131
- })
121
+ // Process the request.
122
+ if err := sess .ProcessRequest (w , r ); err != nil {
123
+ logger .Error (r .Context (), "session execution failed" , slog .Error (err ))
124
+ }
125
+ }
126
+ }
127
+
128
+ func NewBridge (cfg codersdk.AIBridgeConfig , logger slog.Logger , clientFn func () (proto.DRPCAIBridgeDaemonClient , error ), tools map [string ][]* MCPTool ) (* Bridge , error ) {
129
+ var bridge Bridge
130
+
131
+ openAIProvider := NewOpenAIChatProvider (cfg .OpenAI .BaseURL .String (), cfg .OpenAI .Key .String ())
132
+
133
+ drpcClient , err := clientFn ()
134
+ if err != nil {
135
+ return nil , xerrors .Errorf ("could not acquire coderd client for tracking: %w" , err )
136
+ }
137
+
138
+ mux := & http.ServeMux {}
139
+ mux .HandleFunc ("/v1/chat/completions" , handleOpenAI (openAIProvider , drpcClient , tools , logger .Named ("openai" )))
132
140
mux .HandleFunc ("/v1/messages" , bridge .proxyAnthropicRequest )
133
141
134
142
srv := & http.Server {
@@ -172,7 +180,6 @@ func (b *Bridge) Handler() http.Handler {
172
180
// proxyOpenAIRequest intercepts, filters, augments, and delivers requests & responses from client to upstream and back.
173
181
//
174
182
// References:
175
- // - https://platform.openai.com/docs/api-reference/chat-streaming
176
183
func (b * Bridge ) proxyOpenAIRequest (w http.ResponseWriter , r * http.Request ) {
177
184
sessionID := uuid .NewString ()
178
185
b .logger .Info (r .Context (), "openai request started" , slog .F ("session_id" , sessionID ), slog .F ("method" , r .Method ), slog .F ("path" , r .URL .Path ))
@@ -245,7 +252,7 @@ func (b *Bridge) proxyOpenAIRequest(w http.ResponseWriter, r *http.Request) {
245
252
opts = append (opts , oai_option .WithBaseURL (baseURL ))
246
253
}
247
254
248
- opts = append (opts , oai_option .WithMiddleware (LoggingMiddleware ))
255
+ // opts = append(opts, oai_option.WithMiddleware(LoggingMiddleware))
249
256
250
257
client := openai .NewClient (opts ... )
251
258
req := in .ChatCompletionNewParams
@@ -714,7 +721,7 @@ func (b *Bridge) proxyAnthropicRequest(w http.ResponseWriter, r *http.Request) {
714
721
if reqBetaHeader := r .Header .Get ("anthropic-beta" ); strings .TrimSpace (reqBetaHeader ) != "" {
715
722
opts = append (opts , option .WithHeader ("anthropic-beta" , reqBetaHeader ))
716
723
}
717
- opts = append (opts , option .WithMiddleware (LoggingMiddleware ))
724
+ // opts = append(opts, option.WithMiddleware(LoggingMiddleware))
718
725
719
726
apiKey := b .cfg .Anthropic .Key .String ()
720
727
if apiKey == "" {
@@ -1233,9 +1240,9 @@ func (b *Bridge) proxyAnthropicRequest(w http.ResponseWriter, r *http.Request) {
1233
1240
}
1234
1241
1235
1242
func (b * Bridge ) trackToolUsage (ctx context.Context , sessionID , msgID , model , toolName string , toolInput interface {}, injected bool ) {
1236
- coderdClient , ok := b .clientFn ()
1237
- if ! ok {
1238
- b .logger .Error (ctx , "could not acquire coderd client for tool usage tracking" )
1243
+ coderdClient , err := b .clientFn ()
1244
+ if err != nil {
1245
+ b .logger .Error (ctx , "could not acquire coderd client for tool usage tracking" , slog . Error ( err ) )
1239
1246
return
1240
1247
}
1241
1248
@@ -1265,7 +1272,7 @@ func (b *Bridge) trackToolUsage(ctx context.Context, sessionID, msgID, model, to
1265
1272
}
1266
1273
}
1267
1274
1268
- _ , err : = coderdClient .TrackToolUsage (ctx , & proto.TrackToolUsageRequest {
1275
+ _ , err = coderdClient .TrackToolUsage (ctx , & proto.TrackToolUsageRequest {
1269
1276
SessionId : sessionID ,
1270
1277
MsgId : msgID ,
1271
1278
Model : model ,
@@ -1279,13 +1286,13 @@ func (b *Bridge) trackToolUsage(ctx context.Context, sessionID, msgID, model, to
1279
1286
}
1280
1287
1281
1288
func (b * Bridge ) trackUserPrompt (ctx context.Context , sessionID , msgID , model , prompt string ) {
1282
- coderdClient , ok := b .clientFn ()
1283
- if ! ok {
1284
- b .logger .Error (ctx , "could not acquire coderd client for user prompt tracking" )
1289
+ coderdClient , err := b .clientFn ()
1290
+ if err != nil {
1291
+ b .logger .Error (ctx , "could not acquire coderd client for user prompt tracking" , slog . Error ( err ) )
1285
1292
return
1286
1293
}
1287
1294
1288
- _ , err : = coderdClient .TrackUserPrompt (ctx , & proto.TrackUserPromptRequest {
1295
+ _ , err = coderdClient .TrackUserPrompt (ctx , & proto.TrackUserPromptRequest {
1289
1296
SessionId : sessionID ,
1290
1297
MsgId : msgID ,
1291
1298
Model : model ,
@@ -1297,13 +1304,13 @@ func (b *Bridge) trackUserPrompt(ctx context.Context, sessionID, msgID, model, p
1297
1304
}
1298
1305
1299
1306
func (b * Bridge ) trackTokenUsage (ctx context.Context , sessionID , msgID , model string , inputTokens , outputTokens int64 , other map [string ]int64 ) {
1300
- coderdClient , ok := b .clientFn ()
1301
- if ! ok {
1302
- b .logger .Error (ctx , "could not acquire coderd client for token usage tracking" )
1307
+ coderdClient , err := b .clientFn ()
1308
+ if err != nil {
1309
+ b .logger .Error (ctx , "could not acquire coderd client for token usage tracking" , slog . Error ( err ) )
1303
1310
return
1304
1311
}
1305
1312
1306
- _ , err : = coderdClient .TrackTokenUsage (ctx , & proto.TrackTokenUsageRequest {
1313
+ _ , err = coderdClient .TrackTokenUsage (ctx , & proto.TrackTokenUsageRequest {
1307
1314
SessionId : sessionID ,
1308
1315
MsgId : msgID ,
1309
1316
Model : model ,
0 commit comments