@@ -12,15 +12,21 @@ import (
12
12
"net/http"
13
13
"net/http/httptest"
14
14
"sync"
15
+ "sync/atomic"
15
16
"testing"
16
17
17
18
"golang.org/x/tools/txtar"
18
19
"storj.io/drpc"
19
20
21
+ "github.com/anthropics/anthropic-sdk-go"
22
+ "github.com/anthropics/anthropic-sdk-go/packages/ssestream"
20
23
"github.com/stretchr/testify/assert"
21
24
"github.com/stretchr/testify/require"
22
25
"github.com/tidwall/sjson"
23
26
27
+ "github.com/mark3labs/mcp-go/mcp"
28
+ "github.com/mark3labs/mcp-go/server"
29
+
24
30
"github.com/coder/coder/v2/aibridged"
25
31
"github.com/coder/coder/v2/aibridged/proto"
26
32
"github.com/coder/coder/v2/coderd/coderdtest"
@@ -33,6 +39,8 @@ import (
33
39
var (
34
40
//go:embed fixtures/anthropic/single_builtin_tool.txtar
35
41
antSingleBuiltinTool []byte
42
+ //go:embed fixtures/anthropic/single_injected_tool.txtar
43
+ antSingleInjectedTool []byte
36
44
37
45
//go:embed fixtures/openai/single_builtin_tool.txtar
38
46
oaiSingleBuiltinTool []byte
@@ -45,15 +53,18 @@ var (
45
53
)
46
54
47
55
const (
48
- fixtureRequest = "request"
49
- fixtureStreamingResponse = "streaming"
50
- fixtureNonStreamingResponse = "non-streaming"
56
+ fixtureRequest = "request"
57
+ fixtureStreamingResponse = "streaming"
58
+ fixtureNonStreamingResponse = "non-streaming"
59
+ fixtureStreamingToolResponse = "streaming/tool-call"
60
+ fixtureNonStreamingToolResponse = "non-streaming/tool-call"
51
61
)
52
62
53
63
func TestAnthropicMessages (t * testing.T ) {
54
64
t .Parallel ()
55
65
56
- sessionToken := getSessionToken (t )
66
+ client := coderdtest .New (t , nil )
67
+ sessionToken := getSessionToken (t , client )
57
68
58
69
t .Run ("single builtin tool" , func (t * testing.T ) {
59
70
t .Parallel ()
@@ -95,7 +106,7 @@ func TestAnthropicMessages(t *testing.T) {
95
106
reqBody = newBody
96
107
97
108
ctx := testutil .Context (t , testutil .WaitLong )
98
- srv := newMockServer (ctx , t , files )
109
+ srv := newMockServer (ctx , t , files , nil )
99
110
t .Cleanup (srv .Close )
100
111
101
112
coderdClient := & fakeBridgeDaemonClient {}
@@ -150,12 +161,138 @@ func TestAnthropicMessages(t *testing.T) {
150
161
})
151
162
}
152
163
})
164
+
165
+ t .Run ("single injected tool" , func (t * testing.T ) {
166
+ t .Parallel ()
167
+
168
+ cases := []struct {
169
+ streaming bool
170
+ }{
171
+ {
172
+ streaming : true ,
173
+ },
174
+ {
175
+ streaming : false ,
176
+ },
177
+ }
178
+
179
+ for _ , tc := range cases {
180
+ t .Run (fmt .Sprintf ("%s/streaming=%v" , t .Name (), tc .streaming ), func (t * testing.T ) {
181
+ t .Parallel ()
182
+
183
+ arc := txtar .Parse (antSingleInjectedTool )
184
+ t .Logf ("%s: %s" , t .Name (), arc .Comment )
185
+
186
+ files := filesMap (arc )
187
+ require .Len (t , files , 5 )
188
+ require .Contains (t , files , fixtureRequest )
189
+ require .Contains (t , files , fixtureStreamingResponse )
190
+ require .Contains (t , files , fixtureNonStreamingResponse )
191
+ require .Contains (t , files , fixtureStreamingToolResponse )
192
+ require .Contains (t , files , fixtureNonStreamingToolResponse )
193
+
194
+ reqBody := files [fixtureRequest ]
195
+
196
+ // Add the stream param to the request.
197
+ newBody , err := sjson .SetBytes (reqBody , "stream" , tc .streaming )
198
+ require .NoError (t , err )
199
+ reqBody = newBody
200
+
201
+ ctx := testutil .Context (t , testutil .WaitLong )
202
+ // Conditionally return fixtures based on request count.
203
+ // First request: halts with tool call instruction.
204
+ // Second request: tool call invocation.
205
+ mockSrv := newMockServer (ctx , t , files , func (reqCount uint32 , resp []byte ) []byte {
206
+ if reqCount == 1 {
207
+ return resp
208
+ }
209
+
210
+ if reqCount > 2 {
211
+ t .Fatalf ("did not expect more than 2 calls; received %d" , reqCount )
212
+ }
213
+
214
+ if ! tc .streaming {
215
+ return files [fixtureNonStreamingToolResponse ]
216
+ }
217
+ return files [fixtureStreamingToolResponse ]
218
+ })
219
+ t .Cleanup (mockSrv .Close )
220
+
221
+ coderdClient := & fakeBridgeDaemonClient {}
222
+ logger := testutil .Logger (t ) // slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
223
+
224
+ // Setup Coder MCP integration.
225
+ mcpSrv := httptest .NewServer (createMockMCPSrv (t ))
226
+ mcpBridge , err := aibridged .NewMCPToolBridge ("coder" , mcpSrv .URL , map [string ]string {}, logger )
227
+ require .NoError (t , err )
228
+
229
+ // Initialize MCP client, fetch tools, and inject into bridge.
230
+ require .NoError (t , mcpBridge .Init (testutil .Context (t , testutil .WaitShort )))
231
+ tools := mcpBridge .ListTools ()
232
+ require .NotEmpty (t , tools )
233
+
234
+ b , err := aibridged .NewBridge (codersdk.AIBridgeConfig {
235
+ Daemons : 1 ,
236
+ Anthropic : codersdk.AIBridgeAnthropicConfig {
237
+ BaseURL : serpent .String (mockSrv .URL ),
238
+ Key : serpent .String (sessionToken ),
239
+ },
240
+ }, logger , func () (proto.DRPCAIBridgeDaemonClient , bool ) {
241
+ return coderdClient , true
242
+ }, tools )
243
+ require .NoError (t , err )
244
+
245
+ // Invoke request to mocked API via aibridge.
246
+ bridgeSrv := httptest .NewServer (b .Handler ())
247
+ req := createAnthropicMessagesReq (t , bridgeSrv .URL , reqBody )
248
+ client := & http.Client {}
249
+ resp , err := client .Do (req )
250
+ require .NoError (t , err )
251
+ require .Equal (t , http .StatusOK , resp .StatusCode )
252
+ defer resp .Body .Close ()
253
+
254
+ // We must ALWAYS have 2 calls to the bridge.
255
+ require .Eventually (t , func () bool { return mockSrv .callCount .Load () == 2 }, testutil .WaitLong , testutil .IntervalFast )
256
+
257
+ // TODO: this is a bit flimsy since this API won't be in beta forever.
258
+ var content * anthropic.BetaContentBlockUnion
259
+ if tc .streaming {
260
+ // Parse the response stream.
261
+ decoder := ssestream .NewDecoder (resp )
262
+ stream := ssestream .NewStream [anthropic.BetaRawMessageStreamEventUnion ](decoder , nil )
263
+ var message anthropic.BetaMessage
264
+ for stream .Next () {
265
+ event := stream .Current ()
266
+ require .NoError (t , message .Accumulate (event ))
267
+ }
268
+ require .NoError (t , stream .Err ())
269
+ require .Len (t , message .Content , 2 )
270
+ content = & message .Content [1 ]
271
+ } else {
272
+ // Parse & unmarshal the response.
273
+ out , err := io .ReadAll (resp .Body )
274
+ require .NoError (t , err )
275
+
276
+ // TODO: this is a bit flimsy since this API won't be in beta forever.
277
+ var message anthropic.BetaMessage
278
+ require .NoError (t , json .Unmarshal (out , & message ))
279
+ require .NotNil (t , message )
280
+ require .Len (t , message .Content , 1 )
281
+ content = & message .Content [0 ]
282
+ }
283
+
284
+ require .NotNil (t , content )
285
+ require .Equal (t , "admin" , content .Text )
286
+ })
287
+ }
288
+ })
153
289
}
154
290
155
291
func TestOpenAIChatCompletions (t * testing.T ) {
156
292
t .Parallel ()
157
293
158
- sessionToken := getSessionToken (t )
294
+ client := coderdtest .New (t , nil )
295
+ sessionToken := getSessionToken (t , client )
159
296
160
297
t .Run ("single builtin tool" , func (t * testing.T ) {
161
298
t .Parallel ()
@@ -197,7 +334,7 @@ func TestOpenAIChatCompletions(t *testing.T) {
197
334
reqBody = newBody
198
335
199
336
ctx := testutil .Context (t , testutil .WaitLong )
200
- srv := newMockServer (ctx , t , files )
337
+ srv := newMockServer (ctx , t , files , nil )
201
338
t .Cleanup (srv .Close )
202
339
203
340
coderdClient := & fakeBridgeDaemonClient {}
@@ -260,7 +397,8 @@ func TestOpenAIChatCompletions(t *testing.T) {
260
397
func TestSimple (t * testing.T ) {
261
398
t .Parallel ()
262
399
263
- sessionToken := getSessionToken (t )
400
+ client := coderdtest .New (t , nil )
401
+ sessionToken := getSessionToken (t , client )
264
402
265
403
testCases := []struct {
266
404
name string
@@ -337,7 +475,7 @@ func TestSimple(t *testing.T) {
337
475
338
476
// Given: a mock API server and a Bridge through which the requests will flow.
339
477
ctx := testutil .Context (t , testutil .WaitLong )
340
- srv := newMockServer (ctx , t , files )
478
+ srv := newMockServer (ctx , t , files , nil )
341
479
t .Cleanup (srv .Close )
342
480
343
481
coderdClient := & fakeBridgeDaemonClient {}
@@ -418,10 +556,9 @@ func createOpenAIChatCompletionsReq(t *testing.T, baseURL string, input []byte)
418
556
return req
419
557
}
420
558
421
- func getSessionToken (t * testing.T ) string {
559
+ func getSessionToken (t * testing.T , client * codersdk. Client ) string {
422
560
t .Helper ()
423
561
424
- client := coderdtest .New (t , nil )
425
562
_ = coderdtest .CreateFirstUser (t , client )
426
563
resp , err := client .LoginWithPassword (t .Context (), codersdk.LoginWithPasswordRequest {
427
564
Email : coderdtest .FirstUserParams .Email ,
@@ -434,11 +571,18 @@ func getSessionToken(t *testing.T) string {
434
571
435
572
type mockServer struct {
436
573
* httptest.Server
574
+
575
+ callCount atomic.Uint32
437
576
}
438
577
439
- func newMockServer (ctx context.Context , t * testing.T , files archiveFileMap ) * mockServer {
578
+ func newMockServer (ctx context.Context , t * testing.T , files archiveFileMap , responseMutatorFn func ( reqCount uint32 , resp [] byte ) [] byte ) * mockServer {
440
579
t .Helper ()
580
+
581
+ ms := & mockServer {}
441
582
srv := httptest .NewServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
583
+ callCount := ms .callCount .Add (1 )
584
+ t .Logf ("\n \n CALL COUNT: %d\n \n " , callCount )
585
+
442
586
body , err := io .ReadAll (r .Body )
443
587
defer r .Body .Close ()
444
588
require .NoError (t , err )
@@ -450,9 +594,14 @@ func newMockServer(ctx context.Context, t *testing.T, files archiveFileMap) *moc
450
594
require .NoError (t , json .Unmarshal (body , & reqMsg ))
451
595
452
596
if ! reqMsg .Stream {
597
+ resp := files [fixtureNonStreamingResponse ]
598
+ if responseMutatorFn != nil {
599
+ resp = responseMutatorFn (ms .callCount .Load (), resp )
600
+ }
601
+
453
602
w .Header ().Set ("Content-Type" , "application/json" )
454
603
w .WriteHeader (http .StatusOK )
455
- w .Write (files [ fixtureNonStreamingResponse ] )
604
+ w .Write (resp )
456
605
return
457
606
}
458
607
@@ -461,7 +610,12 @@ func newMockServer(ctx context.Context, t *testing.T, files archiveFileMap) *moc
461
610
w .Header ().Set ("Connection" , "keep-alive" )
462
611
w .Header ().Set ("Access-Control-Allow-Origin" , "*" )
463
612
464
- scanner := bufio .NewScanner (bytes .NewReader (files [fixtureStreamingResponse ]))
613
+ resp := files [fixtureStreamingResponse ]
614
+ if responseMutatorFn != nil {
615
+ resp = responseMutatorFn (ms .callCount .Load (), resp )
616
+ }
617
+
618
+ scanner := bufio .NewScanner (bytes .NewReader (resp ))
465
619
flusher , ok := w .(http.Flusher )
466
620
if ! ok {
467
621
http .Error (w , "Streaming unsupported" , http .StatusInternalServerError )
@@ -480,14 +634,12 @@ func newMockServer(ctx context.Context, t *testing.T, files archiveFileMap) *moc
480
634
return
481
635
}
482
636
}))
483
-
484
637
srv .Config .BaseContext = func (_ net.Listener ) context.Context {
485
638
return ctx
486
639
}
487
640
488
- return & mockServer {
489
- Server : srv ,
490
- }
641
+ ms .Server = srv
642
+ return ms
491
643
}
492
644
493
645
type fakeBridgeDaemonClient struct {
@@ -526,3 +678,25 @@ func (f *fakeBridgeDaemonClient) TrackToolUsage(ctx context.Context, in *proto.T
526
678
527
679
return & proto.TrackToolUsageResponse {}, nil
528
680
}
681
+
682
+ func createMockMCPSrv (t * testing.T ) http.Handler {
683
+ t .Helper ()
684
+
685
+ s := server .NewMCPServer (
686
+ "Mock coder MCP server" ,
687
+ "1.0.0" ,
688
+ server .WithToolCapabilities (true ),
689
+ )
690
+
691
+ // Add tool
692
+ tool := mcp .NewTool ("coder_get_authenticated_user" ,
693
+ mcp .WithDescription ("Mock of the coder_get_authenticated_user tool" ),
694
+ )
695
+
696
+ // Add tool handler
697
+ s .AddTool (tool , func (ctx context.Context , request mcp.CallToolRequest ) (* mcp.CallToolResult , error ) {
698
+ return mcp .NewToolResultText ("mock" ), nil
699
+ })
700
+
701
+ return server .NewStreamableHTTPServer (s )
702
+ }
0 commit comments