From c0bd190d0d0d81ef8fb6ac33932caa3940be41ea Mon Sep 17 00:00:00 2001 From: Roman Rizzi Date: Fri, 23 May 2025 10:25:08 -0300 Subject: [PATCH] FIX: Correctly pass tool_choice when using Claude models. The `ClaudePrompt` object couldn't access the original prompt's tool_choice attribute, affecting both Anthropic and Bedrock. --- lib/completions/dialects/claude.rb | 9 ++- .../completions/endpoints/anthropic_spec.rb | 49 +++++++++++++++ .../completions/endpoints/aws_bedrock_spec.rb | 61 +++++++++++++++++++ 3 files changed, 114 insertions(+), 5 deletions(-) diff --git a/lib/completions/dialects/claude.rb b/lib/completions/dialects/claude.rb index 1cea22156..0ad756249 100644 --- a/lib/completions/dialects/claude.rb +++ b/lib/completions/dialects/claude.rb @@ -13,14 +13,13 @@ def can_translate?(llm_model) end class ClaudePrompt - attr_reader :system_prompt - attr_reader :messages - attr_reader :tools + attr_reader :system_prompt, :messages, :tools, :tool_choice - def initialize(system_prompt, messages, tools) + def initialize(system_prompt, messages, tools, tool_choice) @system_prompt = system_prompt @messages = messages @tools = tools + @tool_choice = tool_choice end def has_tools? @@ -55,7 +54,7 @@ def translate tools = nil tools = tools_dialect.translated_tools if native_tool_support? - ClaudePrompt.new(system_prompt.presence, interleving_messages, tools) + ClaudePrompt.new(system_prompt.presence, interleving_messages, tools, tool_choice) end def max_prompt_tokens diff --git a/spec/lib/completions/endpoints/anthropic_spec.rb b/spec/lib/completions/endpoints/anthropic_spec.rb index f0e53bea5..24d7e0f55 100644 --- a/spec/lib/completions/endpoints/anthropic_spec.rb +++ b/spec/lib/completions/endpoints/anthropic_spec.rb @@ -770,6 +770,55 @@ end end + describe "forced tool use" do + it "can properly force tool use" do + prompt = + DiscourseAi::Completions::Prompt.new( + "You are a bot", + messages: [type: :user, id: "user1", content: "echo hello"], + tools: [echo_tool], + tool_choice: "echo", + ) + + response_body = { + id: "msg_01RdJkxCbsEj9VFyFYAkfy2S", + type: "message", + role: "assistant", + model: "claude-3-haiku-20240307", + content: [ + { + type: "tool_use", + id: "toolu_bdrk_014CMjxtGmKUtGoEFPgc7PF7", + name: "echo", + input: { + text: "hello", + }, + }, + ], + stop_reason: "end_turn", + stop_sequence: nil, + usage: { + input_tokens: 345, + output_tokens: 65, + }, + }.to_json + + parsed_body = nil + stub_request(:post, url).with( + body: + proc do |req_body| + parsed_body = JSON.parse(req_body, symbolize_names: true) + true + end, + ).to_return(status: 200, body: response_body) + + llm.generate(prompt, user: Discourse.system_user) + + # Verify that tool_choice: "echo" is present + expect(parsed_body.dig(:tool_choice, :name)).to eq("echo") + end + end + describe "structured output via prefilling" do it "forces the response to be a JSON and using the given JSON schema" do schema = { diff --git a/spec/lib/completions/endpoints/aws_bedrock_spec.rb b/spec/lib/completions/endpoints/aws_bedrock_spec.rb index bd60f9883..364c3b6b4 100644 --- a/spec/lib/completions/endpoints/aws_bedrock_spec.rb +++ b/spec/lib/completions/endpoints/aws_bedrock_spec.rb @@ -547,6 +547,67 @@ def encode_message(message) end end + describe "forced tool use" do + it "can properly force tool use" do + proxy = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}") + request = nil + + tools = [ + { + name: "echo", + description: "echo something", + parameters: [ + { name: "text", type: "string", description: "text to echo", required: true }, + ], + }, + ] + + prompt = + DiscourseAi::Completions::Prompt.new( + "You are a bot", + messages: [type: :user, id: "user1", content: "echo hello"], + tools: tools, + tool_choice: "echo", + ) + + # Mock response from Bedrock + content = { + content: [ + { + type: "tool_use", + id: "toolu_bdrk_014CMjxtGmKUtGoEFPgc7PF7", + name: "echo", + input: { + text: "hello", + }, + }, + ], + usage: { + input_tokens: 25, + output_tokens: 15, + }, + }.to_json + + stub_request( + :post, + "https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-sonnet-20240229-v1:0/invoke", + ) + .with do |inner_request| + request = inner_request + true + end + .to_return(status: 200, body: content) + + proxy.generate(prompt, user: user) + + # Parse the request body + request_body = JSON.parse(request.body) + + # Verify that tool_choice: "echo" is present + expect(request_body.dig("tool_choice", "name")).to eq("echo") + end + end + describe "structured output via prefilling" do it "forces the response to be a JSON and using the given JSON schema" do schema = {