Skip to content

Commit 48ab223

Browse files
committed
cheers wormhole!
Signed-off-by: Danny Kopping <dannykopping@gmail.com>
1 parent acd0f28 commit 48ab223

34 files changed

+1145
-290
lines changed

aibridged/aibridged.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,16 @@ func (s *Server) Client() (proto.DRPCAIBridgeDaemonClient, error) {
149149
}
150150
}
151151

152+
func (s *Server) StartSession(ctx context.Context, in *proto.StartSessionRequest) (*proto.StartSessionResponse, error) {
153+
out, err := clientDoWithRetries(ctx, s.Client, func(ctx context.Context, client proto.DRPCAIBridgeDaemonClient) (*proto.StartSessionResponse, error) {
154+
return client.StartSession(ctx, in)
155+
})
156+
if err != nil {
157+
return nil, err
158+
}
159+
return out, nil
160+
}
161+
152162
func (s *Server) TrackTokenUsage(ctx context.Context, in *proto.TrackTokenUsageRequest) (*proto.TrackTokenUsageResponse, error) {
153163
out, err := clientDoWithRetries(ctx, s.Client, func(ctx context.Context, client proto.DRPCAIBridgeDaemonClient) (*proto.TrackTokenUsageResponse, error) {
154164
return client.TrackTokenUsage(ctx, in)

aibridged/bridge.go

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ import (
1111

1212
"cdr.dev/slog"
1313

14+
"github.com/google/uuid"
15+
1416
"github.com/coder/coder/v2/aibridged/proto"
1517
"github.com/coder/coder/v2/codersdk"
1618
)
@@ -109,7 +111,27 @@ func handleOpenAIChat(provider *OpenAIChatProvider, drpcClient proto.DRPCAIBridg
109111
sess = provider.NewBlockingSession(req)
110112
}
111113

112-
sessID := sess.Init(logger, provider.baseURL, provider.key, NewDRPCTracker(drpcClient), NewInjectedToolManager(tools))
114+
userID, ok := r.Context().Value(ContextKeyBridgeUserID{}).(uuid.UUID)
115+
if !ok {
116+
logger.Error(r.Context(), "missing initiator ID in context")
117+
http.Error(w, "unable to retrieve initiator", http.StatusInternalServerError)
118+
return
119+
}
120+
121+
resp, err := drpcClient.StartSession(r.Context(), &proto.StartSessionRequest{
122+
InitiatorId: userID.String(),
123+
Provider: "openai",
124+
Model: req.Model,
125+
})
126+
if err != nil {
127+
logger.Error(r.Context(), "failed to start session", slog.Error(err))
128+
http.Error(w, "failed to start session", http.StatusInternalServerError)
129+
return
130+
}
131+
132+
sessID := resp.GetSessionId()
133+
134+
sess.Init(sessID, logger, provider.baseURL, provider.key, NewDRPCTracker(drpcClient), NewInjectedToolManager(tools))
113135
logger.Debug(context.Background(), "starting openai session", slog.F("session_id", sessID))
114136

115137
defer func() {
@@ -153,7 +175,27 @@ func handleAnthropicMessages(provider *AnthropicMessagesProvider, drpcClient pro
153175
sess = provider.NewBlockingSession(req)
154176
}
155177

156-
sessID := sess.Init(logger, provider.baseURL, provider.key, NewDRPCTracker(drpcClient), NewInjectedToolManager(tools))
178+
userID, ok := r.Context().Value(ContextKeyBridgeUserID{}).(uuid.UUID)
179+
if !ok {
180+
logger.Error(r.Context(), "missing initiator ID in context")
181+
http.Error(w, "unable to retrieve initiator", http.StatusInternalServerError)
182+
return
183+
}
184+
185+
resp, err := drpcClient.StartSession(r.Context(), &proto.StartSessionRequest{
186+
InitiatorId: userID.String(),
187+
Provider: "anthropic",
188+
Model: string(req.Model),
189+
})
190+
if err != nil {
191+
logger.Error(r.Context(), "failed to start session", slog.Error(err))
192+
http.Error(w, "failed to start session", http.StatusInternalServerError)
193+
return
194+
}
195+
196+
sessID := resp.GetSessionId()
197+
198+
sess.Init(sessID, logger, provider.baseURL, provider.key, NewDRPCTracker(drpcClient), NewInjectedToolManager(tools))
157199
logger.Debug(context.Background(), "starting anthropic messages session", slog.F("session_id", sessID))
158200

159201
defer func() {

aibridged/bridge_integration_test.go

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ func TestAnthropicMessages(t *testing.T) {
130130
}, nil)
131131
require.NoError(t, err)
132132

133-
mockSrv := httptest.NewServer(b.Handler())
133+
mockSrv := httptest.NewServer(withInitiator(getCurrentUserID(t, client), b.Handler()))
134134
// Make API call to aibridge for Anthropic /v1/messages
135135
req := createAnthropicMessagesReq(t, mockSrv.URL, reqBody)
136136
client := &http.Client{}
@@ -234,7 +234,7 @@ func TestOpenAIChatCompletions(t *testing.T) {
234234
}, nil)
235235
require.NoError(t, err)
236236

237-
mockSrv := httptest.NewServer(b.Handler())
237+
mockSrv := httptest.NewServer(withInitiator(getCurrentUserID(t, client), b.Handler()))
238238
// Make API call to aibridge for OpenAI /v1/chat/completions
239239
req := createOpenAIChatCompletionsReq(t, mockSrv.URL, reqBody)
240240

@@ -426,7 +426,7 @@ func TestSimple(t *testing.T) {
426426
b, err := tc.configureFunc(srv.URL, coderdClient)
427427
require.NoError(t, err)
428428

429-
mockSrv := httptest.NewServer(b.Handler())
429+
mockSrv := httptest.NewServer(withInitiator(getCurrentUserID(t, client), b.Handler()))
430430
// When: calling the "API server" with the fixture's request body.
431431
req := tc.createRequest(t, mockSrv.URL, reqBody)
432432
client := &http.Client{}
@@ -681,7 +681,7 @@ func TestInjectedTool(t *testing.T) {
681681
require.NoError(t, err)
682682

683683
// Invoke request to mocked API via aibridge.
684-
bridgeSrv := httptest.NewServer(b.Handler())
684+
bridgeSrv := httptest.NewServer(withInitiator(getCurrentUserID(t, client), b.Handler()))
685685
t.Cleanup(bridgeSrv.Close)
686686

687687
req := tc.createRequest(t, bridgeSrv.URL, reqBody)
@@ -846,9 +846,12 @@ func newMockServer(ctx context.Context, t *testing.T, files archiveFileMap, resp
846846
return ms
847847
}
848848

849+
var _ proto.DRPCAIBridgeDaemonClient = &fakeBridgeDaemonClient{}
850+
849851
type fakeBridgeDaemonClient struct {
850852
mu sync.Mutex
851853

854+
sessions []*proto.StartSessionRequest
852855
tokenUsages []*proto.TrackTokenUsageRequest
853856
userPrompts []*proto.TrackUserPromptRequest
854857
toolUsages []*proto.TrackToolUsageRequest
@@ -859,6 +862,17 @@ func (*fakeBridgeDaemonClient) DRPCConn() drpc.Conn {
859862
return conn
860863
}
861864

865+
// StartSession implements proto.DRPCAIBridgeDaemonClient.
866+
func (f *fakeBridgeDaemonClient) StartSession(ctx context.Context, in *proto.StartSessionRequest) (*proto.StartSessionResponse, error) {
867+
f.mu.Lock()
868+
defer f.mu.Unlock()
869+
f.sessions = append(f.sessions, in)
870+
871+
return &proto.StartSessionResponse{
872+
SessionId: uuid.NewString(),
873+
}, nil
874+
}
875+
862876
func (f *fakeBridgeDaemonClient) TrackTokenUsage(ctx context.Context, in *proto.TrackTokenUsageRequest) (*proto.TrackTokenUsageResponse, error) {
863877
f.mu.Lock()
864878
defer f.mu.Unlock()
@@ -903,3 +917,20 @@ func createMockMCPSrv(t *testing.T) http.Handler {
903917

904918
return server.NewStreamableHTTPServer(s)
905919
}
920+
921+
// withInitiator wraps a handler injecting the Bridge user ID into context.
922+
// TODO: this is only necessary because we're not exercising the real API's middleware, which may hide some problems.
923+
func withInitiator(userID uuid.UUID, next http.Handler) http.Handler {
924+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
925+
ctx := context.WithValue(r.Context(), aibridged.ContextKeyBridgeUserID{}, userID)
926+
next.ServeHTTP(w, r.WithContext(ctx))
927+
})
928+
}
929+
930+
func getCurrentUserID(t *testing.T, client *codersdk.Client) uuid.UUID {
931+
t.Helper()
932+
933+
me, err := client.User(t.Context(), "me")
934+
require.NoError(t, err)
935+
return me.ID
936+
}

aibridged/middleware.go

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
)
1212

1313
type ContextKeyBridgeAPIKey struct{}
14+
type ContextKeyBridgeUserID struct{}
1415

1516
// AuthMiddleware extracts and validates authorization tokens for AI bridge endpoints.
1617
// It supports both Bearer tokens in Authorization headers and Coder session tokens
@@ -28,7 +29,7 @@ func AuthMiddleware(db database.Store) func(http.Handler) http.Handler {
2829
}
2930

3031
// Validate token using httpmw.APIKeyFromRequest
31-
_, _, ok := httpmw.APIKeyFromRequest(ctx, db, func(r *http.Request) string {
32+
key, _, ok := httpmw.APIKeyFromRequest(ctx, db, func(r *http.Request) string {
3233
return token
3334
}, &http.Request{})
3435

@@ -37,8 +38,12 @@ func AuthMiddleware(db database.Store) func(http.Handler) http.Handler {
3738
return
3839
}
3940

41+
ctx = context.WithValue(
42+
context.WithValue(ctx, ContextKeyBridgeUserID{}, key.UserID),
43+
ContextKeyBridgeAPIKey{}, token)
44+
4045
// Pass request with modify context including the request token.
41-
next.ServeHTTP(rw, r.WithContext(context.WithValue(ctx, ContextKeyBridgeAPIKey{}, token)))
46+
next.ServeHTTP(rw, r.WithContext(ctx))
4247
})
4348
}
4449
}

0 commit comments

Comments
 (0)