Skip to content

Commit f5b4f49

Browse files
committed
cleaner abstraction
Signed-off-by: Danny Kopping <dannykopping@gmail.com>
1 parent 069d66d commit f5b4f49

25 files changed

+379
-502
lines changed

aibridged/bridge.go

Lines changed: 7 additions & 141 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,13 @@
11
package aibridged
22

33
import (
4-
"context"
5-
"fmt"
6-
"io"
74
"net/http"
85
"time"
96

107
"golang.org/x/xerrors"
118

129
"cdr.dev/slog"
1310

14-
"github.com/google/uuid"
15-
1611
"github.com/coder/coder/v2/aibridged/proto"
1712
"github.com/coder/coder/v2/codersdk"
1813
)
@@ -39,20 +34,18 @@ type Bridge struct {
3934
tools map[string]*MCPTool
4035
}
4136

42-
func NewBridge(cfg codersdk.AIBridgeConfig, logger slog.Logger, clientFn func() (proto.DRPCAIBridgeDaemonClient, error), tools map[string][]*MCPTool) (*Bridge, error) {
43-
var bridge Bridge
44-
45-
openAIChatProvider := NewOpenAIChatProvider(cfg.OpenAI.BaseURL.String(), cfg.OpenAI.Key.String())
46-
anthropicMessagesProvider := NewAnthropicMessagesProvider(cfg.Anthropic.BaseURL.String(), cfg.Anthropic.Key.String())
47-
37+
func NewBridge(cfg codersdk.AIBridgeConfig, logger slog.Logger, clientFn func() (proto.DRPCAIBridgeDaemonClient, error), tools ToolRegistry) (*Bridge, error) {
4838
drpcClient, err := clientFn()
4939
if err != nil {
5040
return nil, xerrors.Errorf("could not acquire coderd client for tracking: %w", err)
5141
}
5242

43+
openAIProvider := NewOpenAIProvider(cfg.OpenAI.BaseURL.String(), cfg.OpenAI.Key.String())
44+
anthropicMessagesProvider := NewAnthropicMessagesProvider(cfg.Anthropic.BaseURL.String(), cfg.Anthropic.Key.String())
45+
5346
mux := &http.ServeMux{}
54-
mux.HandleFunc("/v1/chat/completions", handleOpenAIChat(openAIChatProvider, drpcClient, tools, logger.Named("openai")))
55-
mux.HandleFunc("/v1/messages", handleAnthropicMessages(anthropicMessagesProvider, drpcClient, tools, logger.Named("anthropic")))
47+
mux.HandleFunc("/v1/chat/completions", NewSessionProcessor(openAIProvider, logger, drpcClient, tools))
48+
mux.HandleFunc("/v1/messages", NewSessionProcessor(anthropicMessagesProvider, logger, drpcClient, tools))
5649

5750
srv := &http.Server{
5851
Handler: mux,
@@ -64,6 +57,7 @@ func NewBridge(cfg codersdk.AIBridgeConfig, logger slog.Logger, clientFn func()
6457
ReadHeaderTimeout: 10 * time.Second,
6558
}
6659

60+
var bridge Bridge
6761
bridge.cfg = cfg
6862
bridge.httpSrv = srv
6963
bridge.clientFn = clientFn
@@ -82,131 +76,3 @@ func NewBridge(cfg codersdk.AIBridgeConfig, logger slog.Logger, clientFn func()
8276
func (b *Bridge) Handler() http.Handler {
8377
return b.httpSrv.Handler
8478
}
85-
86-
func handleOpenAIChat(provider *OpenAIChatProvider, drpcClient proto.DRPCAIBridgeDaemonClient, tools map[string][]*MCPTool, logger slog.Logger) func(http.ResponseWriter, *http.Request) {
87-
return func(w http.ResponseWriter, r *http.Request) {
88-
// Read and parse request.
89-
body, err := io.ReadAll(r.Body)
90-
if err != nil {
91-
if isConnectionError(err) {
92-
logger.Debug(r.Context(), "client disconnected during request body read", slog.Error(err))
93-
return // Don't send error response if client already disconnected
94-
}
95-
logger.Error(r.Context(), "failed to read body", slog.Error(err))
96-
http.Error(w, "failed to read body", http.StatusInternalServerError)
97-
return
98-
}
99-
req, err := provider.ParseRequest(body)
100-
if err != nil {
101-
logger.Error(r.Context(), "failed to parse request", slog.Error(err))
102-
http.Error(w, "failed to parse request", http.StatusBadRequest)
103-
return
104-
}
105-
106-
// Create a new session.
107-
var sess Session
108-
if req.Stream {
109-
sess = provider.NewStreamingSession(req)
110-
} else {
111-
sess = provider.NewBlockingSession(req)
112-
}
113-
114-
userID, ok := r.Context().Value(ContextKeyBridgeUserID{}).(uuid.UUID)
115-
if !ok {
116-
logger.Error(r.Context(), "missing initiator ID in context")
117-
http.Error(w, "unable to retrieve initiator", http.StatusInternalServerError)
118-
return
119-
}
120-
121-
resp, err := drpcClient.StartSession(r.Context(), &proto.StartSessionRequest{
122-
InitiatorId: userID.String(),
123-
Provider: "openai",
124-
Model: req.Model,
125-
})
126-
if err != nil {
127-
logger.Error(r.Context(), "failed to start session", slog.Error(err))
128-
http.Error(w, "failed to start session", http.StatusInternalServerError)
129-
return
130-
}
131-
132-
sessID := resp.GetSessionId()
133-
134-
sess.Init(sessID, logger, provider.baseURL, provider.key, NewDRPCTracker(drpcClient), NewInjectedToolManager(tools))
135-
logger.Debug(context.Background(), "starting openai session", slog.F("session_id", sessID))
136-
137-
defer func() {
138-
if err := sess.Close(); err != nil {
139-
logger.Warn(context.Background(), "failed to close session", slog.Error(err), slog.F("session_id", sessID), slog.F("kind", fmt.Sprintf("%T", sess)))
140-
}
141-
}()
142-
143-
// Process the request.
144-
if err := sess.ProcessRequest(w, r); err != nil {
145-
logger.Error(r.Context(), "session execution failed", slog.Error(err))
146-
}
147-
}
148-
}
149-
150-
func handleAnthropicMessages(provider *AnthropicMessagesProvider, drpcClient proto.DRPCAIBridgeDaemonClient, tools map[string][]*MCPTool, logger slog.Logger) func(http.ResponseWriter, *http.Request) {
151-
return func(w http.ResponseWriter, r *http.Request) {
152-
// Read and parse request.
153-
body, err := io.ReadAll(r.Body)
154-
if err != nil {
155-
if isConnectionError(err) {
156-
logger.Debug(r.Context(), "client disconnected during request body read", slog.Error(err))
157-
return // Don't send error response if client already disconnected
158-
}
159-
logger.Error(r.Context(), "failed to read body", slog.Error(err))
160-
http.Error(w, "failed to read body", http.StatusInternalServerError)
161-
return
162-
}
163-
req, err := provider.ParseRequest(body)
164-
if err != nil {
165-
logger.Error(r.Context(), "failed to parse request", slog.Error(err))
166-
http.Error(w, "failed to parse request", http.StatusBadRequest)
167-
return
168-
}
169-
170-
// Create a new session.
171-
var sess Session
172-
if req.UseStreaming() {
173-
sess = provider.NewStreamingSession(req)
174-
} else {
175-
sess = provider.NewBlockingSession(req)
176-
}
177-
178-
userID, ok := r.Context().Value(ContextKeyBridgeUserID{}).(uuid.UUID)
179-
if !ok {
180-
logger.Error(r.Context(), "missing initiator ID in context")
181-
http.Error(w, "unable to retrieve initiator", http.StatusInternalServerError)
182-
return
183-
}
184-
185-
resp, err := drpcClient.StartSession(r.Context(), &proto.StartSessionRequest{
186-
InitiatorId: userID.String(),
187-
Provider: "anthropic",
188-
Model: string(req.Model),
189-
})
190-
if err != nil {
191-
logger.Error(r.Context(), "failed to start session", slog.Error(err))
192-
http.Error(w, "failed to start session", http.StatusInternalServerError)
193-
return
194-
}
195-
196-
sessID := resp.GetSessionId()
197-
198-
sess.Init(sessID, logger, provider.baseURL, provider.key, NewDRPCTracker(drpcClient), NewInjectedToolManager(tools))
199-
logger.Debug(context.Background(), "starting anthropic messages session", slog.F("session_id", sessID))
200-
201-
defer func() {
202-
if err := sess.Close(); err != nil {
203-
logger.Warn(context.Background(), "failed to close session", slog.Error(err), slog.F("session_id", sessID), slog.F("kind", fmt.Sprintf("%T", sess)))
204-
}
205-
}()
206-
207-
// Process the request.
208-
if err := sess.ProcessRequest(w, r); err != nil {
209-
logger.Error(r.Context(), "session execution failed", slog.Error(err))
210-
}
211-
}
212-
}

aibridged/bridge_integration_test.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -866,9 +866,7 @@ func (f *fakeBridgeDaemonClient) StartSession(ctx context.Context, in *proto.Sta
866866
defer f.mu.Unlock()
867867
f.sessions = append(f.sessions, in)
868868

869-
return &proto.StartSessionResponse{
870-
SessionId: uuid.NewString(),
871-
}, nil
869+
return &proto.StartSessionResponse{}, nil
872870
}
873871

874872
func (f *fakeBridgeDaemonClient) TrackTokenUsage(ctx context.Context, in *proto.TrackTokenUsageRequest) (*proto.TrackTokenUsageResponse, error) {

0 commit comments

Comments
 (0)