Skip to content

Commit 069d66d

Browse files
committed
cheers wormhole!
Signed-off-by: Danny Kopping <dannykopping@gmail.com>
1 parent acd0f28 commit 069d66d

34 files changed

+1145
-292
lines changed

aibridged/aibridged.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,16 @@ func (s *Server) Client() (proto.DRPCAIBridgeDaemonClient, error) {
149149
}
150150
}
151151

152+
func (s *Server) StartSession(ctx context.Context, in *proto.StartSessionRequest) (*proto.StartSessionResponse, error) {
153+
out, err := clientDoWithRetries(ctx, s.Client, func(ctx context.Context, client proto.DRPCAIBridgeDaemonClient) (*proto.StartSessionResponse, error) {
154+
return client.StartSession(ctx, in)
155+
})
156+
if err != nil {
157+
return nil, err
158+
}
159+
return out, nil
160+
}
161+
152162
func (s *Server) TrackTokenUsage(ctx context.Context, in *proto.TrackTokenUsageRequest) (*proto.TrackTokenUsageResponse, error) {
153163
out, err := clientDoWithRetries(ctx, s.Client, func(ctx context.Context, client proto.DRPCAIBridgeDaemonClient) (*proto.TrackTokenUsageResponse, error) {
154164
return client.TrackTokenUsage(ctx, in)

aibridged/bridge.go

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ import (
1111

1212
"cdr.dev/slog"
1313

14+
"github.com/google/uuid"
15+
1416
"github.com/coder/coder/v2/aibridged/proto"
1517
"github.com/coder/coder/v2/codersdk"
1618
)
@@ -109,7 +111,27 @@ func handleOpenAIChat(provider *OpenAIChatProvider, drpcClient proto.DRPCAIBridg
109111
sess = provider.NewBlockingSession(req)
110112
}
111113

112-
sessID := sess.Init(logger, provider.baseURL, provider.key, NewDRPCTracker(drpcClient), NewInjectedToolManager(tools))
114+
userID, ok := r.Context().Value(ContextKeyBridgeUserID{}).(uuid.UUID)
115+
if !ok {
116+
logger.Error(r.Context(), "missing initiator ID in context")
117+
http.Error(w, "unable to retrieve initiator", http.StatusInternalServerError)
118+
return
119+
}
120+
121+
resp, err := drpcClient.StartSession(r.Context(), &proto.StartSessionRequest{
122+
InitiatorId: userID.String(),
123+
Provider: "openai",
124+
Model: req.Model,
125+
})
126+
if err != nil {
127+
logger.Error(r.Context(), "failed to start session", slog.Error(err))
128+
http.Error(w, "failed to start session", http.StatusInternalServerError)
129+
return
130+
}
131+
132+
sessID := resp.GetSessionId()
133+
134+
sess.Init(sessID, logger, provider.baseURL, provider.key, NewDRPCTracker(drpcClient), NewInjectedToolManager(tools))
113135
logger.Debug(context.Background(), "starting openai session", slog.F("session_id", sessID))
114136

115137
defer func() {
@@ -153,7 +175,27 @@ func handleAnthropicMessages(provider *AnthropicMessagesProvider, drpcClient pro
153175
sess = provider.NewBlockingSession(req)
154176
}
155177

156-
sessID := sess.Init(logger, provider.baseURL, provider.key, NewDRPCTracker(drpcClient), NewInjectedToolManager(tools))
178+
userID, ok := r.Context().Value(ContextKeyBridgeUserID{}).(uuid.UUID)
179+
if !ok {
180+
logger.Error(r.Context(), "missing initiator ID in context")
181+
http.Error(w, "unable to retrieve initiator", http.StatusInternalServerError)
182+
return
183+
}
184+
185+
resp, err := drpcClient.StartSession(r.Context(), &proto.StartSessionRequest{
186+
InitiatorId: userID.String(),
187+
Provider: "anthropic",
188+
Model: string(req.Model),
189+
})
190+
if err != nil {
191+
logger.Error(r.Context(), "failed to start session", slog.Error(err))
192+
http.Error(w, "failed to start session", http.StatusInternalServerError)
193+
return
194+
}
195+
196+
sessID := resp.GetSessionId()
197+
198+
sess.Init(sessID, logger, provider.baseURL, provider.key, NewDRPCTracker(drpcClient), NewInjectedToolManager(tools))
157199
logger.Debug(context.Background(), "starting anthropic messages session", slog.F("session_id", sessID))
158200

159201
defer func() {

aibridged/bridge_integration_test.go

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ func TestAnthropicMessages(t *testing.T) {
130130
}, nil)
131131
require.NoError(t, err)
132132

133-
mockSrv := httptest.NewServer(b.Handler())
133+
mockSrv := httptest.NewServer(withInitiator(getCurrentUserID(t, client), b.Handler()))
134134
// Make API call to aibridge for Anthropic /v1/messages
135135
req := createAnthropicMessagesReq(t, mockSrv.URL, reqBody)
136136
client := &http.Client{}
@@ -168,7 +168,6 @@ func TestAnthropicMessages(t *testing.T) {
168168
})
169169
}
170170
})
171-
172171
}
173172

174173
func TestOpenAIChatCompletions(t *testing.T) {
@@ -234,7 +233,7 @@ func TestOpenAIChatCompletions(t *testing.T) {
234233
}, nil)
235234
require.NoError(t, err)
236235

