Skip to content

Commit 8e88565

Browse files
committed
non-streaming injected tool calls for anthropic
Signed-off-by: Danny Kopping <dannykopping@gmail.com>
1 parent bd6cc94 commit 8e88565

File tree

1 file changed

+150
-38
lines changed

1 file changed

+150
-38
lines changed

aibridged/bridge.go

Lines changed: 150 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -722,9 +722,8 @@ func (b *Bridge) proxyAnthropicRequest(w http.ResponseWriter, r *http.Request) {
722722
// looks up API key with os.LookupEnv("ANTHROPIC_API_KEY")
723723
client := anthropic.NewClient(opts...)
724724
if !in.UseStreaming() {
725-
var resp *anthropic.BetaMessage
726725
for {
727-
resp, err = client.Beta.Messages.New(ctx, messages, opts...)
726+
resp, err := client.Beta.Messages.New(ctx, messages, opts...)
728727
if err != nil {
729728
if isConnectionError(err) {
730729
b.logger.Warn(ctx, "upstream connection closed", slog.Error(err))
@@ -754,56 +753,169 @@ func (b *Bridge) proxyAnthropicRequest(w http.ResponseWriter, r *http.Request) {
754753
b.logger.Error(ctx, "failed to track token usage", slog.Error(err))
755754
}
756755

757-
messages.Messages = append(messages.Messages, resp.ToParam())
756+
// Handle tool calls for non-streaming.
757+
var pendingToolCalls []anthropic.BetaToolUseBlock
758+
for _, c := range resp.Content {
759+
toolUse := c.AsToolUse()
760+
if toolUse.ID == "" {
761+
continue
762+
}
763+
764+
if b.isInjectedTool(toolUse.Name) {
765+
pendingToolCalls = append(pendingToolCalls, toolUse)
766+
continue
767+
}
758768

759-
if resp.StopReason == anthropic.BetaStopReasonEndTurn {
769+
// If tool is not injected, track it since the client will be handling it.
770+
if serialized, err := json.Marshal(toolUse.Input); err == nil {
771+
_, err = coderdClient.TrackToolUsage(ctx, &proto.TrackToolUsageRequest{
772+
Model: string(resp.Model),
773+
Input: string(serialized),
774+
Tool: toolUse.Name,
775+
})
776+
if err != nil {
777+
b.logger.Error(ctx, "failed to track tool usage", slog.Error(err))
778+
}
779+
} else {
780+
b.logger.Warn(ctx, "failed to marshal args for tool usage", slog.Error(err))
781+
}
782+
}
783+
784+
// If no injected tool calls, we're done.
785+
if len(pendingToolCalls) == 0 {
786+
out, err := json.Marshal(resp)
787+
if err != nil {
788+
http.Error(w, "error marshaling response", http.StatusInternalServerError)
789+
return
790+
}
791+
792+
w.Header().Set("Content-Type", "application/json")
793+
w.WriteHeader(http.StatusOK)
794+
_, _ = w.Write(out)
760795
break
761796
}
762797

763-
// TODO: implement injected tool calling.
764-
if resp.StopReason == anthropic.BetaStopReasonToolUse {
765-
var (
766-
toolUse anthropic.BetaToolUseBlock
767-
input any
768-
)
769-
for _, c := range resp.Content {
770-
toolUse = c.AsToolUse()
771-
if toolUse.ID == "" {
772-
continue
773-
}
798+
// Append the assistant's message (which contains the tool_use block)
799+
// to the messages for the next API call.
800+
messages.Messages = append(messages.Messages, resp.ToParam())
774801

775-
input = toolUse.Input
802+
// Process each pending tool call.
803+
for _, tc := range pendingToolCalls {
804+
tool := b.tools[tc.Name]
805+
806+
var args map[string]any
807+
serialized, err := json.Marshal(tc.Input)
808+
if err != nil {
809+
b.logger.Warn(ctx, "failed to marshal tool args for unmarshal", slog.Error(err), slog.F("tool", tc.Name))
810+
// Continue to next tool call, but still append an error tool_result
811+
messages.Messages = append(messages.Messages,
812+
anthropic.NewBetaUserMessage(anthropic.NewBetaToolResultBlock(tc.ID, fmt.Sprintf("Error unmarshaling tool arguments: %v", err), true)),
813+
)
814+
continue
815+
} else if err := json.Unmarshal(serialized, &args); err != nil {
816+
b.logger.Warn(ctx, "failed to unmarshal tool args", slog.Error(err), slog.F("tool", tc.Name))
817+
// Continue to next tool call, but still append an error tool_result
818+
messages.Messages = append(messages.Messages,
819+
anthropic.NewBetaUserMessage(anthropic.NewBetaToolResultBlock(tc.ID, fmt.Sprintf("Error unmarshaling tool arguments: %v", err), true)),
820+
)
821+
continue
776822
}
777823

778-
if input != nil {
779-
if serialized, err := json.Marshal(input); err == nil {
780-
_, err = coderdClient.TrackToolUsage(ctx, &proto.TrackToolUsageRequest{
781-
Model: string(resp.Model),
782-
Input: string(serialized),
783-
Tool: toolUse.Name,
784-
Injected: b.isInjectedTool(toolUse.Name),
824+
_, err = coderdClient.TrackToolUsage(ctx, &proto.TrackToolUsageRequest{
825+
Model: string(resp.Model),
826+
Input: string(serialized),
827+
Tool: tc.Name,
828+
Injected: true,
829+
})
830+
if err != nil {
831+
b.logger.Error(ctx, "failed to track injected tool usage", slog.Error(err))
832+
}
833+
834+
res, err := tool.Call(ctx, args)
835+
if err != nil {
836+
// Always provide a tool_result even if the tool call failed
837+
messages.Messages = append(messages.Messages,
838+
anthropic.NewBetaUserMessage(anthropic.NewBetaToolResultBlock(tc.ID, fmt.Sprintf("Error calling tool: %v", err), true)),
839+
)
840+
continue
841+
}
842+
843+
// Ensure at least one tool_result is always added for each tool_use.
844+
toolResult := anthropic.BetaContentBlockParamUnion{
845+
OfToolResult: &anthropic.BetaToolResultBlockParam{
846+
ToolUseID: tc.ID,
847+
IsError: anthropic.Bool(false),
848+
},
849+
}
850+
851+
var hasValidResult bool
852+
for _, content := range res.Content {
853+
switch cb := content.(type) {
854+
case mcp.TextContent:
855+
toolResult.OfToolResult.Content = append(toolResult.OfToolResult.Content, anthropic.BetaToolResultBlockParamContentUnion{
856+
OfText: &anthropic.BetaTextBlockParam{
857+
Text: cb.Text,
858+
},
785859
})
786-
if err != nil {
787-
b.logger.Error(ctx, "failed to track injected tool usage", slog.Error(err))
860+
hasValidResult = true
861+
case mcp.EmbeddedResource:
862+
switch resource := cb.Resource.(type) {
863+
case mcp.TextResourceContents:
864+
val := fmt.Sprintf("Binary resource (MIME: %s, URI: %s): %s",
865+
resource.MIMEType, resource.URI, resource.Text)
866+
toolResult.OfToolResult.Content = append(toolResult.OfToolResult.Content, anthropic.BetaToolResultBlockParamContentUnion{
867+
OfText: &anthropic.BetaTextBlockParam{
868+
Text: val,
869+
},
870+
})
871+
hasValidResult = true
872+
case mcp.BlobResourceContents:
873+
val := fmt.Sprintf("Binary resource (MIME: %s, URI: %s): %s",
874+
resource.MIMEType, resource.URI, resource.Blob)
875+
toolResult.OfToolResult.Content = append(toolResult.OfToolResult.Content, anthropic.BetaToolResultBlockParamContentUnion{
876+
OfText: &anthropic.BetaTextBlockParam{
877+
Text: val,
878+
},
879+
})
880+
hasValidResult = true
881+
default:
882+
b.logger.Error(ctx, "unknown embedded resource type", slog.F("type", fmt.Sprintf("%T", resource)))
883+
toolResult.OfToolResult.Content = append(toolResult.OfToolResult.Content, anthropic.BetaToolResultBlockParamContentUnion{
884+
OfText: &anthropic.BetaTextBlockParam{
885+
Text: "Error: unknown embedded resource type",
886+
},
887+
})
888+
toolResult.OfToolResult.IsError = anthropic.Bool(true)
889+
hasValidResult = true
788890
}
789-
} else {
790-
b.logger.Warn(ctx, "failed to marshal args for tool usage", slog.Error(err))
891+
default:
892+
b.logger.Error(ctx, "not handling non-text tool result", slog.F("type", fmt.Sprintf("%T", cb)))
893+
toolResult.OfToolResult.Content = append(toolResult.OfToolResult.Content, anthropic.BetaToolResultBlockParamContentUnion{
894+
OfText: &anthropic.BetaTextBlockParam{
895+
Text: "Error: unsupported tool result type",
896+
},
897+
})
898+
toolResult.OfToolResult.IsError = anthropic.Bool(true)
899+
hasValidResult = true
791900
}
792901
}
793902

794-
break
795-
}
796-
}
903+
// If no content was processed, still add a tool_result
904+
if !hasValidResult {
905+
b.logger.Error(ctx, "no tool result added", slog.F("content_len", len(res.Content)), slog.F("is_error", res.IsError)) // This can only happen if there's somehow no content.
906+
toolResult.OfToolResult.Content = append(toolResult.OfToolResult.Content, anthropic.BetaToolResultBlockParamContentUnion{
907+
OfText: &anthropic.BetaTextBlockParam{
908+
Text: "Error: no valid tool result content",
909+
},
910+
})
911+
toolResult.OfToolResult.IsError = anthropic.Bool(true)
912+
}
797913

798-
out, err := json.Marshal(resp)
799-
if err != nil {
800-
http.Error(w, "error marshaling response", http.StatusInternalServerError)
801-
return
914+
if len(toolResult.OfToolResult.Content) > 0 {
915+
messages.Messages = append(messages.Messages, anthropic.NewBetaUserMessage(toolResult))
916+
}
917+
}
802918
}
803-
804-
w.Header().Set("Content-Type", "application/json")
805-
w.WriteHeader(http.StatusOK)
806-
_, _ = w.Write(out)
807919
return
808920
}
809921

0 commit comments

Comments
 (0)