Skip to content

Commit 34a59b6

Browse files
authored
FIX: ensure replies are never double streamed (#879)
The custom field "discourse_ai_bypass_ai_reply" was added so we can signal the post created hook to bypass replying even if it thinks it should. Otherwise there are cases where we double answer user questions leading to much confusion. This also slightly refactors code making the controller smaller
1 parent be0b78c commit 34a59b6

File tree

4 files changed

+138
-97
lines changed

4 files changed

+138
-97
lines changed

app/controllers/discourse_ai/admin/ai_personas_controller.rb

Lines changed: 13 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -74,33 +74,6 @@ def destroy
7474
end
7575
end
7676

77-
class << self
78-
POOL_SIZE = 10
79-
def thread_pool
80-
@thread_pool ||=
81-
Concurrent::CachedThreadPool.new(min_threads: 0, max_threads: POOL_SIZE, idletime: 30)
82-
end
83-
84-
def schedule_block(&block)
85-
# think about a better way to handle cross thread connections
86-
if Rails.env.test?
87-
block.call
88-
return
89-
end
90-
91-
db = RailsMultisite::ConnectionManagement.current_db
92-
thread_pool.post do
93-
begin
94-
RailsMultisite::ConnectionManagement.with_connection(db) { block.call }
95-
rescue StandardError => e
96-
Discourse.warn_exception(e, message: "Discourse AI: Unable to stream reply")
97-
end
98-
end
99-
end
100-
end
101-
102-
CRLF = "\r\n"
103-
10477
def stream_reply
10578
persona =
10679
AiPersona.find_by(name: params[:persona_name]) ||
@@ -155,6 +128,9 @@ def stream_reply
155128
topic_id: topic_id,
156129
raw: params[:query],
157130
skip_validations: true,
131+
custom_fields: {
132+
DiscourseAi::AiBot::Playground::BYPASS_AI_REPLY_CUSTOM_FIELD => true,
133+
},
158134
)
159135
else
160136
post =
@@ -165,6 +141,9 @@ def stream_reply
165141
archetype: Archetype.private_message,
166142
target_usernames: "#{user.username},#{persona.user.username}",
167143
skip_validations: true,
144+
custom_fields: {
145+
DiscourseAi::AiBot::Playground::BYPASS_AI_REPLY_CUSTOM_FIELD => true,
146+
},
168147
)
169148

170149
topic = post.topic
@@ -175,81 +154,19 @@ def stream_reply
175154

176155
user = current_user
177156

178-
self.class.queue_streamed_reply(io, persona, user, topic, post)
157+
DiscourseAi::AiBot::ResponseHttpStreamer.queue_streamed_reply(
158+
io,
159+
persona,
160+
user,
161+
topic,
162+
post,
163+
)
179164
end
180165

181166
private
182167

183168
AI_STREAM_CONVERSATION_UNIQUE_ID = "ai-stream-conversation-unique-id"
184169

