@@ -24,6 +24,9 @@ import (
24
24
"github.com/stretchr/testify/require"
25
25
"github.com/tidwall/sjson"
26
26
27
+ "github.com/openai/openai-go"
28
+ oai_ssestream "github.com/openai/openai-go/packages/ssestream"
29
+
27
30
"github.com/mark3labs/mcp-go/mcp"
28
31
"github.com/mark3labs/mcp-go/server"
29
32
44
47
45
48
//go:embed fixtures/openai/single_builtin_tool.txtar
46
49
oaiSingleBuiltinTool []byte
50
+ //go:embed fixtures/openai/single_injected_tool.txtar
51
+ oaiSingleInjectedTool []byte
47
52
48
53
//go:embed fixtures/anthropic/simple.txtar
49
54
antSimple []byte
@@ -162,6 +167,7 @@ func TestAnthropicMessages(t *testing.T) {
162
167
}
163
168
})
164
169
170
+ // TODO: fixture contains hardcoded tool name; this is flimsy since our naming convention may change for injected tools.
165
171
t .Run ("single injected tool" , func (t * testing.T ) {
166
172
t .Parallel ()
167
173
@@ -392,6 +398,133 @@ func TestOpenAIChatCompletions(t *testing.T) {
392
398
})
393
399
}
394
400
})
401
+
402
+ // TODO: fixture contains hardcoded tool name; this is flimsy since our naming convention may change for injected tools.
403
+ t .Run ("single injected tool" , func (t * testing.T ) {
404
+ t .Parallel ()
405
+
406
+ cases := []struct {
407
+ streaming bool
408
+ }{
409
+ {
410
+ streaming : true ,
411
+ },
412
+ // {
413
+ // streaming: false,
414
+ // },
415
+ }
416
+
417
+ for _ , tc := range cases {
418
+ t .Run (fmt .Sprintf ("%s/streaming=%v" , t .Name (), tc .streaming ), func (t * testing.T ) {
419
+ t .Parallel ()
420
+
421
+ arc := txtar .Parse (oaiSingleInjectedTool )
422
+ t .Logf ("%s: %s" , t .Name (), arc .Comment )
423
+
424
+ files := filesMap (arc )
425
+ require .Len (t , files , 5 )
426
+ require .Contains (t , files , fixtureRequest )
427
+ require .Contains (t , files , fixtureStreamingResponse )
428
+ require .Contains (t , files , fixtureNonStreamingResponse )
429
+ require .Contains (t , files , fixtureStreamingToolResponse )
430
+ require .Contains (t , files , fixtureNonStreamingToolResponse )
431
+
432
+ reqBody := files [fixtureRequest ]
433
+
434
+ // Add the stream param to the request.
435
+ newBody , err := sjson .SetBytes (reqBody , "stream" , tc .streaming )
436
+ require .NoError (t , err )
437
+ reqBody = newBody
438
+
439
+ ctx := testutil .Context (t , testutil .WaitLong )
440
+ // Conditionally return fixtures based on request count.
441
+ // First request: halts with tool call instruction.
442
+ // Second request: tool call invocation.
443
+ mockSrv := newMockServer (ctx , t , files , func (reqCount uint32 , resp []byte ) []byte {
444
+ if reqCount == 1 {
445
+ return resp
446
+ }
447
+
448
+ if reqCount > 2 {
449
+ t .Fatalf ("did not expect more than 2 calls; received %d" , reqCount )
450
+ }
451
+
452
+ if ! tc .streaming {
453
+ return files [fixtureNonStreamingToolResponse ]
454
+ }
455
+ return files [fixtureStreamingToolResponse ]
456
+ })
457
+ t .Cleanup (mockSrv .Close )
458
+
459
+ coderdClient := & fakeBridgeDaemonClient {}
460
+ logger := testutil .Logger (t ) // slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
461
+
462
+ // Setup Coder MCP integration.
463
+ mcpSrv := httptest .NewServer (createMockMCPSrv (t ))
464
+ mcpBridge , err := aibridged .NewMCPToolBridge ("coder" , mcpSrv .URL , map [string ]string {}, logger )
465
+ require .NoError (t , err )
466
+
467
+ // Initialize MCP client, fetch tools, and inject into bridge.
468
+ require .NoError (t , mcpBridge .Init (testutil .Context (t , testutil .WaitShort )))
469
+ tools := mcpBridge .ListTools ()
470
+ require .NotEmpty (t , tools )
471
+
472
+ b , err := aibridged .NewBridge (codersdk.AIBridgeConfig {
473
+ Daemons : 1 ,
474
+ OpenAI : codersdk.AIBridgeOpenAIConfig {
475
+ BaseURL : serpent .String (mockSrv .URL ),
476
+ Key : serpent .String (sessionToken ),
477
+ },
478
+ }, logger , func () (proto.DRPCAIBridgeDaemonClient , bool ) {
479
+ return coderdClient , true
480
+ }, tools )
481
+ require .NoError (t , err )
482
+
483
+ // Invoke request to mocked API via aibridge.
484
+ bridgeSrv := httptest .NewServer (b .Handler ())
485
+ req := createOpenAIChatCompletionsReq (t , bridgeSrv .URL , reqBody )
486
+ client := & http.Client {}
487
+ resp , err := client .Do (req )
488
+ require .NoError (t , err )
489
+ require .Equal (t , http .StatusOK , resp .StatusCode )
490
+ defer resp .Body .Close ()
491
+
492
+ // We must ALWAYS have 2 calls to the bridge.
493
+ require .Eventually (t , func () bool { return mockSrv .callCount .Load () == 2 }, testutil .WaitLong , testutil .IntervalFast )
494
+
495
+ // TODO: this is a bit flimsy since this API won't be in beta forever.
496
+ var content * openai.ChatCompletionChoice
497
+ if tc .streaming {
498
+ // Parse the response stream.
499
+ decoder := oai_ssestream .NewDecoder (resp )
500
+ stream := oai_ssestream .NewStream [openai.ChatCompletionChunk ](decoder , nil )
501
+ var message openai.ChatCompletionAccumulator
502
+ for stream .Next () {
503
+ chunk := stream .Current ()
504
+ message .AddChunk (chunk )
505
+ }
506
+
507
+ require .NoError (t , stream .Err ())
508
+ require .Len (t , message .Choices , 1 )
509
+ content = & message .Choices [0 ]
510
+ } else {
511
+ // Parse & unmarshal the response.
512
+ out , err := io .ReadAll (resp .Body )
513
+ require .NoError (t , err )
514
+
515
+ // TODO: this is a bit flimsy since this API won't be in beta forever.
516
+ var message openai.ChatCompletion
517
+ require .NoError (t , json .Unmarshal (out , & message ))
518
+ require .NotNil (t , message )
519
+ require .Len (t , message .Choices , 1 )
520
+ content = & message .Choices [0 ]
521
+ }
522
+
523
+ require .NotNil (t , content )
524
+ require .Contains (t , content .Message .Content , "admin" )
525
+ })
526
+ }
527
+ })
395
528
}
396
529
397
530
func TestSimple (t * testing.T ) {
@@ -580,8 +713,7 @@ func newMockServer(ctx context.Context, t *testing.T, files archiveFileMap, resp
580
713
581
714
ms := & mockServer {}
582
715
srv := httptest .NewServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
583
- callCount := ms .callCount .Add (1 )
584
- t .Logf ("\n \n CALL COUNT: %d\n \n " , callCount )
716
+ ms .callCount .Add (1 )
585
717
586
718
body , err := io .ReadAll (r .Body )
587
719
defer r .Body .Close ()
0 commit comments