Skip to content

Commit 98022d7

Browse files
authored
FEATURE: support custom instructions for persona streaming (#890)
This allows us to inject information into the system prompt which can help shape replies without repeating over and over in messages.
1 parent fa7ca8b commit 98022d7

File tree

6 files changed

+61
-43
lines changed

6 files changed

+61
-43
lines changed

app/controllers/discourse_ai/admin/ai_personas_controller.rb

Lines changed: 7 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -111,55 +111,26 @@ def stream_reply
111111

112112
topic_id = params[:topic_id].to_i
113113
topic = nil
114-
post = nil
115114

116115
if topic_id > 0
117116
topic = Topic.find(topic_id)
118117

119-
raise Discourse::NotFound if topic.nil?
120-
121118
if topic.topic_allowed_users.where(user_id: user.id).empty?
122119
return render_json_error(I18n.t("discourse_ai.errors.user_not_allowed"))
123120
end
124-
125-
post =
126-
PostCreator.create!(
127-
user,
128-
topic_id: topic_id,
129-
raw: params[:query],
130-
skip_validations: true,
131-
custom_fields: {
132-
DiscourseAi::AiBot::Playground::BYPASS_AI_REPLY_CUSTOM_FIELD => true,
133-
},
134-
)
135-
else
136-
post =
137-
PostCreator.create!(
138-
user,
139-
title: I18n.t("discourse_ai.ai_bot.default_pm_prefix"),
140-
raw: params[:query],
141-
archetype: Archetype.private_message,
142-
target_usernames: "#{user.username},#{persona.user.username}",
143-
skip_validations: true,
144-
custom_fields: {
145-
DiscourseAi::AiBot::Playground::BYPASS_AI_REPLY_CUSTOM_FIELD => true,
146-
},
147-
)
148-
149-
topic = post.topic
150121
end
151122

152123
hijack = request.env["rack.hijack"]
153124
io = hijack.call
154125

155-
user = current_user
156-
157126
DiscourseAi::AiBot::ResponseHttpStreamer.queue_streamed_reply(
158-
io,
159-
persona,
160-
user,
161-
topic,
162-
post,
127+
io: io,
128+
persona: persona,
129+
user: user,
130+
topic: topic,
131+
query: params[:query].to_s,
132+
custom_instructions: params[:custom_instructions].to_s,
133+
current_user: current_user,
163134
)
164135
end
165136

lib/ai_bot/personas/persona.rb

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,11 @@ def craft_prompt(context, llm: nil)
171171
DiscourseAi::Completions::Llm.proxy(self.class.question_consolidator_llm)
172172
end
173173

174+
if context[:custom_instructions].present?
175+
prompt_insts << "\n"
176+
prompt_insts << context[:custom_instructions]
177+
end
178+
174179
fragments_guidance =
175180
rag_fragments_prompt(
176181
context[:conversation_context].to_a,

lib/ai_bot/playground.rb

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,7 @@ def get_context(participants:, conversation_context:, user:, skip_tool_details:
392392
result
393393
end
394394

395-
def reply_to(post, &blk)
395+
def reply_to(post, custom_instructions: nil, &blk)
396396
# this is a multithreading issue
397397
# post custom prompt is needed and it may not
398398
# be properly loaded, ensure it is loaded
@@ -413,6 +413,7 @@ def reply_to(post, &blk)
413413
context[:post_id] = post.id
414414
context[:topic_id] = post.topic_id
415415
context[:private_message] = post.topic.private_message?
416+
context[:custom_instructions] = custom_instructions
416417

417418
reply_user = bot.bot_user
418419
if bot.persona.class.respond_to?(:user_id)

lib/ai_bot/response_http_streamer.rb

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,36 @@ def schedule_block(&block)
3131

3232
# keeping this in a static method so we don't capture ENV and other bits
3333
# this allows us to release memory earlier
34-
def queue_streamed_reply(io, persona, user, topic, post)
34+
def queue_streamed_reply(
35+
io:,
36+
persona:,
37+
user:,
38+
topic:,
39+
query:,
40+
custom_instructions:,
41+
current_user:
42+
)
3543
schedule_block do
3644
begin
45+
post_params = {
46+
raw: query,
47+
skip_validations: true,
48+
custom_fields: {
49+
DiscourseAi::AiBot::Playground::BYPASS_AI_REPLY_CUSTOM_FIELD => true,
50+
},
51+
}
52+
53+
if topic
54+
post_params[:topic_id] = topic.id
55+
else
56+
post_params[:title] = I18n.t("discourse_ai.ai_bot.default_pm_prefix")
57+
post_params[:archetype] = Archetype.private_message
58+
post_params[:target_usernames] = "#{user.username},#{persona.user.username}"
59+
end
60+
61+
post = PostCreator.create!(user, post_params)
62+
topic = post.topic
63+
3764
io.write "HTTP/1.1 200 OK"
3865
io.write CRLF
3966
io.write "Content-Type: text/plain; charset=utf-8"
@@ -52,7 +79,7 @@ def queue_streamed_reply(io, persona, user, topic, post)
5279
io.flush
5380

5481
persona_class =
55-
DiscourseAi::AiBot::Personas::Persona.find_by(id: persona.id, user: user)
82+
DiscourseAi::AiBot::Personas::Persona.find_by(id: persona.id, user: current_user)
5683
bot = DiscourseAi::AiBot::Bot.as(persona.user, persona: persona_class.new)
5784

5885
data =
@@ -69,7 +96,7 @@ def queue_streamed_reply(io, persona, user, topic, post)
6996

7097
DiscourseAi::AiBot::Playground
7198
.new(bot)
72-
.reply_to(post) do |partial|
99+
.reply_to(post, custom_instructions: custom_instructions) do |partial|
73100
next if partial.length == 0
74101

75102
data = { partial: partial }.to_json + "\n\n"
@@ -88,11 +115,11 @@ def queue_streamed_reply(io, persona, user, topic, post)
88115
io.write CRLF
89116

90117
io.flush
91-
io.done
118+
io.done if io.respond_to?(:done)
92119
rescue StandardError => e
93120
# make it a tiny bit easier to debug in dev, this is tricky
94121
# multi-threaded code that exhibits various limitations in rails
95-
p e if Rails.env.development?
122+
p e if Rails.env.development? || Rails.env.test?
96123
Discourse.warn_exception(e, message: "Discourse AI: Unable to stream reply")
97124
ensure
98125
io.close

lib/completions/endpoints/fake.rb

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,10 @@ def self.last_call=(params)
104104
@last_call = params
105105
end
106106

107+
def self.previous_calls
108+
@previous_calls ||= []
109+
end
110+
107111
def self.reset!
108112
@last_call = nil
109113
@fake_content = nil
@@ -118,7 +122,11 @@ def perform_completion!(
118122
feature_name: nil,
119123
feature_context: nil
120124
)
121-
self.class.last_call = { dialect: dialect, user: user, model_params: model_params }
125+
last_call = { dialect: dialect, user: user, model_params: model_params }
126+
self.class.last_call = last_call
127+
self.class.previous_calls << last_call
128+
# guard memory in test
129+
self.class.previous_calls.shift if self.class.previous_calls.length > 10
122130

123131
content = self.class.fake_content
124132

spec/requests/admin/ai_personas_controller_spec.rb

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,7 @@ def validate_streamed_response(raw_http, expected)
500500
allowed_group_ids: [Group::AUTO_GROUPS[:trust_level_0]],
501501
default_llm: "custom:#{llm.id}",
502502
allow_personal_messages: true,
503+
system_prompt: "you are a helpful bot",
503504
)
504505

505506
io_out, io_in = IO.pipe
@@ -510,6 +511,7 @@ def validate_streamed_response(raw_http, expected)
510511
query: "how are you today?",
511512
user_unique_id: "site:test.com:user_id:1",
512513
preferred_username: "test_user",
514+
custom_instructions: "To be appended to system prompt",
513515
},
514516
env: {
515517
"rack.hijack" => lambda { io_in },
@@ -521,6 +523,10 @@ def validate_streamed_response(raw_http, expected)
521523
raw = io_out.read
522524
context_info = validate_streamed_response(raw, "This is a test! Testing!")
523525

526+
system_prompt = fake_endpoint.previous_calls[-2][:dialect].prompt.messages.first[:content]
527+
528+
expect(system_prompt).to eq("you are a helpful bot\nTo be appended to system prompt")
529+
524530
expect(context_info["topic_id"]).to be_present
525531
topic = Topic.find(context_info["topic_id"])
526532
last_post = topic.posts.order(:created_at).last

0 commit comments

Comments
 (0)