185-
# keeping this in a static method so we don't capture ENV and other bits
186-
# this allows us to release memory earlier
187-
def self.queue_streamed_reply(io, persona, user, topic, post)
188-
schedule_block do
189-
begin
190-
io.write "HTTP/1.1 200 OK"
191-
io.write CRLF
192-
io.write "Content-Type: text/plain; charset=utf-8"
193-
io.write CRLF
194-
io.write "Transfer-Encoding: chunked"
195-
io.write CRLF
196-
io.write "Cache-Control: no-cache, no-store, must-revalidate"
197-
io.write CRLF
198-
io.write "Connection: close"
199-
io.write CRLF
200-
io.write "X-Accel-Buffering: no"
201-
io.write CRLF
202-
io.write "X-Content-Type-Options: nosniff"
203-
io.write CRLF
204-
io.write CRLF
205-
io.flush
206-
207-
persona_class =
208-
DiscourseAi::AiBot::Personas::Persona.find_by(id: persona.id, user: user)
209-
bot = DiscourseAi::AiBot::Bot.as(persona.user, persona: persona_class.new)
210-
211-
data =
212-
{ topic_id: topic.id, bot_user_id: persona.user.id, persona_id: persona.id }.to_json +
213-
"\n\n"
214-
215-
io.write data.bytesize.to_s(16)
216-
io.write CRLF
217-
io.write data
218-
io.write CRLF
219-
220-
DiscourseAi::AiBot::Playground
221-
.new(bot)
222-
.reply_to(post) do |partial|
223-
next if partial.length == 0
224-
225-
data = { partial: partial }.to_json + "\n\n"
226-
227-
data.force_encoding("UTF-8")
228-
229-
io.write data.bytesize.to_s(16)
230-
io.write CRLF
231-
io.write data
232-
io.write CRLF
233-
io.flush
234-
end
235-
236-
io.write "0"
237-
io.write CRLF
238-
io.write CRLF
239-
240-
io.flush
241-
io.done
242-
rescue StandardError => e
243-
# make it a tiny bit easier to debug in dev, this is tricky
244-
# multi-threaded code that exhibits various limitations in rails
245-
p e if Rails.env.development?
246-
Discourse.warn_exception(e, message: "Discourse AI: Unable to stream reply")
247-
ensure
248-
io.close
249-
end
250-
end
251-
end
252-
253170
def stage_user
254171
unique_id = params[:user_unique_id].to_s
255172
field = UserCustomField.find_by(name: AI_STREAM_CONVERSATION_UNIQUE_ID, value: unique_id)

lib/ai_bot/playground.rb

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
module DiscourseAi
44
module AiBot
55
class Playground
6+
BYPASS_AI_REPLY_CUSTOM_FIELD = "discourse_ai_bypass_ai_reply"
7+
68
attr_reader :bot
79

810
# An abstraction to manage the bot and topic interactions.
@@ -550,6 +552,7 @@ def can_attach?(post)
550552
return false if bot.bot_user.nil?
551553
return false if post.topic.private_message? && post.post_type != Post.types[:regular]
552554
return false if (SiteSetting.ai_bot_allowed_groups_map & post.user.group_ids).blank?
555+
return false if post.custom_fields[BYPASS_AI_REPLY_CUSTOM_FIELD].present?
553556

554557
true
555558
end

lib/ai_bot/response_http_streamer.rb

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# frozen_string_literal: true
2+
3+
module DiscourseAi
4+
module AiBot
5+
class ResponseHttpStreamer
6+
CRLF = "\r\n"
7+
POOL_SIZE = 10
8+
9+
class << self
10+
def thread_pool
11+
@thread_pool ||=
12+
Concurrent::CachedThreadPool.new(min_threads: 0, max_threads: POOL_SIZE, idletime: 30)
13+
end
14+
15+
def schedule_block(&block)
16+
# think about a better way to handle cross thread connections
17+
if Rails.env.test?
18+
block.call
19+
return
20+
end
21+
22+
db = RailsMultisite::ConnectionManagement.current_db
23+
thread_pool.post do
24+
begin
25+
RailsMultisite::ConnectionManagement.with_connection(db) { block.call }
26+
rescue StandardError => e
27+
Discourse.warn_exception(e, message: "Discourse AI: Unable to stream reply")
28+
end
29+
end
30+
end
31+
32+
# keeping this in a static method so we don't capture ENV and other bits
33+
# this allows us to release memory earlier
34+
def queue_streamed_reply(io, persona, user, topic, post)
35+
schedule_block do
36+
begin
37+
io.write "HTTP/1.1 200 OK"
38+
io.write CRLF
39+
io.write "Content-Type: text/plain; charset=utf-8"
40+
io.write CRLF
41+
io.write "Transfer-Encoding: chunked"
42+
io.write CRLF
43+
io.write "Cache-Control: no-cache, no-store, must-revalidate"
44+
io.write CRLF
45+
io.write "Connection: close"
46+
io.write CRLF
47+
io.write "X-Accel-Buffering: no"
48+
io.write CRLF
49+
io.write "X-Content-Type-Options: nosniff"
50+
io.write CRLF
51+
io.write CRLF
52+
io.flush
53+
54+
persona_class =
55+
DiscourseAi::AiBot::Personas::Persona.find_by(id: persona.id, user: user)
56+
bot = DiscourseAi::AiBot::Bot.as(persona.user, persona: persona_class.new)
57+
58+
data =
59+
{
60+
topic_id: topic.id,
61+
bot_user_id: persona.user.id,
62+
persona_id: persona.id,
63+
}.to_json + "\n\n"
64+
65+
io.write data.bytesize.to_s(16)
66+
io.write CRLF
67+
io.write data
68+
io.write CRLF
69+
70+
DiscourseAi::AiBot::Playground
71+
.new(bot)
72+
.reply_to(post) do |partial|
73+
next if partial.length == 0
74+
75+
data = { partial: partial }.to_json + "\n\n"
76+
77+
data.force_encoding("UTF-8")
78+
79+
io.write data.bytesize.to_s(16)
80+
io.write CRLF
81+
io.write data
82+
io.write CRLF
83+
io.flush
84+
end
85+
86+
io.write "0"
87+
io.write CRLF
88+
io.write CRLF
89+
90+
io.flush
91+
io.done
92+
rescue StandardError => e
93+
# make it a tiny bit easier to debug in dev, this is tricky
94+
# multi-threaded code that exhibits various limitations in rails
95+
p e if Rails.env.development?
96+
Discourse.warn_exception(e, message: "Discourse AI: Unable to stream reply")
97+
ensure
98+
io.close
99+
end
100+
end
101+
end
102+
end
103+
end
104+
end
105+
end

