@@ -16,10 +16,12 @@ import (
16
16
"testing"
17
17
18
18
"golang.org/x/tools/txtar"
19
+ "golang.org/x/xerrors"
19
20
"storj.io/drpc"
20
21
21
22
"github.com/anthropics/anthropic-sdk-go"
22
23
"github.com/anthropics/anthropic-sdk-go/packages/ssestream"
24
+ "github.com/google/uuid"
23
25
"github.com/stretchr/testify/assert"
24
26
"github.com/stretchr/testify/require"
25
27
"github.com/tidwall/sjson"
@@ -409,9 +411,9 @@ func TestOpenAIChatCompletions(t *testing.T) {
409
411
{
410
412
streaming : true ,
411
413
},
412
- // {
413
- // streaming: false,
414
- // },
414
+ {
415
+ streaming : false ,
416
+ },
415
417
}
416
418
417
419
for _ , tc := range cases {
@@ -492,7 +494,6 @@ func TestOpenAIChatCompletions(t *testing.T) {
492
494
// We must ALWAYS have 2 calls to the bridge.
493
495
require .Eventually (t , func () bool { return mockSrv .callCount .Load () == 2 }, testutil .WaitLong , testutil .IntervalFast )
494
496
495
- // TODO: this is a bit flimsy since this API won't be in beta forever.
496
497
var content * openai.ChatCompletionChoice
497
498
if tc .streaming {
498
499
// Parse the response stream.
@@ -512,7 +513,6 @@ func TestOpenAIChatCompletions(t *testing.T) {
512
513
out , err := io .ReadAll (resp .Body )
513
514
require .NoError (t , err )
514
515
515
- // TODO: this is a bit flimsy since this API won't be in beta forever.
516
516
var message openai.ChatCompletion
517
517
require .NoError (t , json .Unmarshal (out , & message ))
518
518
require .NotNil (t , message )
@@ -534,10 +534,11 @@ func TestSimple(t *testing.T) {
534
534
sessionToken := getSessionToken (t , client )
535
535
536
536
testCases := []struct {
537
- name string
538
- fixture []byte
539
- configureFunc func (string , proto.DRPCAIBridgeDaemonClient ) (* aibridged.Bridge , error )
540
- createRequest func (* testing.T , string , []byte ) * http.Request
537
+ name string
538
+ fixture []byte
539
+ configureFunc func (string , proto.DRPCAIBridgeDaemonClient ) (* aibridged.Bridge , error )
540
+ getResponseIDFunc func (bool , * http.Response ) (string , error )
541
+ createRequest func (* testing.T , string , []byte ) * http.Request
541
542
}{
542
543
{
543
544
name : "anthropic" ,
@@ -554,6 +555,36 @@ func TestSimple(t *testing.T) {
554
555
return client , true
555
556
}, nil )
556
557
},
558
+ getResponseIDFunc : func (streaming bool , resp * http.Response ) (string , error ) {
559
+ if streaming {
560
+ decoder := ssestream .NewDecoder (resp )
561
+ // TODO: this is a bit flimsy since this API won't be in beta forever.
562
+ stream := ssestream .NewStream [anthropic.BetaRawMessageStreamEventUnion ](decoder , nil )
563
+ var message anthropic.BetaMessage
564
+ for stream .Next () {
565
+ event := stream .Current ()
566
+ if err := message .Accumulate (event ); err != nil {
567
+ return "" , xerrors .Errorf ("accumulate event: %w" , err )
568
+ }
569
+ }
570
+ if stream .Err () != nil {
571
+ return "" , xerrors .Errorf ("stream error: %w" , stream .Err ())
572
+ }
573
+ return message .ID , nil
574
+ }
575
+
576
+ body , err := io .ReadAll (resp .Body )
577
+ if err != nil {
578
+ return "" , xerrors .Errorf ("read body: %w" , err )
579
+ }
580
+
581
+ // TODO: this is a bit flimsy since this API won't be in beta forever.
582
+ var message anthropic.BetaMessage
583
+ if err := json .Unmarshal (body , & message ); err != nil {
584
+ return "" , xerrors .Errorf ("unmarshal response: %w" , err )
585
+ }
586
+ return message .ID , nil
587
+ },
557
588
createRequest : createAnthropicMessagesReq ,
558
589
},
559
590
{
@@ -571,6 +602,34 @@ func TestSimple(t *testing.T) {
571
602
return client , true
572
603
}, nil )
573
604
},
605
+ getResponseIDFunc : func (streaming bool , resp * http.Response ) (string , error ) {
606
+ if streaming {
607
+ // Parse the response stream.
608
+ decoder := oai_ssestream .NewDecoder (resp )
609
+ stream := oai_ssestream .NewStream [openai.ChatCompletionChunk ](decoder , nil )
610
+ var message openai.ChatCompletionAccumulator
611
+ for stream .Next () {
612
+ chunk := stream .Current ()
613
+ message .AddChunk (chunk )
614
+ }
615
+ if stream .Err () != nil {
616
+ return "" , xerrors .Errorf ("stream error: %w" , stream .Err ())
617
+ }
618
+ return message .ID , nil
619
+ }
620
+
621
+ // Parse & unmarshal the response.
622
+ body , err := io .ReadAll (resp .Body )
623
+ if err != nil {
624
+ return "" , xerrors .Errorf ("read body: %w" , err )
625
+ }
626
+
627
+ var message openai.ChatCompletion
628
+ if err := json .Unmarshal (body , & message ); err != nil {
629
+ return "" , xerrors .Errorf ("unmarshal response: %w" , err )
630
+ }
631
+ return message .ID , nil
632
+ },
574
633
createRequest : createOpenAIChatCompletionsReq ,
575
634
},
576
635
}
@@ -630,9 +689,20 @@ func TestSimple(t *testing.T) {
630
689
require .NoError (t , err )
631
690
assert .NotEmpty (t , bodyBytes , "should have received response body" )
632
691
692
+ // Reset the body after being read.
693
+ resp .Body = io .NopCloser (bytes .NewReader (bodyBytes ))
694
+
633
695
// Then: I expect the prompt to have been tracked.
634
696
require .NotEmpty (t , coderdClient .userPrompts , "no prompts tracked" )
635
697
assert .Equal (t , "how many angels can dance on the head of a pin" , coderdClient .userPrompts [0 ].Prompt )
698
+
699
+ // Validate that responses have their IDs overridden with a session ID rather than the original ID from the upstream provider.
700
+ // The reason for this is that Bridge may make multiple upstream requests (i.e. to invoke injected tools), and clients will not be expecting
701
+ // multiple messages in response to a single request.
702
+ // TODO: validate that expected upstream message ID is captured alongside returned ID in token usage.
703
+ id , err := tc .getResponseIDFunc (sc .streaming , resp )
704
+ require .NoError (t , err , "failed to retrieve response ID" )
705
+ require .Nil (t , uuid .Validate (id ), "id is not a UUID" )
636
706
})
637
707
}
638
708
})
0 commit comments