237-
mockSrv := httptest.NewServer(b.Handler())
236+
mockSrv := httptest.NewServer(withInitiator(getCurrentUserID(t, client), b.Handler()))
238237
// Make API call to aibridge for OpenAI /v1/chat/completions
239238
req := createOpenAIChatCompletionsReq(t, mockSrv.URL, reqBody)
240239

@@ -275,7 +274,6 @@ func TestOpenAIChatCompletions(t *testing.T) {
275274
})
276275
}
277276
})
278-
279277
}
280278

281279
func TestSimple(t *testing.T) {
@@ -426,7 +424,7 @@ func TestSimple(t *testing.T) {
426424
b, err := tc.configureFunc(srv.URL, coderdClient)
427425
require.NoError(t, err)
428426

429-
mockSrv := httptest.NewServer(b.Handler())
427+
mockSrv := httptest.NewServer(withInitiator(getCurrentUserID(t, client), b.Handler()))
430428
// When: calling the "API server" with the fixture's request body.
431429
req := tc.createRequest(t, mockSrv.URL, reqBody)
432430
client := &http.Client{}
@@ -681,7 +679,7 @@ func TestInjectedTool(t *testing.T) {
681679
require.NoError(t, err)
682680

683681
// Invoke request to mocked API via aibridge.
684-
bridgeSrv := httptest.NewServer(b.Handler())
682+
bridgeSrv := httptest.NewServer(withInitiator(getCurrentUserID(t, client), b.Handler()))
685683
t.Cleanup(bridgeSrv.Close)
686684

687685
req := tc.createRequest(t, bridgeSrv.URL, reqBody)
@@ -846,9 +844,12 @@ func newMockServer(ctx context.Context, t *testing.T, files archiveFileMap, resp
846844
return ms
847845
}
848846

847+
var _ proto.DRPCAIBridgeDaemonClient = &fakeBridgeDaemonClient{}
848+
849849
type fakeBridgeDaemonClient struct {
850850
mu sync.Mutex
851851

852+
sessions []*proto.StartSessionRequest
852853
tokenUsages []*proto.TrackTokenUsageRequest
853854
userPrompts []*proto.TrackUserPromptRequest
854855
toolUsages []*proto.TrackToolUsageRequest
@@ -859,6 +860,17 @@ func (*fakeBridgeDaemonClient) DRPCConn() drpc.Conn {
859860
return conn
860861
}
861862

863+
// StartSession implements proto.DRPCAIBridgeDaemonClient.
864+
func (f *fakeBridgeDaemonClient) StartSession(ctx context.Context, in *proto.StartSessionRequest) (*proto.StartSessionResponse, error) {
865+
f.mu.Lock()
866+
defer f.mu.Unlock()
867+
f.sessions = append(f.sessions, in)
868+
869+
return &proto.StartSessionResponse{
870+
SessionId: uuid.NewString(),
871+
}, nil
872+
}
873+
862874
func (f *fakeBridgeDaemonClient) TrackTokenUsage(ctx context.Context, in *proto.TrackTokenUsageRequest) (*proto.TrackTokenUsageResponse, error) {
863875
f.mu.Lock()
864876
defer f.mu.Unlock()
@@ -903,3 +915,20 @@ func createMockMCPSrv(t *testing.T) http.Handler {
903915

904916
return server.NewStreamableHTTPServer(s)
905917
}
918+
919+
// withInitiator wraps a handler injecting the Bridge user ID into context.
920+
// TODO: this is only necessary because we're not exercising the real API's middleware, which may hide some problems.
921+
func withInitiator(userID uuid.UUID, next http.Handler) http.Handler {
922+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
923+
ctx := context.WithValue(r.Context(), aibridged.ContextKeyBridgeUserID{}, userID)
924+
next.ServeHTTP(w, r.WithContext(ctx))
925+
})
926+
}
927+
928+
func getCurrentUserID(t *testing.T, client *codersdk.Client) uuid.UUID {
929+
t.Helper()
930+
931+
me, err := client.User(t.Context(), "me")
932+
require.NoError(t, err)
933+
return me.ID
934+
}

aibridged/middleware.go

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
)
1212

1313
type ContextKeyBridgeAPIKey struct{}
14+
type ContextKeyBridgeUserID struct{}
1415

1516
// AuthMiddleware extracts and validates authorization tokens for AI bridge endpoints.
1617
// It supports both Bearer tokens in Authorization headers and Coder session tokens
@@ -28,7 +29,7 @@ func AuthMiddleware(db database.Store) func(http.Handler) http.Handler {
2829
}
2930

3031
// Validate token using httpmw.APIKeyFromRequest
31-
_, _, ok := httpmw.APIKeyFromRequest(ctx, db, func(r *http.Request) string {
32+
key, _, ok := httpmw.APIKeyFromRequest(ctx, db, func(r *http.Request) string {
3233
return token
3334
}, &http.Request{})
3435

@@ -37,8 +38,12 @@ func AuthMiddleware(db database.Store) func(http.Handler) http.Handler {
3738
return
3839
}
3940

41+
ctx = context.WithValue(
42+
context.WithValue(ctx, ContextKeyBridgeUserID{}, key.UserID),
43+
ContextKeyBridgeAPIKey{}, token)
44+
4045
// Pass request with modify context including the request token.
41-
next.ServeHTTP(rw, r.WithContext(context.WithValue(ctx, ContextKeyBridgeAPIKey{}, token)))
46+
next.ServeHTTP(rw, r.WithContext(ctx))
4247
})
4348
}
4449
}

0 commit comments

Comments
 (0)