Skip to content

Commit 1ae6284

Browse files
committed
continued refactoring
Signed-off-by: Danny Kopping <dannykopping@gmail.com>
1 parent 6bdcd57 commit 1ae6284

14 files changed

+754
-494
lines changed

aibridged/aibridged.go

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -137,15 +137,15 @@ connectLoop:
137137
}
138138
}
139139

140-
func (s *Server) Client() (proto.DRPCAIBridgeDaemonClient, bool) {
140+
func (s *Server) Client() (proto.DRPCAIBridgeDaemonClient, error) {
141141
select {
142142
case <-s.closeContext.Done():
143-
return nil, false
143+
return nil, xerrors.New("context closed")
144144
case <-s.shuttingDownCh:
145145
// Shutting down should return a nil client and unblock
146-
return nil, false
146+
return nil, xerrors.New("shutting down")
147147
case client := <-s.clientCh:
148-
return client, true
148+
return client, nil
149149
}
150150
}
151151

@@ -179,7 +179,7 @@ func (s *Server) TrackToolUsage(ctx context.Context, in *proto.TrackToolUsageReq
179179
return out, nil
180180
}
181181

182-
// TODO: direct copy/paste from provisionerd, abstract into common util.
182+
// NOTE: mostly copypasta from provisionerd; might be work abstracting.
183183
func retryable(err error) bool {
184184
return xerrors.Is(err, yamux.ErrSessionShutdown) || xerrors.Is(err, io.EOF) || xerrors.Is(err, fasthttputil.ErrInmemoryListenerClosed) ||
185185
// annoyingly, dRPC sometimes returns context.Canceled if the transport was closed, even if the context for
@@ -190,15 +190,19 @@ func retryable(err error) bool {
190190
// clientDoWithRetries runs the function f with a client, and retries with
191191
// backoff until either the error returned is not retryable() or the context
192192
// expires.
193-
// TODO: direct copy/paste from provisionerd, abstract into common util.
193+
// NOTE: mostly copypasta from provisionerd; might be work abstracting.
194194
func clientDoWithRetries[T any](ctx context.Context,
195-
getClient func() (proto.DRPCAIBridgeDaemonClient, bool),
195+
getClient func() (proto.DRPCAIBridgeDaemonClient, error),
196196
f func(context.Context, proto.DRPCAIBridgeDaemonClient) (T, error),
197197
) (ret T, _ error) {
198198
for retrier := retry.New(25*time.Millisecond, 5*time.Second); retrier.Wait(ctx); {
199-
client, ok := getClient()
200-
if !ok {
201-
continue
199+
var empty T
200+
client, err := getClient()
201+
if err != nil {
202+
if retryable(err) {
203+
continue
204+
}
205+
return empty, err
202206
}
203207
resp, err := f(ctx, client)
204208
if retryable(err) {

aibridged/anthropic.go

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,14 @@ package aibridged
22

33
import (
44
"encoding/json"
5+
"regexp"
6+
"strings"
57

68
"github.com/anthropics/anthropic-sdk-go"
79
ant_param "github.com/anthropics/anthropic-sdk-go/packages/param"
810
"github.com/tidwall/gjson"
11+
"golang.org/x/xerrors"
12+
"tailscale.com/types/ptr"
913
)
1014

1115
type streamer interface {
@@ -161,6 +165,62 @@ func (b *BetaMessageNewParamsWrapper) UnmarshalJSON(raw []byte) error {
161165
b.Stream = extractStreamFlag(raw)
162166
return nil
163167
}
168+
164169
func (b *BetaMessageNewParamsWrapper) UseStreaming() bool {
165170
return b.Stream
166171
}
172+
173+
func (b *BetaMessageNewParamsWrapper) LastUserPrompt() (*string, error) {
174+
if b == nil {
175+
return nil, xerrors.New("nil struct")
176+
}
177+
178+
if len(b.Messages) == 0 {
179+
return nil, xerrors.New("no messages")
180+
}
181+
182+
var userMessage string
183+
for i := len(b.Messages) - 1; i >= 0; i-- {
184+
m := b.Messages[i]
185+
if m.Role != anthropic.BetaMessageParamRoleUser {
186+
continue
187+
}
188+
if len(m.Content) == 0 {
189+
continue
190+
}
191+
192+
for j := len(m.Content) - 1; j >= 0; j-- {
193+
if textContent := m.Content[j].GetText(); textContent != nil {
194+
userMessage = *textContent
195+
}
196+
197+
// Ignore internal Claude Code prompts.
198+
if userMessage == "test" ||
199+
strings.Contains(userMessage, "<system-reminder>") {
200+
userMessage = ""
201+
continue
202+
}
203+
204+
// Handle Cursor-specific formatting by extracting content from <user_query> tags
205+
if isCursor, _ := regexp.MatchString("<user_query>", userMessage); isCursor {
206+
userMessage = extractCursorUserQuery(userMessage)
207+
}
208+
return ptr.To(strings.TrimSpace(userMessage)), nil
209+
}
210+
}
211+
212+
return nil, nil
213+
}
214+
215+
func extractCursorUserQuery(message string) string {
216+
pat := regexp.MustCompile(`<user_query>(?P<content>[\s\S]*?)</user_query>`)
217+
match := pat.FindStringSubmatch(message)
218+
if match != nil {
219+
// Get the named group by index
220+
contentIndex := pat.SubexpIndex("content")
221+
if contentIndex != -1 {
222+
message = match[contentIndex]
223+
}
224+
}
225+
return strings.TrimSpace(message)
226+
}

aibridged/bridge.go

Lines changed: 45 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -75,21 +75,15 @@ type Bridge struct {
7575
cfg codersdk.AIBridgeConfig
7676

7777
httpSrv *http.Server
78-
clientFn func() (proto.DRPCAIBridgeDaemonClient, bool)
78+
clientFn func() (proto.DRPCAIBridgeDaemonClient, error)
7979
logger slog.Logger
8080

8181
tools map[string]*MCPTool
8282
}
8383

84-
func NewBridge(cfg codersdk.AIBridgeConfig, logger slog.Logger, clientFn func() (proto.DRPCAIBridgeDaemonClient, bool), tools map[string][]*MCPTool) (*Bridge, error) {
85-
var bridge Bridge
86-
87-
mux := &http.ServeMux{}
88-
mux.HandleFunc("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
89-
prov := NewOpenAIProvider(cfg.OpenAI.BaseURL.String(), cfg.OpenAI.Key.String())
90-
91-
// TODO: everything is generic beyond this point...
92-
84+
func handleOpenAI(provider *OpenAIChatProvider, drpcClient proto.DRPCAIBridgeDaemonClient, tools map[string][]*MCPTool, logger slog.Logger) func(http.ResponseWriter, *http.Request) {
85+
return func(w http.ResponseWriter, r *http.Request) {
86+
// Read and parse request.
9387
body, err := io.ReadAll(r.Body)
9488
if err != nil {
9589
if isConnectionError(err) {
@@ -100,35 +94,49 @@ func NewBridge(cfg codersdk.AIBridgeConfig, logger slog.Logger, clientFn func()
10094
http.Error(w, "failed to read body", http.StatusInternalServerError)
10195
return
10296
}
103-
104-
req, err := prov.ParseRequest(body)
97+
req, err := provider.ParseRequest(body)
10598
if err != nil {
10699
logger.Error(r.Context(), "failed to parse request", slog.Error(err))
107100
http.Error(w, "failed to parse request", http.StatusBadRequest)
101+
return
108102
}
109103

110-
var sess Session[ChatCompletionNewParamsWrapper]
104+
// Create a new session.
105+
var sess Session
111106
if req.Stream {
112-
sess = prov.NewAsynchronousSession(req)
107+
sess = provider.NewStreamingSession(req)
113108
} else {
114-
sess = prov.NewSynchronousSession(req)
109+
sess = provider.NewBlockingSession(req)
115110
}
116111

117-
coderdClient, ok := clientFn()
118-
if !ok {
119-
logger.Error(r.Context(), "could not acquire coderd client for tracking")
120-
return
121-
}
112+
sessID := sess.Init(logger, provider.baseURL, provider.key, NewDRPCTracker(drpcClient), NewInjectedToolManager(tools))
113+
logger.Debug(context.Background(), "starting openai session", slog.F("session_id", sessID))
122114

123-
sessID := sess.Init(logger, prov.baseURL, prov.key, NewDRPCTracker(coderdClient), NewInjectedToolManager(tools))
124115
defer func() {
125116
if err := sess.Close(); err != nil {
126117
logger.Warn(context.Background(), "failed to close session", slog.Error(err), slog.F("session_id", sessID), slog.F("kind", fmt.Sprintf("%T", sess)))
127118
}
128119
}()
129120

130-
sess.Execute(req, w, r) // TODO: handle error?
131-
})
121+
// Process the request.
122+
if err := sess.ProcessRequest(w, r); err != nil {
123+
logger.Error(r.Context(), "session execution failed", slog.Error(err))
124+
}
125+
}
126+
}
127+
128+
func NewBridge(cfg codersdk.AIBridgeConfig, logger slog.Logger, clientFn func() (proto.DRPCAIBridgeDaemonClient, error), tools map[string][]*MCPTool) (*Bridge, error) {
129+
var bridge Bridge
130+
131+
openAIProvider := NewOpenAIChatProvider(cfg.OpenAI.BaseURL.String(), cfg.OpenAI.Key.String())
132+
133+
drpcClient, err := clientFn()
134+
if err != nil {
135+
return nil, xerrors.Errorf("could not acquire coderd client for tracking: %w", err)
136+
}
137+
138+
mux := &http.ServeMux{}
139+
mux.HandleFunc("/v1/chat/completions", handleOpenAI(openAIProvider, drpcClient, tools, logger.Named("openai")))
132140
mux.HandleFunc("/v1/messages", bridge.proxyAnthropicRequest)
133141

134142
srv := &http.Server{
@@ -172,7 +180,6 @@ func (b *Bridge) Handler() http.Handler {
172180
// proxyOpenAIRequest intercepts, filters, augments, and delivers requests & responses from client to upstream and back.
173181
//
174182
// References:
175-
// - https://platform.openai.com/docs/api-reference/chat-streaming
176183
func (b *Bridge) proxyOpenAIRequest(w http.ResponseWriter, r *http.Request) {
177184
sessionID := uuid.NewString()
178185
b.logger.Info(r.Context(), "openai request started", slog.F("session_id", sessionID), slog.F("method", r.Method), slog.F("path", r.URL.Path))
@@ -245,7 +252,7 @@ func (b *Bridge) proxyOpenAIRequest(w http.ResponseWriter, r *http.Request) {
245252
opts = append(opts, oai_option.WithBaseURL(baseURL))
246253
}
247254

248-
opts = append(opts, oai_option.WithMiddleware(LoggingMiddleware))
255+
// opts = append(opts, oai_option.WithMiddleware(LoggingMiddleware))
249256

250257
client := openai.NewClient(opts...)
251258
req := in.ChatCompletionNewParams
@@ -714,7 +721,7 @@ func (b *Bridge) proxyAnthropicRequest(w http.ResponseWriter, r *http.Request) {
714721
if reqBetaHeader := r.Header.Get("anthropic-beta"); strings.TrimSpace(reqBetaHeader) != "" {
715722
opts = append(opts, option.WithHeader("anthropic-beta", reqBetaHeader))
716723
}
717-
opts = append(opts, option.WithMiddleware(LoggingMiddleware))
724+
// opts = append(opts, option.WithMiddleware(LoggingMiddleware))
718725

719726
apiKey := b.cfg.Anthropic.Key.String()
720727
if apiKey == "" {
@@ -1233,9 +1240,9 @@ func (b *Bridge) proxyAnthropicRequest(w http.ResponseWriter, r *http.Request) {
12331240
}
12341241

12351242
func (b *Bridge) trackToolUsage(ctx context.Context, sessionID, msgID, model, toolName string, toolInput interface{}, injected bool) {
1236-
coderdClient, ok := b.clientFn()
1237-
if !ok {
1238-
b.logger.Error(ctx, "could not acquire coderd client for tool usage tracking")
1243+
coderdClient, err := b.clientFn()
1244+
if err != nil {
1245+
b.logger.Error(ctx, "could not acquire coderd client for tool usage tracking", slog.Error(err))
12391246
return
12401247
}
12411248

@@ -1265,7 +1272,7 @@ func (b *Bridge) trackToolUsage(ctx context.Context, sessionID, msgID, model, to
12651272
}
12661273
}
12671274

1268-
_, err := coderdClient.TrackToolUsage(ctx, &proto.TrackToolUsageRequest{
1275+
_, err = coderdClient.TrackToolUsage(ctx, &proto.TrackToolUsageRequest{
12691276
SessionId: sessionID,
12701277
MsgId: msgID,
12711278
Model: model,
@@ -1279,13 +1286,13 @@ func (b *Bridge) trackToolUsage(ctx context.Context, sessionID, msgID, model, to
12791286
}
12801287

12811288
func (b *Bridge) trackUserPrompt(ctx context.Context, sessionID, msgID, model, prompt string) {
1282-
coderdClient, ok := b.clientFn()
1283-
if !ok {
1284-
b.logger.Error(ctx, "could not acquire coderd client for user prompt tracking")
1289+
coderdClient, err := b.clientFn()
1290+
if err != nil {
1291+
b.logger.Error(ctx, "could not acquire coderd client for user prompt tracking", slog.Error(err))
12851292
return
12861293
}
12871294

1288-
_, err := coderdClient.TrackUserPrompt(ctx, &proto.TrackUserPromptRequest{
1295+
_, err = coderdClient.TrackUserPrompt(ctx, &proto.TrackUserPromptRequest{
12891296
SessionId: sessionID,
12901297
MsgId: msgID,
12911298
Model: model,
@@ -1297,13 +1304,13 @@ func (b *Bridge) trackUserPrompt(ctx context.Context, sessionID, msgID, model, p
12971304
}
12981305

12991306
func (b *Bridge) trackTokenUsage(ctx context.Context, sessionID, msgID, model string, inputTokens, outputTokens int64, other map[string]int64) {
1300-
coderdClient, ok := b.clientFn()
1301-
if !ok {
1302-
b.logger.Error(ctx, "could not acquire coderd client for token usage tracking")
1307+
coderdClient, err := b.clientFn()
1308+
if err != nil {
1309+
b.logger.Error(ctx, "could not acquire coderd client for token usage tracking", slog.Error(err))
13031310
return
13041311
}
13051312

1306-
_, err := coderdClient.TrackTokenUsage(ctx, &proto.TrackTokenUsageRequest{
1313+
_, err = coderdClient.TrackTokenUsage(ctx, &proto.TrackTokenUsageRequest{
13071314
SessionId: sessionID,
13081315
MsgId: msgID,
13091316
Model: model,

0 commit comments

Comments
 (0)