Skip to content

Commit dfea784

Browse files
authored
DEV: Improve diff streaming accuracy with safety checker (#1338)
This update adds a safety checker which scans the streamed updates. It ensures that incomplete segments of text are not sent yet over message bus as this will cause breakage with the diff streamer. It also updates the diff streamer to handle a thinking state for when we are waiting for message bus updates.
1 parent ff2e18f commit dfea784

File tree

6 files changed

+240
-14
lines changed

6 files changed

+240
-14
lines changed

assets/javascripts/discourse/components/modal/diff-modal.gjs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ export default class ModalDiffModal extends Component {
2222
@service messageBus;
2323

2424
@tracked loading = false;
25+
@tracked finalResult = "";
2526
@tracked diffStreamer = new DiffStreamer(this.args.model.selectedText);
2627
@tracked suggestion = "";
2728
@tracked
@@ -65,6 +66,10 @@ export default class ModalDiffModal extends Component {
6566
async updateResult(result) {
6667
this.loading = false;
6768

69+
if (result.done) {
70+
this.finalResult = result.result;
71+
}
72+
6873
if (this.args.model.showResultAsDiff) {
6974
this.diffStreamer.updateResult(result, "result");
7075
} else {
@@ -105,10 +110,14 @@ export default class ModalDiffModal extends Component {
105110
);
106111
}
107112

108-
if (this.args.model.showResultAsDiff && this.diffStreamer.suggestion) {
113+
const finalResult =
114+
this.finalResult?.length > 0
115+
? this.finalResult
116+
: this.diffStreamer.suggestion;
117+
if (this.args.model.showResultAsDiff && finalResult) {
109118
this.args.model.toolbarEvent.replaceText(
110119
this.args.model.selectedText,
111-
this.diffStreamer.suggestion
120+
finalResult
112121
);
113122
}
114123
}
@@ -131,6 +140,7 @@ export default class ModalDiffModal extends Component {
131140
"composer-ai-helper-modal__suggestion"
132141
"streamable-content"
133142
(if this.isStreaming "streaming")
143+
(if this.diffStreamer.isThinking "thinking")
134144
(if @model.showResultAsDiff "inline-diff")
135145
}}
136146
>

assets/javascripts/discourse/lib/diff-streamer.gjs

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ export default class DiffStreamer {
1212
@tracked lastResultText = "";
1313
@tracked diff = "";
1414
@tracked suggestion = "";
15+
@tracked isDone = false;
16+
@tracked isThinking = false;
1517
typingTimer = null;
1618
currentWordIndex = 0;
1719

@@ -35,6 +37,7 @@ export default class DiffStreamer {
3537
const newText = result[newTextKey];
3638
const diffText = newText.slice(this.lastResultText.length).trim();
3739
const newWords = diffText.split(/\s+/).filter(Boolean);
40+
this.isDone = result?.done;
3841

3942
if (newWords.length > 0) {
4043
this.isStreaming = true;
@@ -64,14 +67,20 @@ export default class DiffStreamer {
6467
* Highlights the current word if streaming is ongoing.
6568
*/
6669
#streamNextWord() {
67-
if (this.currentWordIndex === this.words.length) {
70+
if (this.currentWordIndex === this.words.length && !this.isDone) {
71+
this.isThinking = true;
72+
}
73+
74+
if (this.currentWordIndex === this.words.length && this.isDone) {
75+
this.isThinking = false;
6876
this.diff = this.#compareText(this.selectedText, this.suggestion, {
6977
markLastWord: false,
7078
});
7179
this.isStreaming = false;
7280
}
7381

7482
if (this.currentWordIndex < this.words.length) {
83+
this.isThinking = false;
7584
this.suggestion += this.words[this.currentWordIndex] + " ";
7685
this.diff = this.#compareText(this.selectedText, this.suggestion, {
7786
markLastWord: true,
@@ -99,29 +108,49 @@ export default class DiffStreamer {
99108
const oldWords = oldText.trim().split(/\s+/);
100109
const newWords = newText.trim().split(/\s+/);
101110

111+
// Track where the line breaks are in the original oldText
112+
const lineBreakMap = (() => {
113+
const lines = oldText.trim().split("\n");
114+
const map = new Set();
115+
let wordIndex = 0;
116+
117+
for (const line of lines) {
118+
const wordsInLine = line.trim().split(/\s+/);
119+
wordIndex += wordsInLine.length;
120+
map.add(wordIndex - 1); // Mark the last word in each line
121+
}
122+
123+
return map;
124+
})();
125+
102126
const diff = [];
103127
let i = 0;
104128

105-
while (i < oldWords.length) {
129+
while (i < oldWords.length || i < newWords.length) {
106130
const oldWord = oldWords[i];
107131
const newWord = newWords[i];
108132

109133
let wordHTML = "";
110-
let originalWordHTML = `<span class="ghost">${oldWord}</span>`;
111134

112135
if (newWord === undefined) {
113-
wordHTML = originalWordHTML;
136+
wordHTML = `<span class="ghost">${oldWord}</span>`;
114137
} else if (oldWord === newWord) {
115138
wordHTML = `<span class="same-word">${newWord}</span>`;
116139
} else if (oldWord !== newWord) {
117-
wordHTML = `<del>${oldWord}</del> <ins>${newWord}</ins>`;
140+
wordHTML = `<del>${oldWord ?? ""}</del> <ins>${newWord ?? ""}</ins>`;
118141
}
119142
120143
if (i === newWords.length - 1 && opts.markLastWord) {
121144
wordHTML = `<mark class="highlight">${wordHTML}</mark>`;
122145
}
123146
124147
diff.push(wordHTML);
148+
149+
// Add a line break after this word if it ended a line in the original text
150+
if (lineBreakMap.has(i)) {
151+
diff.push("<br>");
152+
}
153+
125154
i++;
126155
}
127156

assets/stylesheets/common/streaming.scss

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,3 +79,18 @@ article.streaming .cooked {
7979
}
8080
}
8181
}
82+
83+
@keyframes mark-blink {
84+
0%,
85+
100% {
86+
border-color: var(--highlight-high);
87+
}
88+
89+
50% {
90+
border-color: transparent;
91+
}
92+
}
93+
94+
.composer-ai-helper-modal__suggestion.thinking mark.highlight {
95+
animation: mark-blink 1s step-start 0s infinite;
96+
}

lib/ai_helper/assistant.rb

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -181,14 +181,15 @@ def stream_prompt(completion_prompt, input, user, channel, force_default_locale:
181181

182182
streamed_diff = parse_diff(input, partial_response) if completion_prompt.diff?
183183

184-
# Throttle the updates and
185-
# checking length prevents partial tags
186-
# that aren't sanitized correctly yet (i.e. '<output')
187-
# from being sent in the stream
184+
# Throttle updates and check for safe stream points
188185
if (streamed_result.length > 10 && (Time.now - start > 0.3)) || Rails.env.test?
189-
payload = { result: sanitize_result(streamed_result), diff: streamed_diff, done: false }
190-
publish_update(channel, payload, user)
191-
start = Time.now
186+
sanitized = sanitize_result(streamed_result)
187+
188+
if DiscourseAi::Utils::DiffUtils::SafetyChecker.safe_to_stream?(sanitized)
189+
payload = { result: sanitized, diff: streamed_diff, done: false }
190+
publish_update(channel, payload, user)
191+
start = Time.now
192+
end
192193
end
193194
end
194195

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# frozen_string_literal: true
2+
3+
require "cgi"
4+
5+
module DiscourseAi
6+
module Utils
7+
module DiffUtils
8+
class SafetyChecker
9+
def self.safe_to_stream?(html_text)
10+
new(html_text).safe?
11+
end
12+
13+
def initialize(html_text)
14+
@original_html = html_text
15+
@text = sanitize(html_text)
16+
end
17+
18+
def safe?
19+
return false if unclosed_markdown_links?
20+
return false if unclosed_raw_html_tag?
21+
return false if trailing_incomplete_url?
22+
return false if unclosed_backticks?
23+
return false if unbalanced_bold_or_italic?
24+
return false if incomplete_image_markdown?
25+
return false if unbalanced_quote_blocks?
26+
return false if unclosed_triple_backticks?
27+
return false if partial_emoji?
28+
29+
true
30+
end
31+
32+
private
33+
34+
def sanitize(html)
35+
text = html.gsub(%r{</?[^>]+>}, "") # remove tags like <span>, <del>, etc.
36+
CGI.unescapeHTML(text)
37+
end
38+
39+
def unclosed_markdown_links?
40+
open_brackets = @text.count("[")
41+
close_brackets = @text.count("]")
42+
open_parens = @text.count("(")
43+
close_parens = @text.count(")")
44+
45+
open_brackets != close_brackets || open_parens != close_parens
46+
end
47+
48+
def unclosed_raw_html_tag?
49+
last_lt = @text.rindex("<")
50+
last_gt = @text.rindex(">")
51+
last_lt && (!last_gt || last_gt < last_lt)
52+
end
53+
54+
def trailing_incomplete_url?
55+
last_word = @text.split(/\s/).last
56+
last_word =~ %r{\Ahttps?://[^\s]*\z} && last_word !~ /[)\].,!?:;'"]\z/
57+
end
58+
59+
def unclosed_backticks?
60+
@text.count("`").odd?
61+
end
62+
63+
def unbalanced_bold_or_italic?
64+
@text.scan(/\*\*/).count.odd? || @text.scan(/\*(?!\*)/).count.odd? ||
65+
@text.scan(/_/).count.odd?
66+
end
67+
68+
def incomplete_image_markdown?
69+
last_image = @text[/!\[.*?\]\(.*?$/, 0]
70+
last_image && last_image[-1] != ")"
71+
end
72+
73+
def unbalanced_quote_blocks?
74+
opens = @text.scan(/\[quote(=.*?)?\]/i).count
75+
closes = @text.scan(%r{\[/quote\]}i).count
76+
opens > closes
77+
end
78+
79+
def unclosed_triple_backticks?
80+
@text.scan(/```/).count.odd?
81+
end
82+
83+
def partial_emoji?
84+
text = @text.gsub(/!\[.*?\]\(.*?\)/, "").gsub(%r{https?://[^\s]+}, "")
85+
tokens = text.scan(/:[a-z0-9_+\-\.]+:?/i)
86+
tokens.any? { |token| token.start_with?(":") && !token.end_with?(":") }
87+
end
88+
end
89+
end
90+
end
91+
end
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# frozen_string_literal: true
2+
3+
RSpec.describe DiscourseAi::Utils::DiffUtils::SafetyChecker do
4+
describe "#safe?" do
5+
subject { described_class.new(text).safe? }
6+
7+
context "with safe text" do
8+
let(:text) { "This is a simple safe text without issues." }
9+
10+
it { is_expected.to eq(true) }
11+
12+
context "with normal HTML tags" do
13+
let(:text) { "Here is <strong>bold</strong> and <em>italic</em> text." }
14+
it { is_expected.to eq(true) }
15+
end
16+
17+
context "with balanced markdown and no partial emoji" do
18+
let(:text) { "This is **bold**, *italic*, and a smiley :smile:!" }
19+
it { is_expected.to eq(true) }
20+
end
21+
22+
context "with balanced quote blocks" do
23+
let(:text) { "[quote]Quoted text[/quote]" }
24+
it { is_expected.to eq(true) }
25+
end
26+
27+
context "with complete image markdown" do
28+
let(:text) { "![alt text](https://example.com/image.png)" }
29+
it { is_expected.to eq(true) }
30+
end
31+
end
32+
33+
context "with unsafe text" do
34+
context "with unclosed markdown link" do
35+
let(:text) { "This is a [link(https://example.com)" }
36+
it { is_expected.to eq(false) }
37+
end
38+
39+
context "with unclosed raw HTML tag" do
40+
let(:text) { "Text with <div unclosed tag" }
41+
it { is_expected.to eq(false) }
42+
end
43+
44+
context "with trailing incomplete URL" do
45+
let(:text) { "Check this out https://example.com/something" } # no closing punctuation
46+
it { is_expected.to eq(false) }
47+
end
48+
49+
context "with unclosed backticks" do
50+
let(:text) { "Here is some `inline code without closing" }
51+
it { is_expected.to eq(false) }
52+
end
53+
54+
context "with unbalanced bold or italic markdown" do
55+
let(:text) { "This is *italic without closing" }
56+
it { is_expected.to eq(false) }
57+
end
58+
59+
context "with incomplete image markdown" do
60+
let(:text) { "Image ![alt text](https://example.com/image.png" } # missing closing )
61+
it { is_expected.to eq(false) }
62+
end
63+
64+
context "with unbalanced quote blocks" do
65+
let(:text) { "[quote]Unclosed quote block" }
66+
it { is_expected.to eq(false) }
67+
end
68+
69+
context "with unclosed triple backticks" do
70+
let(:text) { "```code block without closing" }
71+
it { is_expected.to eq(false) }
72+
end
73+
74+
context "with partial emoji" do
75+
let(:text) { "A partial emoji :smile" }
76+
it { is_expected.to eq(false) }
77+
end
78+
end
79+
end
80+
end

0 commit comments

Comments
 (0)