Skip to content

FIX: Correctly pass tool_choice when using Claude models. #1364

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions lib/completions/dialects/claude.rb
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,13 @@ def can_translate?(llm_model)
end

class ClaudePrompt
attr_reader :system_prompt
attr_reader :messages
attr_reader :tools
attr_reader :system_prompt, :messages, :tools, :tool_choice

def initialize(system_prompt, messages, tools)
def initialize(system_prompt, messages, tools, tool_choice)
@system_prompt = system_prompt
@messages = messages
@tools = tools
@tool_choice = tool_choice
end

def has_tools?
Expand Down Expand Up @@ -55,7 +54,7 @@ def translate
tools = nil
tools = tools_dialect.translated_tools if native_tool_support?

ClaudePrompt.new(system_prompt.presence, interleving_messages, tools)
ClaudePrompt.new(system_prompt.presence, interleving_messages, tools, tool_choice)
end

def max_prompt_tokens
Expand Down
49 changes: 49 additions & 0 deletions spec/lib/completions/endpoints/anthropic_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -770,6 +770,55 @@
end
end

describe "forced tool use" do
it "can properly force tool use" do
prompt =
DiscourseAi::Completions::Prompt.new(
"You are a bot",
messages: [type: :user, id: "user1", content: "echo hello"],
tools: [echo_tool],
tool_choice: "echo",
)

response_body = {
id: "msg_01RdJkxCbsEj9VFyFYAkfy2S",
type: "message",
role: "assistant",
model: "claude-3-haiku-20240307",
content: [
{
type: "tool_use",
id: "toolu_bdrk_014CMjxtGmKUtGoEFPgc7PF7",
name: "echo",
input: {
text: "hello",
},
},
],
stop_reason: "end_turn",
stop_sequence: nil,
usage: {
input_tokens: 345,
output_tokens: 65,
},
}.to_json

parsed_body = nil
stub_request(:post, url).with(
body:
proc do |req_body|
parsed_body = JSON.parse(req_body, symbolize_names: true)
true
end,
).to_return(status: 200, body: response_body)

llm.generate(prompt, user: Discourse.system_user)

# Verify that tool_choice: "echo" is present
expect(parsed_body.dig(:tool_choice, :name)).to eq("echo")
end
end

describe "structured output via prefilling" do
it "forces the response to be a JSON and using the given JSON schema" do
schema = {
Expand Down
61 changes: 61 additions & 0 deletions spec/lib/completions/endpoints/aws_bedrock_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,67 @@ def encode_message(message)
end
end

describe "forced tool use" do
it "can properly force tool use" do
proxy = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
request = nil

tools = [
{
name: "echo",
description: "echo something",
parameters: [
{ name: "text", type: "string", description: "text to echo", required: true },
],
},
]

prompt =
DiscourseAi::Completions::Prompt.new(
"You are a bot",
messages: [type: :user, id: "user1", content: "echo hello"],
tools: tools,
tool_choice: "echo",
)

# Mock response from Bedrock
content = {
content: [
{
type: "tool_use",
id: "toolu_bdrk_014CMjxtGmKUtGoEFPgc7PF7",
name: "echo",
input: {
text: "hello",
},
},
],
usage: {
input_tokens: 25,
output_tokens: 15,
},
}.to_json

stub_request(
:post,
"https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-sonnet-20240229-v1:0/invoke",
)
.with do |inner_request|
request = inner_request
true
end
.to_return(status: 200, body: content)

proxy.generate(prompt, user: user)

# Parse the request body
request_body = JSON.parse(request.body)

# Verify that tool_choice: "echo" is present
expect(request_body.dig("tool_choice", "name")).to eq("echo")
end
end

describe "structured output via prefilling" do
it "forces the response to be a JSON and using the given JSON schema" do
schema = {
Expand Down
Loading