Skip to content

Commit 1760a91

Browse files
committed
bridge-per-key model, autoinject coder mcp using local mcp server
Signed-off-by: Danny Kopping <dannykopping@gmail.com>
1 parent 528358a commit 1760a91

File tree

9 files changed

+124
-171
lines changed

9 files changed

+124
-171
lines changed

aibridged/aibridged.go

Lines changed: 8 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,11 @@ type Server struct {
4848
shuttingDownB bool
4949
// shuttingDownCh will receive when we start graceful shutdown
5050
shuttingDownCh chan struct{}
51-
52-
bridge *Bridge
5351
}
5452

5553
var _ proto.DRPCAIBridgeDaemonServer = &Server{}
5654

57-
func New(rpcDialer Dialer, httpAddr string, logger slog.Logger, bridgeCfg codersdk.AIBridgeConfig, tools []*MCPTool) (*Server, error) {
55+
func New(rpcDialer Dialer, httpAddr string, logger slog.Logger) (*Server, error) {
5856
if rpcDialer == nil {
5957
return nil, xerrors.Errorf("nil rpcDialer given")
6058
}
@@ -71,29 +69,13 @@ func New(rpcDialer Dialer, httpAddr string, logger slog.Logger, bridgeCfg coders
7169
initConnectionCh: make(chan struct{}),
7270
}
7371

74-
// TODO: improve error handling here; if this fails it prevents the whole server from starting up!
75-
76-
bridge, err := NewBridge(bridgeCfg, httpAddr, logger.Named("ai_bridge"), daemon.client, tools)
77-
if err != nil {
78-
return nil, xerrors.Errorf("create new bridge server: %w", err)
79-
}
80-
81-
daemon.bridge = bridge
82-
8372
daemon.wg.Add(1)
8473
go daemon.connect()
8574

86-
daemon.wg.Add(1)
87-
go func() {
88-
defer daemon.wg.Done()
89-
err := bridge.Serve()
90-
// TODO: better error handling.
91-
// TODO: close on shutdown.
92-
logger.Error(ctx, "bridge server stopped", slog.Error(err))
93-
}()
94-
9575
return daemon, nil
96-
} // Connect establishes a connection to coderd.
76+
}
77+
78+
// Connect establishes a connection to coderd.
9779
func (s *Server) connect() {
9880
defer s.logger.Debug(s.closeContext, "connect loop exited")
9981
defer s.wg.Done()
@@ -155,7 +137,7 @@ connectLoop:
155137
}
156138
}
157139

