Skip to content

Commit ce5e7e3

Browse files
committed
refactored openai into new structure
Signed-off-by: Danny Kopping <dannykopping@gmail.com>
1 parent 6bdcd57 commit ce5e7e3

14 files changed

+753
-920
lines changed

aibridged/aibridged.go

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -137,15 +137,15 @@ connectLoop:
137137
}
138138
}
139139

140-
func (s *Server) Client() (proto.DRPCAIBridgeDaemonClient, bool) {
140+
func (s *Server) Client() (proto.DRPCAIBridgeDaemonClient, error) {
141141
select {
142142
case <-s.closeContext.Done():
143-
return nil, false
143+
return nil, xerrors.New("context closed")
144144
case <-s.shuttingDownCh:
145145
// Shutting down should return a nil client and unblock
146-
return nil, false
146+
return nil, xerrors.New("shutting down")
147147
case client := <-s.clientCh:
148-
return client, true
148+
return client, nil
149149
}
150150
}
151151

@@ -179,7 +179,7 @@ func (s *Server) TrackToolUsage(ctx context.Context, in *proto.TrackToolUsageReq
179179
return out, nil
180180
}
181181

182-
// TODO: direct copy/paste from provisionerd, abstract into common util.
182+
// NOTE: mostly copypasta from provisionerd; might be work abstracting.
183183
func retryable(err error) bool {
184184
return xerrors.Is(err, yamux.ErrSessionShutdown) || xerrors.Is(err, io.EOF) || xerrors.Is(err, fasthttputil.ErrInmemoryListenerClosed) ||
185185
// annoyingly, dRPC sometimes returns context.Canceled if the transport was closed, even if the context for
@@ -190,15 +190,19 @@ func retryable(err error) bool {
190190
// clientDoWithRetries runs the function f with a client, and retries with
191191
// backoff until either the error returned is not retryable() or the context
192192
// expires.
193-
// TODO: direct copy/paste from provisionerd, abstract into common util.
193+
// NOTE: mostly copypasta from provisionerd; might be work abstracting.
194194
func clientDoWithRetries[T any](ctx context.Context,
195-
getClient func() (proto.DRPCAIBridgeDaemonClient, bool),
195+
getClient func() (proto.DRPCAIBridgeDaemonClient, error),
196196
f func(context.Context, proto.DRPCAIBridgeDaemonClient) (T, error),
197197
) (ret T, _ error) {
198198
for retrier := retry.New(25*time.Millisecond, 5*time.Second); retrier.Wait(ctx); {
199-
client, ok := getClient()
200-
if !ok {
201-
continue
199+
var empty T
200+
client, err := getClient()
201+
if err != nil {
202+
if retryable(err) {
203+
continue
204+
}
205+
return empty, err
202206
}
203207
resp, err := f(ctx, client)
204208
if retryable(err) {

aibridged/anthropic.go

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,14 @@ package aibridged
22

33
import (
44
"encoding/json"
5+
"regexp"
6+
"strings"
57

68
"github.com/anthropics/anthropic-sdk-go"
79
ant_param "github.com/anthropics/anthropic-sdk-go/packages/param"
810
"github.com/tidwall/gjson"
11+
"golang.org/x/xerrors"
12+
"tailscale.com/types/ptr"
913
)
1014

1115
type streamer interface {
@@ -161,6 +165,62 @@ func (b *BetaMessageNewParamsWrapper) UnmarshalJSON(raw []byte) error {
161165
b.Stream = extractStreamFlag(raw)
162166
return nil
163167
}
168+
164169
func (b *BetaMessageNewParamsWrapper) UseStreaming() bool {
165170
return b.Stream
166171
}
172+
173+
func (b *BetaMessageNewParamsWrapper) LastUserPrompt() (*string, error) {
174+
if b == nil {
175+
return nil, xerrors.New("nil struct")
176+
}
177+
178+
if len(b.Messages) == 0 {
179+
return nil, xerrors.New("no messages")
180+
}
181+
182+
var userMessage string
183+
for i := len(b.Messages) - 1; i >= 0; i-- {
184+
m := b.Messages[i]
185+
if m.Role != anthropic.BetaMessageParamRoleUser {
186+
continue
187+
}
188+
if len(m.Content) == 0 {
189+
continue
190+
}
191+
192+
for j := len(m.Content) - 1; j >= 0; j-- {
193+
if textContent := m.Content[j].GetText(); textContent != nil {
194+
userMessage = *textContent
195+
}
196+
197+
// Ignore internal Claude Code prompts.
198+
if userMessage == "test" ||
199+
strings.Contains(userMessage, "<system-reminder>") {
200+
userMessage = ""
201+
continue
202+
}
203+
204+
// Handle Cursor-specific formatting by extracting content from <user_query> tags
205+
if isCursor, _ := regexp.MatchString("<user_query>", userMessage); isCursor {
206+
userMessage = extractCursorUserQuery(userMessage)
207+
}
208+
return ptr.To(strings.TrimSpace(userMessage)), nil
209+
}
210+
}
211+
212+
return nil, nil
213+
}
214+
215+
func extractCursorUserQuery(message string) string {
216+
pat := regexp.MustCompile(`<user_query>(?P<content>[\s\S]*?)</user_query>`)
217+
match := pat.FindStringSubmatch(message)
218+
if match != nil {
219+
// Get the named group by index
220+
contentIndex := pat.SubexpIndex("content")
221+
if contentIndex != -1 {
222+
message = match[contentIndex]
223+
}
224+
}
225+
return strings.TrimSpace(message)
226+
}

0 commit comments

Comments
 (0)