Skip to content

Commit e6b60f3

Browse files
committed
Added hybrid search example [skip ci]
1 parent bdc7747 commit e6b60f3

File tree

3 files changed

+143
-0
lines changed

3 files changed

+143
-0
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ Or check out some examples:
1919
- [Embeddings](https://github.com/pgvector/pgvector-rust/blob/master/examples/openai/src/main.rs) with OpenAI
2020
- [Binary embeddings](https://github.com/pgvector/pgvector-rust/blob/master/examples/cohere/src/main.rs) with Cohere
2121
- [Sentence embeddings](https://github.com/pgvector/pgvector-rust/blob/master/examples/candle/src/main.rs) with Candle
22+
- [Hybrid search](https://github.com/pgvector/pgvector-rust/blob/master/examples/hybrid_search/src/main.rs) with Candle (Reciprocal Rank Fusion)
2223
- [Recommendations](https://github.com/pgvector/pgvector-rust/blob/master/examples/disco/src/main.rs) with Disco
2324
- [Bulk loading](https://github.com/pgvector/pgvector-rust/blob/master/examples/loading/src/main.rs) with `COPY`
2425

examples/hybrid_search/Cargo.toml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
[package]
2+
name = "example"
3+
version = "0.1.0"
4+
edition = "2021"
5+
publish = false
6+
7+
[dependencies]
8+
candle-core = "0.6"
9+
candle-nn = "0.6"
10+
candle-transformers = "0.6"
11+
hf-hub = "0.3"
12+
pgvector = { path = "../..", features = ["postgres"] }
13+
postgres = "0.19"
14+
serde_json = "1"
15+
tokenizers = "0.19"

examples/hybrid_search/src/main.rs

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
// https://github.com/huggingface/candle/tree/main/candle-examples/examples/bert
2+
// https://huggingface.co/sentence-transformers/multi-qa-MiniLM-L6-cos-v1
3+
4+
use candle_core::{Device, Tensor};
5+
use candle_nn::VarBuilder;
6+
use candle_transformers::models::bert::{BertModel, Config, DTYPE};
7+
use hf_hub::api::sync::Api;
8+
use pgvector::Vector;
9+
use postgres::{Client, NoTls};
10+
use std::error::Error;
11+
use std::fs::read_to_string;
12+
use tokenizers::{PaddingParams, PaddingStrategy, Tokenizer};
13+
14+
fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
15+
let mut client = Client::configure()
16+
.host("localhost")
17+
.dbname("pgvector_example")
18+
.user(std::env::var("USER")?.as_str())
19+
.connect(NoTls)?;
20+
21+
client.execute("CREATE EXTENSION IF NOT EXISTS vector", &[])?;
22+
client.execute("DROP TABLE IF EXISTS documents", &[])?;
23+
client.execute(
24+
"CREATE TABLE documents (id serial PRIMARY KEY, content text, embedding vector(384))",
25+
&[],
26+
)?;
27+
client.execute(
28+
"CREATE INDEX ON documents USING GIN (to_tsvector('english', content))",
29+
&[],
30+
)?;
31+
32+
let model = EmbeddingModel::new("sentence-transformers/multi-qa-MiniLM-L6-cos-v1")?;
33+
34+
let input = [
35+
"The dog is barking",
36+
"The cat is purring",
37+
"The bear is growling",
38+
];
39+
let embeddings = input
40+
.iter()
41+
.map(|text| model.embed(text))
42+
.collect::<Result<Vec<_>, _>>()?;
43+
44+
for (content, embedding) in input.iter().zip(embeddings) {
45+
client.execute(
46+
"INSERT INTO documents (content, embedding) VALUES ($1, $2)",
47+
&[&content, &Vector::from(embedding)],
48+
)?;
49+
}
50+
51+
let sql = "
52+
WITH semantic_search AS (
53+
SELECT id, RANK () OVER (ORDER BY embedding <=> $2) AS rank
54+
FROM documents
55+
ORDER BY embedding <=> $2
56+
LIMIT 20
57+
),
58+
keyword_search AS (
59+
SELECT id, RANK () OVER (ORDER BY ts_rank_cd(to_tsvector('english', content), query) DESC)
60+
FROM documents, plainto_tsquery('english', $1) query
61+
WHERE to_tsvector('english', content) @@ query
62+
ORDER BY ts_rank_cd(to_tsvector('english', content), query) DESC
63+
LIMIT 20
64+
)
65+
SELECT
66+
COALESCE(semantic_search.id, keyword_search.id) AS id,
67+
COALESCE(1.0 / ($3::double precision + semantic_search.rank), 0.0) +
68+
COALESCE(1.0 / ($3::double precision + keyword_search.rank), 0.0) AS score
69+
FROM semantic_search
70+
FULL OUTER JOIN keyword_search ON semantic_search.id = keyword_search.id
71+
ORDER BY score DESC
72+
LIMIT 5
73+
";
74+
75+
let query = "growling bear";
76+
let query_embedding = model.embed(query)?;
77+
let k = 60.0;
78+
79+
for row in client.query(sql, &[&query, &Vector::from(query_embedding), &k])? {
80+
let id: i32 = row.get(0);
81+
let score: f64 = row.get(1);
82+
println!("document: {}, RRF score: {}", id, score);
83+
}
84+
85+
Ok(())
86+
}
87+
88+
struct EmbeddingModel {
89+
tokenizer: Tokenizer,
90+
model: BertModel,
91+
}
92+
93+
impl EmbeddingModel {
94+
pub fn new(model_id: &str) -> Result<Self, Box<dyn Error + Send + Sync>> {
95+
let api = Api::new()?;
96+
let repo = api.model(model_id.to_string());
97+
let tokenizer_path = repo.get("tokenizer.json")?;
98+
let config_path = repo.get("config.json")?;
99+
let weights_path = repo.get("model.safetensors")?;
100+
101+
let mut tokenizer = Tokenizer::from_file(tokenizer_path)?;
102+
let padding = PaddingParams {
103+
strategy: PaddingStrategy::BatchLongest,
104+
..Default::default()
105+
};
106+
tokenizer.with_padding(Some(padding));
107+
108+
let device = Device::Cpu;
109+
let config: Config = serde_json::from_str(&read_to_string(config_path)?)?;
110+
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[weights_path], DTYPE, &device)? };
111+
let model = BertModel::load(vb, &config)?;
112+
113+
Ok(Self { tokenizer, model })
114+
}
115+
116+
// embed one at a time since BertModel does not support attention mask
117+
// https://github.com/huggingface/candle/issues/1798
118+
fn embed(&self, text: &str) -> Result<Vec<f32>, Box<dyn Error + Send + Sync>> {
119+
let tokens = self.tokenizer.encode(text, true)?;
120+
let token_ids = Tensor::new(vec![tokens.get_ids().to_vec()], &self.model.device)?;
121+
let token_type_ids = token_ids.zeros_like()?;
122+
let embeddings = self.model.forward(&token_ids, &token_type_ids)?;
123+
let embeddings = (embeddings.sum(1)? / (embeddings.dim(1)? as f64))?;
124+
let embeddings = embeddings.broadcast_div(&embeddings.sqr()?.sum_keepdim(1)?.sqrt()?)?;
125+
Ok(embeddings.squeeze(0)?.to_vec1::<f32>()?)
126+
}
127+
}

0 commit comments

Comments
 (0)