spec/requests/admin/ai_personas_controller_spec.rb

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -490,13 +490,16 @@ def validate_streamed_response(raw_http, expected)
490490

491491
it "is able to create a new conversation" do
492492
Jobs.run_immediately!
493+
# trust level 0
494+
SiteSetting.ai_bot_allowed_groups = "10"
493495

494496
fake_endpoint.fake_content = ["This is a test! Testing!", "An amazing title"]
495497

496498
ai_persona.create_user!
497499
ai_persona.update!(
498-
allowed_group_ids: [Group::AUTO_GROUPS[:staff]],
500+
allowed_group_ids: [Group::AUTO_GROUPS[:trust_level_0]],
499501
default_llm: "custom:#{llm.id}",
502+
allow_personal_messages: true,
500503
)
501504

502505
io_out, io_in = IO.pipe
@@ -530,6 +533,7 @@ def validate_streamed_response(raw_http, expected)
530533
expect(topic.topic_allowed_users.count).to eq(2)
531534
expect(topic.archetype).to eq(Archetype.private_message)
532535
expect(topic.title).to eq("An amazing title")
536+
expect(topic.posts.count).to eq(2)
533537

534538
# now let's try to make a reply with a tool call
535539
function_call = <<~XML
@@ -546,6 +550,16 @@ def validate_streamed_response(raw_http, expected)
546550

547551
ai_persona.update!(tools: ["Categories"])
548552

553+
# lets also unstage the user and add the user to tl0
554+
# this will ensure there are no feedback loops
555+
new_user = user_post.user
556+
new_user.update!(staged: false)
557+
Group.user_trust_level_change!(new_user.id, new_user.trust_level)
558+
559+
# double check this happened and user is in group
560+
personas = AiPersona.allowed_modalities(user: new_user.reload, allow_personal_messages: true)
561+
expect(personas.count).to eq(1)
562+
549563
io_out, io_in = IO.pipe
550564

551565
post "/admin/plugins/discourse-ai/ai-personas/stream-reply.json",
@@ -579,6 +593,8 @@ def validate_streamed_response(raw_http, expected)
579593
expect(user_post.user.custom_fields).to eq(
580594
{ "ai-stream-conversation-unique-id" => "site:test.com:user_id:1" },
581595
)
596+
597+
expect(topic.posts.count).to eq(4)
582598
end
583599
end
584600
end

0 commit comments

Comments
 (0)