Skip to content

Commit bd90c83

Browse files
committed
Test tool calls to Anthropic
Signed-off-by: Danny Kopping <dannykopping@gmail.com>
1 parent eef9bf7 commit bd90c83

File tree

3 files changed

+318
-19
lines changed

3 files changed

+318
-19
lines changed

aibridged/bridge.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -954,7 +954,7 @@ func (b *Bridge) proxyAnthropicRequest(w http.ResponseWriter, r *http.Request) {
954954
if err := message.Accumulate(event); err != nil {
955955
b.logger.Error(ctx, "failed to accumulate streaming events", slog.Error(err), slog.F("event", event), slog.F("msg", message.RawJSON()))
956956
http.Error(w, "failed to proxy request", http.StatusInternalServerError)
957-
return
957+
return // TODO: don't return, skip to close.
958958
}
959959

960960
// Tool-related handling.

aibridged/bridge_test.go

Lines changed: 192 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,21 @@ import (
1212
"net/http"
1313
"net/http/httptest"
1414
"sync"
15+
"sync/atomic"
1516
"testing"
1617

1718
"golang.org/x/tools/txtar"
1819
"storj.io/drpc"
1920

21+
"github.com/anthropics/anthropic-sdk-go"
22+
"github.com/anthropics/anthropic-sdk-go/packages/ssestream"
2023
"github.com/stretchr/testify/assert"
2124
"github.com/stretchr/testify/require"
2225
"github.com/tidwall/sjson"
2326

27+
"github.com/mark3labs/mcp-go/mcp"
28+
"github.com/mark3labs/mcp-go/server"
29+
2430
"github.com/coder/coder/v2/aibridged"
2531
"github.com/coder/coder/v2/aibridged/proto"
2632
"github.com/coder/coder/v2/coderd/coderdtest"
@@ -33,6 +39,8 @@ import (
3339
var (
3440
//go:embed fixtures/anthropic/single_builtin_tool.txtar
3541
antSingleBuiltinTool []byte
42+
//go:embed fixtures/anthropic/single_injected_tool.txtar
43+
antSingleInjectedTool []byte
3644

3745
//go:embed fixtures/openai/single_builtin_tool.txtar
3846
oaiSingleBuiltinTool []byte
@@ -45,15 +53,18 @@ var (
4553
)
4654

4755
const (
48-
fixtureRequest = "request"
49-
fixtureStreamingResponse = "streaming"
50-
fixtureNonStreamingResponse = "non-streaming"
56+
fixtureRequest = "request"
57+
fixtureStreamingResponse = "streaming"
58+
fixtureNonStreamingResponse = "non-streaming"
59+
fixtureStreamingToolResponse = "streaming/tool-call"
60+
fixtureNonStreamingToolResponse = "non-streaming/tool-call"
5161
)
5262

5363
func TestAnthropicMessages(t *testing.T) {
5464
t.Parallel()
5565

56-
sessionToken := getSessionToken(t)
66+
client := coderdtest.New(t, nil)
67+
sessionToken := getSessionToken(t, client)
5768

5869
t.Run("single builtin tool", func(t *testing.T) {
5970
t.Parallel()
@@ -95,7 +106,7 @@ func TestAnthropicMessages(t *testing.T) {
95106
reqBody = newBody
96107

97108
ctx := testutil.Context(t, testutil.WaitLong)
98-
srv := newMockServer(ctx, t, files)
109+
srv := newMockServer(ctx, t, files, nil)
99110
t.Cleanup(srv.Close)
100111

101112
coderdClient := &fakeBridgeDaemonClient{}
@@ -150,12 +161,138 @@ func TestAnthropicMessages(t *testing.T) {
150161
})
151162
}
152163
})
164+
165+
t.Run("single injected tool", func(t *testing.T) {
166+
t.Parallel()
167+
168+
cases := []struct {
169+
streaming bool
170+
}{
171+
{
172+
streaming: true,
173+
},
174+
{
175+
streaming: false,
176+
},
177+
}
178+
179+
for _, tc := range cases {
180+
t.Run(fmt.Sprintf("%s/streaming=%v", t.Name(), tc.streaming), func(t *testing.T) {
181+
t.Parallel()
182+
183+
arc := txtar.Parse(antSingleInjectedTool)
184+
t.Logf("%s: %s", t.Name(), arc.Comment)
185+
186+
files := filesMap(arc)
187+
require.Len(t, files, 5)
188+
require.Contains(t, files, fixtureRequest)
189+
require.Contains(t, files, fixtureStreamingResponse)
190+
require.Contains(t, files, fixtureNonStreamingResponse)
191+
require.Contains(t, files, fixtureStreamingToolResponse)
192+
require.Contains(t, files, fixtureNonStreamingToolResponse)
193+
194+
reqBody := files[fixtureRequest]
195+
196+
// Add the stream param to the request.
197+
newBody, err := sjson.SetBytes(reqBody, "stream", tc.streaming)
198+
require.NoError(t, err)
199+
reqBody = newBody
200+
201+
ctx := testutil.Context(t, testutil.WaitLong)
202+
// Conditionally return fixtures based on request count.
203+
// First request: halts with tool call instruction.
204+
// Second request: tool call invocation.
205+
mockSrv := newMockServer(ctx, t, files, func(reqCount uint32, resp []byte) []byte {
206+
if reqCount == 1 {
207+
return resp
208+
}
209+
210+
if reqCount > 2 {
211+
t.Fatalf("did not expect more than 2 calls; received %d", reqCount)
212+
}
213+
214+
if !tc.streaming {
215+
return files[fixtureNonStreamingToolResponse]
216+
}
217+
return files[fixtureStreamingToolResponse]
218+
})
219+
t.Cleanup(mockSrv.Close)
220+
221+
coderdClient := &fakeBridgeDaemonClient{}
222+
logger := testutil.Logger(t) // slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
223+
224+
// Setup Coder MCP integration.
225+
mcpSrv := httptest.NewServer(createMockMCPSrv(t))
226+
mcpBridge, err := aibridged.NewMCPToolBridge("coder", mcpSrv.URL, map[string]string{}, logger)
227+
require.NoError(t, err)
228+
229+
// Initialize MCP client, fetch tools, and inject into bridge.
230+
require.NoError(t, mcpBridge.Init(testutil.Context(t, testutil.WaitShort)))
231+
tools := mcpBridge.ListTools()
232+
require.NotEmpty(t, tools)
233+
234+
b, err := aibridged.NewBridge(codersdk.AIBridgeConfig{
235+
Daemons: 1,
236+
Anthropic: codersdk.AIBridgeAnthropicConfig{
237+
BaseURL: serpent.String(mockSrv.URL),
238+
Key: serpent.String(sessionToken),
239+
},
240+
}, logger, func() (proto.DRPCAIBridgeDaemonClient, bool) {
241+
return coderdClient, true
242+
}, tools)
243+
require.NoError(t, err)
244+
245+
// Invoke request to mocked API via aibridge.
246+
bridgeSrv := httptest.NewServer(b.Handler())
247+
req := createAnthropicMessagesReq(t, bridgeSrv.URL, reqBody)
248+
client := &http.Client{}
249+
resp, err := client.Do(req)
250+
require.NoError(t, err)
251+
require.Equal(t, http.StatusOK, resp.StatusCode)
252+
defer resp.Body.Close()
253+
254+
// We must ALWAYS have 2 calls to the bridge.
255+
require.Eventually(t, func() bool { return mockSrv.callCount.Load() == 2 }, testutil.WaitLong, testutil.IntervalFast)
256+
257+
// TODO: this is a bit flimsy since this API won't be in beta forever.
258+
var content *anthropic.BetaContentBlockUnion
259+
if tc.streaming {
260+
// Parse the response stream.
261+
decoder := ssestream.NewDecoder(resp)
262+
stream := ssestream.NewStream[anthropic.BetaRawMessageStreamEventUnion](decoder, nil)
263+
var message anthropic.BetaMessage
264+
for stream.Next() {
265+
event := stream.Current()
266+
require.NoError(t, message.Accumulate(event))
267+
}
268+
require.NoError(t, stream.Err())
269+
require.Len(t, message.Content, 2)
270+
content = &message.Content[1]
271+
} else {
272+
// Parse & unmarshal the response.
273+
out, err := io.ReadAll(resp.Body)
274+
require.NoError(t, err)
275+
276+
// TODO: this is a bit flimsy since this API won't be in beta forever.
277+
var message anthropic.BetaMessage
278+
require.NoError(t, json.Unmarshal(out, &message))
279+
require.NotNil(t, message)
280+
require.Len(t, message.Content, 1)
281+
content = &message.Content[0]
282+
}
283+
284+
require.NotNil(t, content)
285+
require.Equal(t, "admin", content.Text)
286+
})
287+
}
288+
})
153289
}
154290

155291
func TestOpenAIChatCompletions(t *testing.T) {
156292
t.Parallel()
157293

158-
sessionToken := getSessionToken(t)
294+
client := coderdtest.New(t, nil)
295+
sessionToken := getSessionToken(t, client)
159296

160297
t.Run("single builtin tool", func(t *testing.T) {
161298
t.Parallel()
@@ -197,7 +334,7 @@ func TestOpenAIChatCompletions(t *testing.T) {
197334
reqBody = newBody
198335

199336
ctx := testutil.Context(t, testutil.WaitLong)
200-
srv := newMockServer(ctx, t, files)
337+
srv := newMockServer(ctx, t, files, nil)
201338
t.Cleanup(srv.Close)
202339

203340
coderdClient := &fakeBridgeDaemonClient{}
@@ -260,7 +397,8 @@ func TestOpenAIChatCompletions(t *testing.T) {
260397
func TestSimple(t *testing.T) {
261398
t.Parallel()
262399

263-
sessionToken := getSessionToken(t)
400+
client := coderdtest.New(t, nil)
401+
sessionToken := getSessionToken(t, client)
264402

265403
testCases := []struct {
266404
name string
@@ -337,7 +475,7 @@ func TestSimple(t *testing.T) {
337475

338476
// Given: a mock API server and a Bridge through which the requests will flow.
339477
ctx := testutil.Context(t, testutil.WaitLong)
340-
srv := newMockServer(ctx, t, files)
478+
srv := newMockServer(ctx, t, files, nil)
341479
t.Cleanup(srv.Close)
342480

343481
coderdClient := &fakeBridgeDaemonClient{}
@@ -418,10 +556,9 @@ func createOpenAIChatCompletionsReq(t *testing.T, baseURL string, input []byte)
418556
return req
419557
}
420558

421-
func getSessionToken(t *testing.T) string {
559+
func getSessionToken(t *testing.T, client *codersdk.Client) string {
422560
t.Helper()
423561

424-
client := coderdtest.New(t, nil)
425562
_ = coderdtest.CreateFirstUser(t, client)
426563
resp, err := client.LoginWithPassword(t.Context(), codersdk.LoginWithPasswordRequest{
427564
Email: coderdtest.FirstUserParams.Email,
@@ -434,11 +571,18 @@ func getSessionToken(t *testing.T) string {
434571

435572
type mockServer struct {
436573
*httptest.Server
574+
575+
callCount atomic.Uint32
437576
}
438577

439-
func newMockServer(ctx context.Context, t *testing.T, files archiveFileMap) *mockServer {
578+
func newMockServer(ctx context.Context, t *testing.T, files archiveFileMap, responseMutatorFn func(reqCount uint32, resp []byte) []byte) *mockServer {
440579
t.Helper()
580+
581+
ms := &mockServer{}
441582
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
583+
callCount := ms.callCount.Add(1)
584+
t.Logf("\n\nCALL COUNT: %d\n\n", callCount)
585+
442586
body, err := io.ReadAll(r.Body)
443587
defer r.Body.Close()
444588
require.NoError(t, err)
@@ -450,9 +594,14 @@ func newMockServer(ctx context.Context, t *testing.T, files archiveFileMap) *moc
450594
require.NoError(t, json.Unmarshal(body, &reqMsg))
451595

452596
if !reqMsg.Stream {
597+
resp := files[fixtureNonStreamingResponse]
598+
if responseMutatorFn != nil {
599+
resp = responseMutatorFn(ms.callCount.Load(), resp)
600+
}
601+
453602
w.Header().Set("Content-Type", "application/json")
454603
w.WriteHeader(http.StatusOK)
455-
w.Write(files[fixtureNonStreamingResponse])
604+
w.Write(resp)
456605
return
457606
}
458607

@@ -461,7 +610,12 @@ func newMockServer(ctx context.Context, t *testing.T, files archiveFileMap) *moc
461610
w.Header().Set("Connection", "keep-alive")
462611
w.Header().Set("Access-Control-Allow-Origin", "*")
463612

464-
scanner := bufio.NewScanner(bytes.NewReader(files[fixtureStreamingResponse]))
613+
resp := files[fixtureStreamingResponse]
614+
if responseMutatorFn != nil {
615+
resp = responseMutatorFn(ms.callCount.Load(), resp)
616+
}
617+
618+
scanner := bufio.NewScanner(bytes.NewReader(resp))
465619
flusher, ok := w.(http.Flusher)
466620
if !ok {
467621
http.Error(w, "Streaming unsupported", http.StatusInternalServerError)
@@ -480,14 +634,12 @@ func newMockServer(ctx context.Context, t *testing.T, files archiveFileMap) *moc
480634
return
481635
}
482636
}))
483-
484637
srv.Config.BaseContext = func(_ net.Listener) context.Context {
485638
return ctx
486639
}
487640

488-
return &mockServer{
489-
Server: srv,
490-
}
641+
ms.Server = srv
642+
return ms
491643
}
492644

493645
type fakeBridgeDaemonClient struct {
@@ -526,3 +678,25 @@ func (f *fakeBridgeDaemonClient) TrackToolUsage(ctx context.Context, in *proto.T
526678

527679
return &proto.TrackToolUsageResponse{}, nil
528680
}
681+
682+
func createMockMCPSrv(t *testing.T) http.Handler {
683+
t.Helper()
684+
685+
s := server.NewMCPServer(
686+
"Mock coder MCP server",
687+
"1.0.0",
688+
server.WithToolCapabilities(true),
689+
)
690+
691+
// Add tool
692+
tool := mcp.NewTool("coder_get_authenticated_user",
693+
mcp.WithDescription("Mock of the coder_get_authenticated_user tool"),
694+
)
695+
696+
// Add tool handler
697+
s.AddTool(tool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
698+
return mcp.NewToolResultText("mock"), nil
699+
})
700+
701+
return server.NewStreamableHTTPServer(s)
702+
}

0 commit comments

Comments
 (0)