Skip to content

Commit 0687ec7

Browse files
authored
FEATURE: allow embedding based search without hyde (#777)
This allows callers of embedding based search to bypass hyde. Hyde will expand the search term using an LLM, but if an LLM is performing the search we can skip this expansion. It also introduced some tests for the controller which we did not have
1 parent 72607c3 commit 0687ec7

File tree

3 files changed

+106
-18
lines changed

3 files changed

+106
-18
lines changed

app/controllers/discourse_ai/embeddings/embeddings_controller.rb

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ class EmbeddingsController < ::ApplicationController
99

1010
def search
1111
query = params[:q].to_s
12+
skip_hyde = params[:hyde].downcase.to_s == "false" || params[:hyde].to_s == "0"
1213

1314
if query.length < SiteSetting.min_search_term_length
1415
raise Discourse::InvalidParameters.new(:q)
@@ -31,14 +32,17 @@ def search
3132

3233
hijack do
3334
semantic_search
34-
.search_for_topics(query)
35+
.search_for_topics(query, _page = 1, hyde: !skip_hyde)
3536
.each { |topic_post| grouped_results.add(topic_post) }
3637

3738
render_serialized(grouped_results, GroupedSearchResultSerializer, result: grouped_results)
3839
end
3940
end
4041

4142
def quick_search
43+
# this search function searches posts (vs: topics)
44+
# it requires post embeddings and a reranker
45+
# it will not perform a hyde expantion
4246
query = params[:q].to_s
4347

4448
if query.length < SiteSetting.min_search_term_length

lib/embeddings/semantic_search.rb

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ def self.clear_cache_for(query)
1111

1212
Discourse.cache.delete(hyde_key)
1313
Discourse.cache.delete("#{hyde_key}-#{SiteSetting.ai_embeddings_model}")
14+
Discourse.cache.delete("-#{SiteSetting.ai_embeddings_model}")
1415
end
1516

1617
def initialize(guardian)
@@ -29,19 +30,14 @@ def cached_query?(query)
2930
Discourse.cache.read(embedding_key).present?
3031
end
3132

32-
def search_for_topics(query, page = 1)
33-
max_results_per_page = 100
34-
limit = [Search.per_filter, max_results_per_page].min + 1
35-
offset = (page - 1) * limit
36-
search = Search.new(query, { guardian: guardian })
37-
search_term = search.term
38-
39-
return [] if search_term.nil? || search_term.length < SiteSetting.min_search_term_length
40-
41-
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
42-
vector_rep =
43-
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
33+
def vector_rep
34+
@vector_rep ||=
35+
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(
36+
DiscourseAi::Embeddings::Strategies::Truncation.new,
37+
)
38+
end
4439

40+
def hyde_embedding(search_term)
4541
digest = OpenSSL::Digest::SHA1.hexdigest(search_term)
4642
hyde_key = build_hyde_key(digest, SiteSetting.ai_embeddings_semantic_search_hyde_model)
4743

@@ -57,14 +53,34 @@ def search_for_topics(query, page = 1)
5753
.cache
5854
.fetch(hyde_key, expires_in: 1.week) { hypothetical_post_from(search_term) }
5955

60-
hypothetical_post_embedding =
61-
Discourse
62-
.cache
63-
.fetch(embedding_key, expires_in: 1.week) { vector_rep.vector_from(hypothetical_post) }
56+
Discourse
57+
.cache
58+
.fetch(embedding_key, expires_in: 1.week) { vector_rep.vector_from(hypothetical_post) }
59+
end
60+
61+
def embedding(search_term)
62+
digest = OpenSSL::Digest::SHA1.hexdigest(search_term)
63+
embedding_key = build_embedding_key(digest, "", SiteSetting.ai_embeddings_model)
64+
65+
Discourse
66+
.cache
67+
.fetch(embedding_key, expires_in: 1.week) { vector_rep.vector_from(search_term) }
68+
end
69+
70+
def search_for_topics(query, page = 1, hyde: true)
71+
max_results_per_page = 100
72+
limit = [Search.per_filter, max_results_per_page].min + 1
73+
offset = (page - 1) * limit
74+
search = Search.new(query, { guardian: guardian })
75+
search_term = search.term
76+
77+
return [] if search_term.nil? || search_term.length < SiteSetting.min_search_term_length
78+
79+
search_embedding = hyde ? hyde_embedding(search_term) : embedding(search_term)
6480

6581
candidate_topic_ids =
6682
vector_rep.asymmetric_topics_similarity_search(
67-
hypothetical_post_embedding,
83+
search_embedding,
6884
limit: limit,
6985
offset: offset,
7086
)
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# frozen_string_literal: true
2+
3+
describe DiscourseAi::Embeddings::EmbeddingsController do
4+
context "when performing a topic search" do
5+
before do
6+
SiteSetting.min_search_term_length = 3
7+
SiteSetting.ai_embeddings_model = "text-embedding-3-small"
8+
DiscourseAi::Embeddings::SemanticSearch.clear_cache_for("test")
9+
SearchIndexer.enable
10+
end
11+
12+
fab!(:category)
13+
fab!(:subcategory) { Fabricate(:category, parent_category_id: category.id) }
14+
15+
fab!(:topic)
16+
fab!(:post) { Fabricate(:post, topic: topic) }
17+
18+
fab!(:topic_in_subcategory) { Fabricate(:topic, category: subcategory) }
19+
fab!(:post_in_subcategory) { Fabricate(:post, topic: topic_in_subcategory) }
20+
21+
def index(topic)
22+
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
23+
vector_rep =
24+
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
25+
26+
stub_request(:post, "https://api.openai.com/v1/embeddings").to_return(
27+
status: 200,
28+
body: JSON.dump({ data: [{ embedding: [0.1] * 1536 }] }),
29+
)
30+
31+
vector_rep.generate_representation_from(topic)
32+
end
33+
34+
def stub_embedding(query)
35+
embedding = [0.049382] * 1536
36+
EmbeddingsGenerationStubs.openai_service(SiteSetting.ai_embeddings_model, query, embedding)
37+
end
38+
39+
it "returns results correctly when performing a non Hyde search" do
40+
index(topic)
41+
index(topic_in_subcategory)
42+
43+
query = "test"
44+
stub_embedding(query)
45+
46+
get "/discourse-ai/embeddings/semantic-search.json?q=#{query}&hyde=false"
47+
48+
expect(response.status).to eq(200)
49+
expect(response.parsed_body["topics"].map { |t| t["id"] }).to contain_exactly(
50+
topic.id,
51+
topic_in_subcategory.id,
52+
)
53+
end
54+
55+
it "is able to filter to a specific category (including sub categories)" do
56+
index(topic)
57+
index(topic_in_subcategory)
58+
59+
query = "test category:#{category.slug}"
60+
stub_embedding("test")
61+
62+
get "/discourse-ai/embeddings/semantic-search.json?q=#{query}&hyde=false"
63+
64+
expect(response.status).to eq(200)
65+
expect(response.parsed_body["topics"].map { |t| t["id"] }).to eq([topic_in_subcategory.id])
66+
end
67+
end
68+
end

0 commit comments

Comments
 (0)