Skip to content

Commit e6a4e3f

Browse files
committed
WIP: bridge refactor
Signed-off-by: Danny Kopping <dannykopping@gmail.com>
1 parent 645b6df commit e6a4e3f

File tree

10 files changed

+614
-83
lines changed

10 files changed

+614
-83
lines changed

aibridged/bridge.go

Lines changed: 49 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,54 @@ type Bridge struct {
8181
tools map[string]*MCPTool
8282
}
8383

84-
func NewBridge(cfg codersdk.AIBridgeConfig, logger slog.Logger, clientFn func() (proto.DRPCAIBridgeDaemonClient, bool), tools []*MCPTool) (*Bridge, error) {
84+
func NewBridge(cfg codersdk.AIBridgeConfig, logger slog.Logger, clientFn func() (proto.DRPCAIBridgeDaemonClient, bool), tools map[string][]*MCPTool) (*Bridge, error) {
8585
var bridge Bridge
8686

8787
mux := &http.ServeMux{}
88-
mux.HandleFunc("/v1/chat/completions", bridge.proxyOpenAIRequest)
88+
mux.HandleFunc("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
89+
prov := NewOpenAIProvider(cfg.OpenAI.BaseURL.String(), cfg.OpenAI.Key.String())
90+
91+
// TODO: everything is generic beyond this point...
92+
93+
body, err := io.ReadAll(r.Body)
94+
if err != nil {
95+
if isConnectionError(err) {
96+
logger.Debug(r.Context(), "client disconnected during request body read", slog.Error(err))
97+
return // Don't send error response if client already disconnected
98+
}
99+
logger.Error(r.Context(), "failed to read body", slog.Error(err))
100+
http.Error(w, "failed to read body", http.StatusInternalServerError)
101+
return
102+
}
103+
104+
req, err := prov.ParseRequest(body)
105+
if err != nil {
106+
logger.Error(r.Context(), "failed to parse request", slog.Error(err))
107+
http.Error(w, "failed to parse request", http.StatusBadRequest)
108+
}
109+
110+
var sess Session[ChatCompletionNewParamsWrapper]
111+
if req.Stream {
112+
sess = prov.NewAsynchronousSession(req)
113+
} else {
114+
sess = prov.NewSynchronousSession(req)
115+
}
116+
117+
coderdClient, ok := clientFn()
118+
if !ok {
119+
logger.Error(r.Context(), "could not acquire coderd client for tracking")
120+
return
121+
}
122+
123+
sessID := sess.Init(logger, prov.baseURL, prov.key, NewDRPCTracker(coderdClient), NewInjectedToolManager(tools))
124+
defer func() {
125+
if err := sess.Close(); err != nil {
126+
logger.Warn(context.Background(), "failed to close session", slog.Error(err), slog.F("session_id", sessID), slog.F("kind", fmt.Sprintf("%T", sess)))
127+
}
128+
}()
129+
130+
sess.Execute(req, w, r) // TODO: handle error?
131+
})
89132
mux.HandleFunc("/v1/messages", bridge.proxyAnthropicRequest)
90133

91134
srv := &http.Server{
@@ -104,8 +147,10 @@ func NewBridge(cfg codersdk.AIBridgeConfig, logger slog.Logger, clientFn func()
104147
bridge.logger = logger
105148

106149
bridge.tools = make(map[string]*MCPTool, len(tools))
107-
for _, tool := range tools {
108-
bridge.tools[tool.ID] = tool
150+
for _, serverTools := range tools {
151+
for _, tool := range serverTools {
152+
bridge.tools[tool.ID] = tool
153+
}
109154
}
110155

111156
return &bridge, nil
@@ -166,11 +211,6 @@ func (b *Bridge) proxyOpenAIRequest(w http.ResponseWriter, r *http.Request) {
166211
b.trackUserPrompt(ctx, sessionID, "", in.Model, *prompt)
167212
}
168213

169-
// Prepend assistant message.
170-
in.Messages = append([]openai.ChatCompletionMessageParamUnion{
171-
openai.SystemMessage("You are a helpful assistant that explicitly mentions when tool calls are about to be made."),
172-
}, in.Messages...)
173-
174214
for _, tool := range b.tools {
175215
fn := openai.ChatCompletionToolParam{
176216
Function: openai.FunctionDefinitionParam{

aibridged/openai.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
package aibridged
22

33
import (
4+
"regexp"
5+
"strings"
6+
47
"github.com/openai/openai-go"
58
"github.com/openai/openai-go/packages/param"
69
"github.com/tidwall/gjson"
10+
"golang.org/x/xerrors"
11+
"tailscale.com/types/ptr"
712
)
813

914
// ChatCompletionNewParamsWrapper exists because the "stream" param is not included in openai.ChatCompletionNewParams.
@@ -42,6 +47,36 @@ func (c *ChatCompletionNewParamsWrapper) UnmarshalJSON(raw []byte) error {
4247
return nil
4348
}
4449

50+
func (c *ChatCompletionNewParamsWrapper) LastUserPrompt() (*string, error) {
51+
if c == nil {
52+
return nil, xerrors.New("nil struct")
53+
}
54+
55+
if len(c.Messages) == 0 {
56+
return nil, xerrors.New("no messages")
57+
}
58+
59+
var msg *openai.ChatCompletionUserMessageParam
60+
for i := len(c.Messages) - 1; i >= 0; i-- {
61+
m := c.Messages[i]
62+
if m.OfUser != nil {
63+
msg = m.OfUser
64+
break
65+
}
66+
}
67+
68+
if msg == nil {
69+
return nil, nil
70+
}
71+
72+
userMessage := msg.Content.OfString.String()
73+
if isCursor, _ := regexp.MatchString("<user_query>", userMessage); isCursor {
74+
userMessage = extractCursorUserQuery(userMessage)
75+
}
76+
77+
return ptr.To(strings.TrimSpace(userMessage)), nil
78+
}
79+
4580
func sumUsage(ref, in openai.CompletionUsage) openai.CompletionUsage {
4681
return openai.CompletionUsage{
4782
CompletionTokens: ref.CompletionTokens + in.CompletionTokens,

aibridged/provider.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
package aibridged
2+
3+
type Provider[Req any] interface {
4+
ParseRequest(payload []byte) (*Req, error)
5+
NewAsynchronousSession(*Req) Session[Req]
6+
NewSynchronousSession(*Req) Session[Req]
7+
}

aibridged/provider_openai.go

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
package aibridged
2+
3+
import (
4+
"encoding/json"
5+
"os"
6+
7+
"github.com/openai/openai-go"
8+
"github.com/openai/openai-go/option"
9+
"golang.org/x/xerrors"
10+
)
11+
12+
type OpenAIProvider struct {
13+
baseURL, key string
14+
}
15+
16+
func NewOpenAIProvider(baseURL, key string) *OpenAIProvider {
17+
return &OpenAIProvider{
18+
baseURL: baseURL,
19+
key: key,
20+
}
21+
}
22+
23+
func (*OpenAIProvider) ParseRequest(payload []byte) (*ChatCompletionNewParamsWrapper, error) {
24+
var in ChatCompletionNewParamsWrapper
25+
if err := json.Unmarshal(payload, &in); err != nil {
26+
return nil, xerrors.Errorf("failed to unmarshal request: %w", err)
27+
}
28+
29+
return &in, nil
30+
}
31+
32+
func (p *OpenAIProvider) NewAsynchronousSession(req *ChatCompletionNewParamsWrapper) Session[ChatCompletionNewParamsWrapper] {
33+
return &OpenAIStreamingSession{}
34+
}
35+
func (p *OpenAIProvider) NewSynchronousSession(req *ChatCompletionNewParamsWrapper) Session[ChatCompletionNewParamsWrapper] {
36+
panic("not implemented")
37+
38+
}
39+
40+
func newOpenAIClient(baseURL, key string) openai.Client {
41+
var opts []option.RequestOption
42+
if key == "" {
43+
key = os.Getenv("OPENAI_API_KEY")
44+
}
45+
opts = append(opts, option.WithAPIKey(key))
46+
if baseURL != "" {
47+
opts = append(opts, option.WithBaseURL(baseURL))
48+
}
49+
50+
opts = append(opts, option.WithMiddleware(LoggingMiddleware))
51+
52+
return openai.NewClient(opts...)
53+
}

aibridged/session.go

Lines changed: 10 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,19 @@
11
package aibridged
22

33
import (
4-
"fmt"
5-
"sync/atomic"
6-
)
7-
8-
type OpenAIToolCall struct {
9-
funcName string
10-
args map[string]string
11-
}
4+
"net/http"
125

13-
type OpenAIToolCallState int
14-
15-
const (
16-
OpenAIToolCallNotReady OpenAIToolCallState = iota
17-
OpenAIToolCallReady
18-
OpenAIToolCallInProgress
19-
OpenAIToolCallDone
6+
"cdr.dev/slog"
207
)
218

22-
func (o OpenAIToolCallState) String() string {
23-
switch o {
24-
case OpenAIToolCallNotReady:
25-
return "not ready"
26-
case OpenAIToolCallReady:
27-
return "ready"
28-
case OpenAIToolCallInProgress:
29-
return "in-progress"
30-
case OpenAIToolCallDone:
31-
return "done"
32-
default:
33-
return fmt.Sprintf("UNKNOWN STATE: %d", o)
34-
}
35-
}
36-
37-
type OpenAISession struct {
38-
done atomic.Bool
39-
// key = tool call ID
40-
toolCallsRequired map[string]*OpenAIToolCall
41-
toolCallsState map[string]OpenAIToolCallState
42-
phantomEvents [][]byte
9+
type Model struct {
10+
Provider, ModelName string
4311
}
4412

45-
func NewOpenAISession() *OpenAISession {
46-
return &OpenAISession{
47-
toolCallsRequired: make(map[string]*OpenAIToolCall),
48-
toolCallsState: make(map[string]OpenAIToolCallState),
49-
}
13+
type Session[Req any] interface {
14+
Init(logger slog.Logger, baseURL, key string, tracker Tracker, toolMgr ToolManager) (id string)
15+
LastUserPrompt(req Req) (*string, error)
16+
Model(req *Req) Model
17+
Execute(req *Req, w http.ResponseWriter, r *http.Request) error
18+
Close() error
5019
}

0 commit comments

Comments
 (0)