Skip to content

Commit d6939be

Browse files
committed
correct token accumulation and tracking
Signed-off-by: Danny Kopping <dannykopping@gmail.com>
1 parent f0691cb commit d6939be

File tree

2 files changed

+22
-29
lines changed

2 files changed

+22
-29
lines changed

aibridged/bridge_integration_test.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -143,11 +143,10 @@ func TestAnthropicMessages(t *testing.T) {
143143
// Ensure the message starts and completes, at a minimum.
144144
assert.Contains(t, sp.AllEvents(), "message_start")
145145
assert.Contains(t, sp.AllEvents(), "message_stop")
146-
require.Len(t, coderdClient.tokenUsages, 2) // One from message_start, one from message_delta.
147-
} else {
148-
require.Len(t, coderdClient.tokenUsages, 1)
149146
}
150147

148+
require.Len(t, coderdClient.tokenUsages, 1)
149+
151150
assert.EqualValues(t, tc.expectedInputTokens, calculateTotalInputTokens(coderdClient.tokenUsages), "input tokens miscalculated")
152151
assert.EqualValues(t, tc.expectedOutputTokens, calculateTotalOutputTokens(coderdClient.tokenUsages), "output tokens miscalculated")
153152

aibridged/session_anthropic_messages_streaming.go

Lines changed: 20 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,10 @@ func (s *AnthropicMessagesStreamingSession) ProcessRequest(w http.ResponseWriter
8888
messages := s.req.BetaMessageNewParams
8989
logger := s.logger.With(slog.F("model", s.req.Model))
9090

91+
// Accumulate usage across the entire streaming interaction (including tool reinvocations).
92+
var cumulativeInputTokens int64
93+
var cumulativeOutputTokens int64
94+
9195
isFirst := true
9296
for {
9397
newStream:
@@ -134,19 +138,9 @@ func (s *AnthropicMessagesStreamingSession) ProcessRequest(w http.ResponseWriter
134138
continue
135139
}
136140
case string(ant_constant.ValueOf[ant_constant.MessageStart]()):
137-
// Track token usage
138141
start := event.AsMessageStart()
139-
metadata := Metadata{
140-
"web_search_requests": start.Message.Usage.ServerToolUse.WebSearchRequests,
141-
"cache_creation_input": start.Message.Usage.CacheCreationInputTokens,
142-
"cache_read_input": start.Message.Usage.CacheReadInputTokens,
143-
"cache_ephemeral_1h_input": start.Message.Usage.CacheCreation.Ephemeral1hInputTokens,
144-
"cache_ephemeral_5m_input": start.Message.Usage.CacheCreation.Ephemeral5mInputTokens,
145-
}
146-
if err := s.tracker.TrackTokensUsage(streamCtx, s.id, message.ID, start.Message.Usage.InputTokens, start.Message.Usage.OutputTokens, metadata); err != nil {
147-
logger.Warn(ctx, "failed to track token usage", slog.Error(err))
148-
}
149-
142+
cumulativeInputTokens += start.Message.Usage.InputTokens
143+
cumulativeOutputTokens += start.Message.Usage.OutputTokens
150144
if !isFirst {
151145
// Don't send message_start unless first message!
152146
// We're sending multiple messages back and forth with the API, but from the client's perspective
@@ -155,19 +149,8 @@ func (s *AnthropicMessagesStreamingSession) ProcessRequest(w http.ResponseWriter
155149
}
156150
case string(ant_constant.ValueOf[ant_constant.MessageDelta]()):
157151
delta := event.AsMessageDelta()
158-
// Track token usage
159-
metadata := Metadata{
160-
"web_search_requests": delta.Usage.ServerToolUse.WebSearchRequests,
161-
"cache_creation_input": delta.Usage.CacheCreationInputTokens,
162-
"cache_read_input": delta.Usage.CacheReadInputTokens,
163-
// Note: CacheCreation fields are not available in MessageDeltaUsage
164-
"cache_ephemeral_1h_input": 0,
165-
"cache_ephemeral_5m_input": 0,
166-
}
167-
if err := s.tracker.TrackTokensUsage(streamCtx, s.id, message.ID, delta.Usage.InputTokens, delta.Usage.OutputTokens, metadata); err != nil {
168-
logger.Warn(ctx, "failed to track token usage", slog.Error(err))
169-
}
170-
152+
cumulativeInputTokens += delta.Usage.InputTokens
153+
cumulativeOutputTokens += delta.Usage.OutputTokens
171154
// Don't relay message_delta events which indicate injected tool use.
172155
if len(pendingToolCalls) > 0 && s.toolMgr.GetTool(lastToolName) != nil {
173156
continue
@@ -183,7 +166,6 @@ func (s *AnthropicMessagesStreamingSession) ProcessRequest(w http.ResponseWriter
183166

184167
// Don't send message_stop until all tools have been called.
185168
case string(ant_constant.ValueOf[ant_constant.MessageStop]()):
186-
187169
if len(pendingToolCalls) > 0 {
188170
// Append the whole message from this stream as context since we'll be sending a new request with the tool results.
189171
messages.Messages = append(messages.Messages, message.ToParam())
@@ -347,6 +329,18 @@ func (s *AnthropicMessagesStreamingSession) ProcessRequest(w http.ResponseWriter
347329
}
348330
}
349331

332+
// Emit a single, final token usage total for this stream.
333+
metadata := Metadata{
334+
"web_search_requests": message.Usage.ServerToolUse.WebSearchRequests,
335+
"cache_creation_input": message.Usage.CacheCreationInputTokens,
336+
"cache_read_input": message.Usage.CacheReadInputTokens,
337+
"cache_ephemeral_1h_input": message.Usage.CacheCreation.Ephemeral1hInputTokens,
338+
"cache_ephemeral_5m_input": message.Usage.CacheCreation.Ephemeral5mInputTokens,
339+
}
340+
if err := s.tracker.TrackTokensUsage(streamCtx, s.id, message.ID, cumulativeInputTokens, cumulativeOutputTokens, metadata); err != nil {
341+
logger.Warn(ctx, "failed to track token usage", slog.Error(err))
342+
}
343+
350344
var streamErr error
351345
if streamErr = stream.Err(); streamErr != nil {
352346
if isConnectionError(streamErr) {

0 commit comments

Comments
 (0)