Skip to content

Commit 2faad0f

Browse files
committed
validate response IDs
Signed-off-by: Danny Kopping <dannykopping@gmail.com>
1 parent 3fc6d09 commit 2faad0f

File tree

2 files changed

+84
-14
lines changed

2 files changed

+84
-14
lines changed

aibridged/bridge.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ func (b *Bridge) proxyOpenAIRequest(w http.ResponseWriter, r *http.Request) {
274274
_, err = coderdClient.TrackToolUsage(ctx, &proto.TrackToolUsageRequest{
275275
SessionId: sessionID,
276276
MsgId: chunk.ID,
277-
Model: string(in.Model),
277+
Model: in.Model,
278278
Input: toolCall.Arguments,
279279
Tool: toolCall.Name,
280280
})
@@ -364,7 +364,7 @@ func (b *Bridge) proxyOpenAIRequest(w http.ResponseWriter, r *http.Request) {
364364
_, err = coderdClient.TrackToolUsage(ctx, &proto.TrackToolUsageRequest{
365365
SessionId: sessionID,
366366
MsgId: acc.ID,
367-
Model: string(in.Model),
367+
Model: in.Model,
368368
Input: tc.Arguments,
369369
Tool: tc.Name, // TODO: sanitize tool name.
370370
Injected: true,
@@ -460,7 +460,7 @@ func (b *Bridge) proxyOpenAIRequest(w http.ResponseWriter, r *http.Request) {
460460
if _, err = coderdClient.TrackTokenUsage(ctx, &proto.TrackTokenUsageRequest{
461461
SessionId: sessionID,
462462
MsgId: completion.ID,
463-
Model: string(completion.Model),
463+
Model: completion.Model,
464464
InputTokens: cumulativeUsage.PromptTokens,
465465
OutputTokens: cumulativeUsage.CompletionTokens,
466466
Other: map[string]int64{
@@ -485,7 +485,7 @@ func (b *Bridge) proxyOpenAIRequest(w http.ResponseWriter, r *http.Request) {
485485
_, err = coderdClient.TrackToolUsage(ctx, &proto.TrackToolUsageRequest{
486486
SessionId: sessionID,
487487
MsgId: completion.ID,
488-
Model: string(in.Model),
488+
Model: in.Model,
489489
Input: toolCall.Function.Arguments,
490490
Tool: toolCall.Function.Name,
491491
})
@@ -519,7 +519,7 @@ func (b *Bridge) proxyOpenAIRequest(w http.ResponseWriter, r *http.Request) {
519519
_, err = coderdClient.TrackToolUsage(ctx, &proto.TrackToolUsageRequest{
520520
SessionId: sessionID,
521521
MsgId: completion.ID,
522-
Model: string(in.Model),
522+
Model: in.Model,
523523
Input: tc.Function.Arguments,
524524
Tool: fn, // TODO: sanitize tool name.
525525
Injected: true,

aibridged/bridge_test.go renamed to aibridged/bridge_integration_test.go

Lines changed: 79 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,12 @@ import (
1616
"testing"
1717

1818
"golang.org/x/tools/txtar"
19+
"golang.org/x/xerrors"
1920
"storj.io/drpc"
2021

2122
"github.com/anthropics/anthropic-sdk-go"
2223
"github.com/anthropics/anthropic-sdk-go/packages/ssestream"
24+
"github.com/google/uuid"
2325
"github.com/stretchr/testify/assert"
2426
"github.com/stretchr/testify/require"
2527
"github.com/tidwall/sjson"
@@ -409,9 +411,9 @@ func TestOpenAIChatCompletions(t *testing.T) {
409411
{
410412
streaming: true,
411413
},
412-
// {
413-
// streaming: false,
414-
// },
414+
{
415+
streaming: false,
416+
},
415417
}
416418

417419
for _, tc := range cases {
@@ -492,7 +494,6 @@ func TestOpenAIChatCompletions(t *testing.T) {
492494
// We must ALWAYS have 2 calls to the bridge.
493495
require.Eventually(t, func() bool { return mockSrv.callCount.Load() == 2 }, testutil.WaitLong, testutil.IntervalFast)
494496

495-
// TODO: this is a bit flimsy since this API won't be in beta forever.
496497
var content *openai.ChatCompletionChoice
497498
if tc.streaming {
498499
// Parse the response stream.
@@ -512,7 +513,6 @@ func TestOpenAIChatCompletions(t *testing.T) {
512513
out, err := io.ReadAll(resp.Body)
513514
require.NoError(t, err)
514515

515-
// TODO: this is a bit flimsy since this API won't be in beta forever.
516516
var message openai.ChatCompletion
517517
require.NoError(t, json.Unmarshal(out, &message))
518518
require.NotNil(t, message)
@@ -534,10 +534,11 @@ func TestSimple(t *testing.T) {
534534
sessionToken := getSessionToken(t, client)
535535

536536
testCases := []struct {
537-
name string
538-
fixture []byte
539-
configureFunc func(string, proto.DRPCAIBridgeDaemonClient) (*aibridged.Bridge, error)
540-
createRequest func(*testing.T, string, []byte) *http.Request
537+
name string
538+
fixture []byte
539+
configureFunc func(string, proto.DRPCAIBridgeDaemonClient) (*aibridged.Bridge, error)
540+
getResponseIDFunc func(bool, *http.Response) (string, error)
541+
createRequest func(*testing.T, string, []byte) *http.Request
541542
}{
542543
{
543544
name: "anthropic",
@@ -554,6 +555,36 @@ func TestSimple(t *testing.T) {
554555
return client, true
555556
}, nil)
556557
},
558+
getResponseIDFunc: func(streaming bool, resp *http.Response) (string, error) {
559+
if streaming {
560+
decoder := ssestream.NewDecoder(resp)
561+
// TODO: this is a bit flimsy since this API won't be in beta forever.
562+
stream := ssestream.NewStream[anthropic.BetaRawMessageStreamEventUnion](decoder, nil)
563+
var message anthropic.BetaMessage
564+
for stream.Next() {
565+
event := stream.Current()
566+
if err := message.Accumulate(event); err != nil {
567+
return "", xerrors.Errorf("accumulate event: %w", err)
568+
}
569+
}
570+
if stream.Err() != nil {
571+
return "", xerrors.Errorf("stream error: %w", stream.Err())
572+
}
573+
return message.ID, nil
574+
}
575+
576+
body, err := io.ReadAll(resp.Body)
577+
if err != nil {
578+
return "", xerrors.Errorf("read body: %w", err)
579+
}
580+
581+
// TODO: this is a bit flimsy since this API won't be in beta forever.
582+
var message anthropic.BetaMessage
583+
if err := json.Unmarshal(body, &message); err != nil {
584+
return "", xerrors.Errorf("unmarshal response: %w", err)
585+
}
586+
return message.ID, nil
587+
},
557588
createRequest: createAnthropicMessagesReq,
558589
},
559590
{
@@ -571,6 +602,34 @@ func TestSimple(t *testing.T) {
571602
return client, true
572603
}, nil)
573604
},
605+
getResponseIDFunc: func(streaming bool, resp *http.Response) (string, error) {
606+
if streaming {
607+
// Parse the response stream.
608+
decoder := oai_ssestream.NewDecoder(resp)
609+
stream := oai_ssestream.NewStream[openai.ChatCompletionChunk](decoder, nil)
610+
var message openai.ChatCompletionAccumulator
611+
for stream.Next() {
612+
chunk := stream.Current()
613+
message.AddChunk(chunk)
614+
}
615+
if stream.Err() != nil {
616+
return "", xerrors.Errorf("stream error: %w", stream.Err())
617+
}
618+
return message.ID, nil
619+
}
620+
621+
// Parse & unmarshal the response.
622+
body, err := io.ReadAll(resp.Body)
623+
if err != nil {
624+
return "", xerrors.Errorf("read body: %w", err)
625+
}
626+
627+
var message openai.ChatCompletion
628+
if err := json.Unmarshal(body, &message); err != nil {
629+
return "", xerrors.Errorf("unmarshal response: %w", err)
630+
}
631+
return message.ID, nil
632+
},
574633
createRequest: createOpenAIChatCompletionsReq,
575634
},
576635
}
@@ -630,9 +689,20 @@ func TestSimple(t *testing.T) {
630689
require.NoError(t, err)
631690
assert.NotEmpty(t, bodyBytes, "should have received response body")
632691

692+
// Reset the body after being read.
693+
resp.Body = io.NopCloser(bytes.NewReader(bodyBytes))
694+
633695
// Then: I expect the prompt to have been tracked.
634696
require.NotEmpty(t, coderdClient.userPrompts, "no prompts tracked")
635697
assert.Equal(t, "how many angels can dance on the head of a pin", coderdClient.userPrompts[0].Prompt)
698+
699+
// Validate that responses have their IDs overridden with a session ID rather than the original ID from the upstream provider.
700+
// The reason for this is that Bridge may make multiple upstream requests (i.e. to invoke injected tools), and clients will not be expecting
701+
// multiple messages in response to a single request.
702+
// TODO: validate that expected upstream message ID is captured alongside returned ID in token usage.
703+
id, err := tc.getResponseIDFunc(sc.streaming, resp)
704+
require.NoError(t, err, "failed to retrieve response ID")
705+
require.Nil(t, uuid.Validate(id), "id is not a UUID")
636706
})
637707
}
638708
})

0 commit comments

Comments
 (0)