158-
func (s *Server) client() (proto.DRPCAIBridgeDaemonClient, bool) {
140+
func (s *Server) Client() (proto.DRPCAIBridgeDaemonClient, bool) {
159141
select {
160142
case <-s.closeContext.Done():
161143
return nil, false
@@ -168,7 +150,7 @@ func (s *Server) client() (proto.DRPCAIBridgeDaemonClient, bool) {
168150
}
169151

170152
func (s *Server) TrackTokenUsage(ctx context.Context, in *proto.TrackTokenUsageRequest) (*proto.TrackTokenUsageResponse, error) {
171-
out, err := clientDoWithRetries(ctx, s.client, func(ctx context.Context, client proto.DRPCAIBridgeDaemonClient) (*proto.TrackTokenUsageResponse, error) {
153+
out, err := clientDoWithRetries(ctx, s.Client, func(ctx context.Context, client proto.DRPCAIBridgeDaemonClient) (*proto.TrackTokenUsageResponse, error) {
172154
return client.TrackTokenUsage(ctx, in)
173155
})
174156
if err != nil {
@@ -178,7 +160,7 @@ func (s *Server) TrackTokenUsage(ctx context.Context, in *proto.TrackTokenUsageR
178160
}
179161

180162
func (s *Server) TrackUserPrompt(ctx context.Context, in *proto.TrackUserPromptRequest) (*proto.TrackUserPromptResponse, error) {
181-
out, err := clientDoWithRetries(ctx, s.client, func(ctx context.Context, client proto.DRPCAIBridgeDaemonClient) (*proto.TrackUserPromptResponse, error) {
163+
out, err := clientDoWithRetries(ctx, s.Client, func(ctx context.Context, client proto.DRPCAIBridgeDaemonClient) (*proto.TrackUserPromptResponse, error) {
182164
return client.TrackUserPrompt(ctx, in)
183165
})
184166
if err != nil {
@@ -188,7 +170,7 @@ func (s *Server) TrackUserPrompt(ctx context.Context, in *proto.TrackUserPromptR
188170
}
189171

190172
func (s *Server) TrackToolUsage(ctx context.Context, in *proto.TrackToolUsageRequest) (*proto.TrackToolUsageResponse, error) {
191-
out, err := clientDoWithRetries(ctx, s.client, func(ctx context.Context, client proto.DRPCAIBridgeDaemonClient) (*proto.TrackToolUsageResponse, error) {
173+
out, err := clientDoWithRetries(ctx, s.Client, func(ctx context.Context, client proto.DRPCAIBridgeDaemonClient) (*proto.TrackToolUsageResponse, error) {
192174
return client.TrackToolUsage(ctx, in)
193175
})
194176
if err != nil {
@@ -197,63 +179,6 @@ func (s *Server) TrackToolUsage(ctx context.Context, in *proto.TrackToolUsageReq
197179
return out, nil
198180
}
199181

200-
// func (s *Server) ChatCompletions(payload *proto.JSONPayload, stream proto.DRPCOpenAIService_ChatCompletionsStream) error {
201-
// // TODO: call OpenAI API.
202-
//
203-
// select {
204-
// case <-stream.Context().Done():
205-
// return nil
206-
// default:
207-
// }
208-
//
209-
// err := stream.Send(&proto.JSONPayload{
210-
// Content: `
211-
//{
212-
// "id": "chatcmpl-B9MBs8CjcvOU2jLn4n570S5qMJKcT",
213-
// "object": "chat.completion",
214-
// "created": 1741569952,
215-
// "model": "gpt-4.1-2025-04-14",
216-
// "choices": [
217-
// {
218-
// "index": 0,
219-
// "message": {
220-
// "role": "assistant",
221-
// "content": "Hello! How can I assist you today?",
222-
// "refusal": null,
223-
// "annotations": []
224-
// },
225-
// "logprobs": null,
226-
// "finish_reason": "stop"
227-
// }
228-
// ],
229-
// "usage": {
230-
// "prompt_tokens": 19,
231-
// "completion_tokens": 10,
232-
// "total_tokens": 29,
233-
// "prompt_tokens_details": {
234-
// "cached_tokens": 0,
235-
// "audio_tokens": 0
236-
// },
237-
// "completion_tokens_details": {
238-
// "reasoning_tokens": 0,
239-
// "audio_tokens": 0,
240-
// "accepted_prediction_tokens": 0,
241-
// "rejected_prediction_tokens": 0
242-
// }
243-
// },
244-
// "service_tier": "default"
245-
//}
246-
// `})
247-
// if err != nil {
248-
// return xerrors.Errorf("stream chat completion response: %w", err)
249-
// }
250-
// return nil
251-
//}
252-
253-
func (s *Server) BridgeAddr() string {
254-
return s.bridge.Addr()
255-
}
256-
257182
// TODO: direct copy/paste from provisionerd, abstract into common util.
258183
func retryable(err error) bool {
259184
return xerrors.Is(err, yamux.ErrSessionShutdown) || xerrors.Is(err, io.EOF) || xerrors.Is(err, fasthttputil.ErrInmemoryListenerClosed) ||

aibridged/bridge.go

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -76,22 +76,20 @@ type Bridge struct {
7676
cfg codersdk.AIBridgeConfig
7777

7878
httpSrv *http.Server
79-
addr string
8079
clientFn func() (proto.DRPCAIBridgeDaemonClient, bool)
8180
logger slog.Logger
8281

8382
tools map[string]*MCPTool
8483
}
8584

86-
func NewBridge(cfg codersdk.AIBridgeConfig, addr string, logger slog.Logger, clientFn func() (proto.DRPCAIBridgeDaemonClient, bool), tools []*MCPTool) (*Bridge, error) {
85+
func NewBridge(cfg codersdk.AIBridgeConfig, logger slog.Logger, clientFn func() (proto.DRPCAIBridgeDaemonClient, bool), tools []*MCPTool) (*Bridge, error) {
8786
var bridge Bridge
8887

8988
mux := &http.ServeMux{}
9089
mux.HandleFunc("/v1/chat/completions", bridge.proxyOpenAIRequest)
9190
mux.HandleFunc("/v1/messages", bridge.proxyAnthropicRequest)
9291

9392
srv := &http.Server{
94-
Addr: addr,
9593
Handler: mux,
9694

9795
// TODO: configurable.
@@ -123,6 +121,10 @@ func (b *Bridge) openAITarget() *url.URL {
123121
return target
124122
}
125123

124+
func (b *Bridge) Handler() http.Handler {
125+
return b.httpSrv.Handler
126+
}
127+
126128
// proxyOpenAIRequest intercepts, filters, augments, and delivers requests & responses from client to upstream and back.
127129
//
128130
// References:
@@ -1239,15 +1241,9 @@ func (b *Bridge) Serve() error {
12391241
return xerrors.Errorf("listen: %w", err)
12401242
}
12411243

1242-
b.addr = list.Addr().String()
1243-
12441244
return b.httpSrv.Serve(list) // TODO: TLS.
12451245
}
12461246

1247-
func (b *Bridge) Addr() string {
1248-
return b.addr
1249-
}
1250-
12511247
// logConnectionError logs connection errors with appropriate severity
12521248
func (b *Bridge) logConnectionError(ctx context.Context, err error, operation string) {
12531249
if isConnectionError(err) {

aibridged/middleware.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,16 @@ package aibridged
22

33
import (
44
"bytes"
5+
"context"
56
"crypto/subtle"
67
"net/http"
78

89
"github.com/coder/coder/v2/coderd/database"
910
"github.com/coder/coder/v2/coderd/httpmw"
1011
)
1112

13+
type ContextKeyBridgeAPIKey struct{}
14+
1215
// AuthMiddleware extracts and validates authorization tokens for AI bridge endpoints.
1316
// It supports both Bearer tokens in Authorization headers and Coder session tokens
1417
// from cookies/headers following the same patterns as existing Coder authentication.
@@ -34,7 +37,8 @@ func AuthMiddleware(db database.Store) func(http.Handler) http.Handler {
3437
return
3538
}
3639

37-
next.ServeHTTP(rw, r)
40+
// Pass request with modify context including the request token.
41+
next.ServeHTTP(rw, r.WithContext(context.WithValue(ctx, ContextKeyBridgeAPIKey{}, token)))
3842
})
3943
}
4044
}

cli/server.go

Lines changed: 3 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ import (
3131
"sync/atomic"
3232
"time"
3333

34+
"github.com/ammario/tlru"
3435
"github.com/charmbracelet/lipgloss"
3536
"github.com/coreos/go-oidc/v3/oidc"
3637
"github.com/coreos/go-systemd/daemon"
@@ -1125,6 +1126,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
11251126
aiBridgeDaemons = append(aiBridgeDaemons, daemon)
11261127
}
11271128
coderAPI.AIBridgeDaemons = aiBridgeDaemons
1129+
coderAPI.AIBridges = tlru.New[string](tlru.ConstantCost[*aibridged.Bridge], 100) // TODO: configurable.
11281130

11291131
// Updates the systemd status from activating to activated.
11301132
_, err = daemon.SdNotify(false, daemon.SdNotifyReady)
@@ -1571,68 +1573,11 @@ func newProvisionerDaemon(
15711573

15721574
func newAIBridgeDaemon(ctx context.Context, coderAPI *coderd.API, name string, bridgeCfg codersdk.AIBridgeConfig) (*aibridged.Server, error) {
15731575
httpAddr := "0.0.0.0:0" // TODO: configurable.
1574-
1575-
// TODO: in reality, it won't work this way. We'll have to load the tools dynamically
1576-
tools, err := loadMCP(coderAPI.Logger.Named("mcp-tools"))
1577-
if err != nil {
1578-
coderAPI.Logger.Error(ctx, "failed to load MCP tools", slog.Error(err))
1579-
}
1580-
15811576
return aibridged.New(func(dialCtx context.Context) (aibridgedproto.DRPCAIBridgeDaemonClient, error) {
15821577
// This debounces calls to listen every second.
15831578
// TODO: is this true / necessary?
15841579
return coderAPI.CreateInMemoryAIBridgeDaemon(dialCtx, name)
1585-
}, httpAddr, coderAPI.Logger.Named("aibridged").With(slog.F("name", name)), bridgeCfg, tools)
1586-
}
1587-
1588-
func loadMCP(logger slog.Logger) ([]*aibridged.MCPTool, error) {
1589-
const (
1590-
githubMCPName = "github"
1591-
coderMCPName = "coder"
1592-
)
1593-
githubMCP, err := aibridged.NewMCPToolBridge(githubMCPName, "https://api.githubcopilot.com/mcp/", map[string]string{
1594-
"Authorization": "Bearer " + os.Getenv("GITHUB_MCP_TOKEN"),
1595-
}, logger.Named("mcp-bridge-github"))
1596-
if err != nil {
1597-
return nil, xerrors.Errorf("github MCP bridge setup: %w", err)
1598-
}
1599-
coderMCP, err := aibridged.NewMCPToolBridge(coderMCPName, "https://dev.coder.com/api/experimental/mcp/http", map[string]string{
1600-
// "Authorization": "Bearer " + os.Getenv("CODER_MCP_TOKEN"),
1601-
// This is necessary to even access the MCP endpoint.
1602-
"Coder-Session-Token": os.Getenv("CODER_MCP_SESSION_TOKEN"),
1603-
}, logger.Named("mcp-bridge-coder"))
1604-
if err != nil {
1605-
return nil, xerrors.Errorf("coder MCP bridge setup: %w", err)
1606-
}
1607-
1608-
var eg errgroup.Group
1609-
eg.Go(func() error {
1610-
ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
1611-
defer cancel()
1612-
1613-
err := githubMCP.Init(ctx)
1614-
if err == nil {
1615-
return nil
1616-
}
1617-
return xerrors.Errorf("github: %w", err)
1618-
})
1619-
eg.Go(func() error {
1620-
ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
1621-
defer cancel()
1622-
1623-
err := coderMCP.Init(ctx)
1624-
if err == nil {
1625-
return nil
1626-
}
1627-
return xerrors.Errorf("coder: %w", err)
1628-
})
1629-
1630-
// This must block requests until MCP proxies are setup.
1631-
if err := eg.Wait(); err != nil {
1632-
return nil, xerrors.Errorf("MCP proxy init: %w", err)
1633-
}
1634-
1635-
return append(githubMCP.ListTools(), coderMCP.ListTools()...), nil
1580+
}, httpAddr, coderAPI.Logger.Named("aibridged").With(slog.F("name", name)))
16361581
}
16371582

16381583
// nolint: revive

0 commit comments

Comments
 (0)