Skip to content

Commit c980c34

Browse files
authored
REFACTOR: Simplify sentiment classification (#977)
This change adds a simpler class for sentiment classification, replacing the soon-to-be removed `Classificator` hierarchy. Additionally, it adds a method for classifying concurrently, speeding up the backfill rake task.
1 parent 6456a4f commit c980c34

File tree

5 files changed

+163
-10
lines changed

5 files changed

+163
-10
lines changed

app/jobs/regular/post_sentiment_analysis.rb

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,7 @@ def execute(args)
99
post = Post.find_by(id: post_id, post_type: Post.types[:regular])
1010
return if post&.raw.blank?
1111

12-
DiscourseAi::PostClassificator.new(
13-
DiscourseAi::Sentiment::SentimentClassification.new,
14-
).classify!(post)
12+
DiscourseAi::Sentiment::PostClassification.new.classify!(post)
1513
end
1614
end
1715
end

lib/inference/hugging_face_text_embeddings.rb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@ def rerank(content, candidates)
6464
JSON.parse(response.body, symbolize_names: true)
6565
end
6666

67-
def classify(content, model_config)
68-
headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" }
67+
def classify(content, model_config, base_url = Discourse.base_url)
68+
headers = { "Referer" => base_url, "Content-Type" => "application/json" }
6969
headers["X-API-KEY"] = model_config.api_key
7070
headers["Authorization"] = "Bearer #{model_config.api_key}"
7171

lib/sentiment/post_classification.rb

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# frozen_string_literal: true
2+
3+
module DiscourseAi
4+
module Sentiment
5+
class PostClassification
6+
def bulk_classify!(relation)
7+
http_pool_size = 100
8+
pool =
9+
Concurrent::CachedThreadPool.new(
10+
min_threads: 0,
11+
max_threads: http_pool_size,
12+
idletime: 30,
13+
)
14+
15+
available_classifiers = classifiers
16+
base_url = Discourse.base_url
17+
18+
promised_classifications =
19+
relation
20+
.map do |record|
21+
text = prepare_text(record)
22+
next if text.blank?
23+
24+
Concurrent::Promises
25+
.fulfilled_future({ target: record, text: text }, pool)
26+
.then_on(pool) do |w_text|
27+
results = Concurrent::Hash.new
28+
29+
promised_target_results =
30+
available_classifiers.map do |c|
31+
Concurrent::Promises.future_on(pool) do
32+
results[c.model_name] = request_with(w_text[:text], c, base_url)
33+
end
34+
end
35+
36+
Concurrent::Promises
37+
.zip(*promised_target_results)
38+
.then_on(pool) { |_| w_text.merge(classification: results) }
39+
end
40+
.flat(1)
41+
end
42+
.compact
43+
44+
Concurrent::Promises
45+
.zip(*promised_classifications)
46+
.value!
47+
.each { |r| store_classification(r[:target], r[:classification]) }
48+
49+
pool.shutdown
50+
pool.wait_for_termination
51+
end
52+
53+
def classify!(target)
54+
return if target.blank?
55+
56+
to_classify = prepare_text(target)
57+
return if to_classify.blank?
58+
59+
results =
60+
classifiers.reduce({}) do |memo, model|
61+
memo[model.model_name] = request_with(to_classify, model)
62+
memo
63+
end
64+
65+
store_classification(target, results)
66+
end
67+
68+
private
69+
70+
def prepare_text(target)
71+
content =
72+
if target.post_number == 1
73+
"#{target.topic.title}\n#{target.raw}"
74+
else
75+
target.raw
76+
end
77+
78+
Tokenizer::BertTokenizer.truncate(content, 512)
79+
end
80+
81+
def classifiers
82+
DiscourseAi::Sentiment::SentimentSiteSettingJsonSchema.values
83+
end
84+
85+
def request_with(content, config, base_url = Discourse.base_url)
86+
DiscourseAi::Inference::HuggingFaceTextEmbeddings.classify(content, config, base_url)
87+
end
88+
89+
def store_classification(target, classification)
90+
attrs =
91+
classification.map do |model_name, classifications|
92+
{
93+
model_used: model_name,
94+
target_id: target.id,
95+
target_type: target.class.sti_name,
96+
classification_type: :sentiment,
97+
classification: classifications,
98+
updated_at: DateTime.now,
99+
created_at: DateTime.now,
100+
}
101+
end
102+
103+
ClassificationResult.upsert_all(
104+
attrs,
105+
unique_by: %i[target_id target_type model_used],
106+
update_only: %i[classification],
107+
)
108+
end
109+
end
110+
end
111+
end

lib/tasks/modules/sentiment/backfill.rake

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,8 @@ task "ai:sentiment:backfill", [:start_post] => [:environment] do |_, args|
1414
.where("category_id IN (?)", public_categories)
1515
.where(posts: { deleted_at: nil })
1616
.where(topics: { deleted_at: nil })
17-
.order("posts.id ASC")
18-
.find_each do |post|
17+
.find_in_batches do |batch|
1918
print "."
20-
DiscourseAi::PostClassificator.new(
21-
DiscourseAi::Sentiment::SentimentClassification.new,
22-
).classify!(post)
19+
DiscourseAi::Sentiment::PostClassification.new.bulk_classify!(batch)
2320
end
2421
end
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# frozen_string_literal: true
2+
3+
require_relative "../../../support/sentiment_inference_stubs"
4+
5+
RSpec.describe DiscourseAi::Sentiment::PostClassification do
6+
fab!(:post_1) { Fabricate(:post, post_number: 2) }
7+
8+
before do
9+
SiteSetting.ai_sentiment_enabled = true
10+
SiteSetting.ai_sentiment_model_configs =
11+
"[{\"model_name\":\"SamLowe/roberta-base-go_emotions\",\"endpoint\":\"http://samlowe-emotion.com\",\"api_key\":\"123\"},{\"model_name\":\"j-hartmann/emotion-english-distilroberta-base\",\"endpoint\":\"http://jhartmann-emotion.com\",\"api_key\":\"123\"},{\"model_name\":\"cardiffnlp/twitter-roberta-base-sentiment-latest\",\"endpoint\":\"http://cardiffnlp-sentiment.com\",\"api_key\":\"123\"}]"
12+
end
13+
14+
describe "#classify!" do
15+
it "does nothing if the post content is blank" do
16+
post_1.update_columns(raw: "")
17+
18+
subject.classify!(post_1)
19+
20+
expect(ClassificationResult.where(target: post_1).count).to be_zero
21+
end
22+
23+
it "successfully classifies the post" do
24+
expected_analysis = DiscourseAi::Sentiment::SentimentSiteSettingJsonSchema.values.length
25+
SentimentInferenceStubs.stub_classification(post_1)
26+
27+
subject.classify!(post_1)
28+
29+
expect(ClassificationResult.where(target: post_1).count).to eq(expected_analysis)
30+
end
31+
end
32+
33+
describe "#classify_bulk!" do
34+
fab!(:post_2) { Fabricate(:post, post_number: 2) }
35+
36+
it "classifies all given posts" do
37+
expected_analysis = DiscourseAi::Sentiment::SentimentSiteSettingJsonSchema.values.length
38+
SentimentInferenceStubs.stub_classification(post_1)
39+
SentimentInferenceStubs.stub_classification(post_2)
40+
41+
subject.bulk_classify!(Post.where(id: [post_1.id, post_2.id]))
42+
43+
expect(ClassificationResult.where(target: post_1).count).to eq(expected_analysis)
44+
expect(ClassificationResult.where(target: post_2).count).to eq(expected_analysis)
45+
end
46+
end
47+
end

0 commit comments

Comments
 (0)