Skip to content

Commit b04ead6

Browse files
committed
Clean up vector search
1 parent f9cb8a1 commit b04ead6

File tree

3 files changed

+53
-64
lines changed

3 files changed

+53
-64
lines changed

pgml-sdks/pgml/src/collection.rs

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -776,20 +776,14 @@ impl Collection {
776776
#[allow(clippy::type_complexity)]
777777
pub async fn vector_search(
778778
&mut self,
779-
query: &str,
779+
query: Json,
780780
pipeline: &mut MultiFieldPipeline,
781-
query_parameters: Option<Json>,
782781
top_k: Option<i64>,
783782
) -> anyhow::Result<Vec<Json>> {
784783
let pool = get_or_initialize_pool(&self.database_url).await?;
785784

786-
let (built_query, values) = build_vector_search_query(
787-
query,
788-
self,
789-
query_parameters.clone().unwrap_or_default(),
790-
pipeline,
791-
)
792-
.await?;
785+
let (built_query, values) =
786+
build_vector_search_query(query.clone(), self, pipeline).await?;
793787
let results: Result<Vec<(Json, String, f64)>, _> =
794788
sqlx::query_as_with(&built_query, values)
795789
.fetch_all(&pool)
@@ -817,13 +811,8 @@ impl Collection {
817811
.project_info;
818812
pipeline.set_project_info(project_info.to_owned());
819813
pipeline.verify_in_database(false).await?;
820-
let (built_query, values) = build_vector_search_query(
821-
query,
822-
self,
823-
query_parameters.clone().unwrap_or_default(),
824-
pipeline,
825-
)
826-
.await?;
814+
let (built_query, values) =
815+
build_vector_search_query(query, self, pipeline).await?;
827816
let results: Vec<(Json, String, f64)> =
828817
sqlx::query_as_with(&built_query, values)
829818
.fetch_all(&pool)
@@ -862,6 +851,7 @@ impl Collection {
862851
.bind(&self.name)
863852
.execute(&mut *transaciton)
864853
.await?;
854+
// TODO: Alter pipeline schema
865855
sqlx::query(&query_builder!(
866856
"ALTER SCHEMA %s RENAME TO %s",
867857
&self.name,

pgml-sdks/pgml/src/lib.rs

Lines changed: 42 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -606,11 +606,11 @@ mod tests {
606606
#[sqlx::test]
607607
async fn can_vector_search_with_local_embeddings() -> anyhow::Result<()> {
608608
internal_init_logger(None, None).ok();
609-
let collection_name = "test_r_c_cvs_3";
609+
let collection_name = "test_r_c_cvswle_3";
610610
let mut collection = Collection::new(collection_name, None);
611611
let documents = generate_dummy_documents(10);
612612
collection.upsert_documents(documents.clone(), None).await?;
613-
let pipeline_name = "test_r_p_cvs_0";
613+
let pipeline_name = "test_r_p_cvswle_0";
614614
let mut pipeline = MultiFieldPipeline::new(
615615
pipeline_name,
616616
Some(
@@ -638,27 +638,27 @@ mod tests {
638638
collection.add_pipeline(&mut pipeline).await?;
639639
let results = collection
640640
.vector_search(
641-
"Test document: 2",
642-
&mut pipeline,
643-
Some(
644-
json!({
645-
"query": {
646-
"fields": {
647-
"title": {
648-
"full_text_search": "test",
649-
},
650-
"body": {},
641+
json!({
642+
"query": {
643+
"fields": {
644+
"title": {
645+
"query": "Test document: 2",
646+
"full_text_search": "test"
647+
},
648+
"body": {
649+
"query": "Test document: 2"
651650
},
652-
"filter": {
653-
"id": {
654-
"$gt": 3
655-
}
656-
}
657651
},
658-
"limit": 5
659-
})
660-
.into(),
661-
),
652+
"filter": {
653+
"id": {
654+
"$gt": 3
655+
}
656+
}
657+
},
658+
"limit": 5
659+
})
660+
.into(),
661+
&mut pipeline,
662662
None,
663663
)
664664
.await?;
@@ -674,11 +674,11 @@ mod tests {
674674
#[sqlx::test]
675675
async fn can_vector_search_with_remote_embeddings() -> anyhow::Result<()> {
676676
internal_init_logger(None, None).ok();
677-
let collection_name = "test_r_c_cvs_4";
677+
let collection_name = "test_r_c_cvswre_4";
678678
let mut collection = Collection::new(collection_name, None);
679679
let documents = generate_dummy_documents(10);
680680
collection.upsert_documents(documents.clone(), None).await?;
681-
let pipeline_name = "test_r_p_cvs_0";
681+
let pipeline_name = "test_r_p_cvswre_0";
682682
let mut pipeline = MultiFieldPipeline::new(
683683
pipeline_name,
684684
Some(
@@ -708,27 +708,27 @@ mod tests {
708708
let mut pipeline = MultiFieldPipeline::new(pipeline_name, None)?;
709709
let results = collection
710710
.vector_search(
711-
"Test document: 2",
712-
&mut pipeline,
713-
Some(
714-
json!({
715-
"query": {
716-
"fields": {
717-
"title": {
718-
"full_text_search": "test",
719-
},
720-
"body": {},
711+
json!({
712+
"query": {
713+
"fields": {
714+
"title": {
715+
"full_text_search": "test",
716+
"query": "Test document: 2"
717+
},
718+
"body": {
719+
"query": "Test document: 2"
721720
},
722-
"filter": {
723-
"id": {
724-
"$gt": 3
725-
}
726-
}
727721
},
728-
"limit": 5
729-
})
730-
.into(),
731-
),
722+
"filter": {
723+
"id": {
724+
"$gt": 3
725+
}
726+
}
727+
},
728+
"limit": 5
729+
})
730+
.into(),
731+
&mut pipeline,
732732
None,
733733
)
734734
.await?;

pgml-sdks/pgml/src/vector_search_query_builder.rs

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ struct ValidFullTextSearchAction {
2626

2727
#[derive(Debug, Deserialize)]
2828
struct ValidField {
29+
query: String,
2930
model_parameters: Option<Json>,
3031
full_text_search: Option<String>,
3132
}
@@ -43,9 +44,8 @@ struct ValidQuery {
4344
}
4445

4546
pub async fn build_vector_search_query(
46-
query_text: &str,
47-
collection: &Collection,
4847
query: Json,
48+
collection: &Collection,
4949
pipeline: &MultiFieldPipeline,
5050
) -> anyhow::Result<(String, SqlxValues)> {
5151
let valid_query: ValidQuery = serde_json::from_value(query.0)?;
@@ -107,7 +107,7 @@ pub async fn build_vector_search_query(
107107
Expr::cust(format!(
108108
"transformer => (SELECT schema #>> '{{{key},embed,model}}' FROM pipeline)",
109109
)),
110-
Expr::cust_with_values("text => $1", [query_text]),
110+
Expr::cust_with_values("text => $1", [vf.query]),
111111
Expr::cust(format!("kwargs => COALESCE((SELECT schema #> '{{{key},embed,model_parameters}}' FROM pipeline), '{{}}'::jsonb)")),
112112
]),
113113
Alias::new("embedding"),
@@ -144,9 +144,8 @@ pub async fn build_vector_search_query(
144144
&model.name,
145145
vf.model_parameters.as_ref(),
146146
)?;
147-
let mut embeddings = remote_embeddings
148-
.embed(vec![query_text.to_string()])
149-
.await?;
147+
let mut embeddings =
148+
remote_embeddings.embed(vec![vf.query.to_string()]).await?;
150149
std::mem::take(&mut embeddings[0])
151150
};
152151

0 commit comments

Comments
 (0)