Skip to content

Commit f9cb8a1

Browse files
committed
Cleaned tests and remote fallback working for search and vector_search
1 parent 9df3528 commit f9cb8a1

File tree

5 files changed

+270
-318
lines changed

5 files changed

+270
-318
lines changed

pgml-sdks/pgml/Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pgml-sdks/pgml/src/collection.rs

Lines changed: 98 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ use std::time::SystemTime;
1515
use tracing::{instrument, warn};
1616
use walkdir::WalkDir;
1717

18+
use crate::search_query_builder::build_search_query;
1819
use crate::vector_search_query_builder::build_vector_search_query;
1920
use crate::{
2021
filter_builder, get_or_initialize_pool,
@@ -712,15 +713,42 @@ impl Collection {
712713

713714
#[instrument(skip(self))]
714715
pub async fn search(
715-
&self,
716+
&mut self,
716717
query: Json,
717-
pipeline: &MultiFieldPipeline,
718+
pipeline: &mut MultiFieldPipeline,
718719
) -> anyhow::Result<Vec<Json>> {
719720
let pool = get_or_initialize_pool(&self.database_url).await?;
720-
let (query, values) =
721-
crate::search_query_builder::build_search_query(self, query, pipeline).await?;
722-
let results: Vec<(Json,)> = sqlx::query_as_with(&query, values).fetch_all(&pool).await?;
723-
Ok(results.into_iter().map(|r| r.0).collect())
721+
let (built_query, values) = build_search_query(self, query.clone(), pipeline).await?;
722+
let results: Result<Vec<(Json,)>, _> = sqlx::query_as_with(&built_query, values)
723+
.fetch_all(&pool)
724+
.await;
725+
726+
match results {
727+
Ok(r) => Ok(r.into_iter().map(|r| r.0).collect()),
728+
Err(e) => match e.as_database_error() {
729+
Some(d) => {
730+
if d.code() == Some(Cow::from("XX000")) {
731+
self.verify_in_database(false).await?;
732+
let project_info = &self
733+
.database_data
734+
.as_ref()
735+
.context("Database data must be set to do remote embeddings search")?
736+
.project_info;
737+
pipeline.set_project_info(project_info.to_owned());
738+
pipeline.verify_in_database(false).await?;
739+
let (built_query, values) =
740+
build_search_query(self, query, pipeline).await?;
741+
let results: Vec<(Json,)> = sqlx::query_as_with(&built_query, values)
742+
.fetch_all(&pool)
743+
.await?;
744+
Ok(results.into_iter().map(|r| r.0).collect())
745+
} else {
746+
Err(anyhow::anyhow!(e))
747+
}
748+
}
749+
None => Err(anyhow::anyhow!(e)),
750+
},
751+
}
724752
}
725753

726754
/// Performs vector search on the [Collection]
@@ -752,142 +780,72 @@ impl Collection {
752780
pipeline: &mut MultiFieldPipeline,
753781
query_parameters: Option<Json>,
754782
top_k: Option<i64>,
755-
) -> anyhow::Result<Vec<(f64, String, Json)>> {
783+
) -> anyhow::Result<Vec<Json>> {
756784
let pool = get_or_initialize_pool(&self.database_url).await?;
757785

758-
let (query, sqlx_values) =
759-
build_vector_search_query(query, self, query_parameters.unwrap_or_default(), pipeline)
760-
.await?;
761-
762-
// With this system, we only do the wrong type of vector search once
763-
// let runtime = if pipeline.model.is_some() {
764-
// pipeline.model.as_ref().unwrap().runtime
765-
// } else {
766-
// ModelRuntime::Python
767-
// };
768-
769-
unimplemented!()
770-
771-
// let pool = get_or_initialize_pool(&self.database_url).await?;
772-
773-
// let query_parameters = query_parameters.unwrap_or_default();
774-
// let top_k = top_k.unwrap_or(5);
775-
776-
// // With this system, we only do the wrong type of vector search once
777-
// let runtime = if pipeline.model.is_some() {
778-
// pipeline.model.as_ref().unwrap().runtime
779-
// } else {
780-
// ModelRuntime::Python
781-
// };
782-
// match runtime {
783-
// ModelRuntime::Python => {
784-
// let embeddings_table_name = format!("{}.{}_embeddings", self.name, pipeline.name);
785-
786-
// let result = sqlx::query_as(&query_builder!(
787-
// queries::EMBED_AND_VECTOR_SEARCH,
788-
// self.pipelines_table_name,
789-
// embeddings_table_name,
790-
// self.chunks_table_name,
791-
// self.documents_table_name
792-
// ))
793-
// .bind(&pipeline.name)
794-
// .bind(query)
795-
// .bind(&query_parameters)
796-
// .bind(top_k)
797-
// .fetch_all(&pool)
798-
// .await;
799-
800-
// match result {
801-
// Ok(r) => Ok(r),
802-
// Err(e) => match e.as_database_error() {
803-
// Some(d) => {
804-
// if d.code() == Some(Cow::from("XX000")) {
805-
// self.vector_search_with_remote_embeddings(
806-
// query,
807-
// pipeline,
808-
// query_parameters,
809-
// top_k,
810-
// &pool,
811-
// )
812-
// .await
813-
// } else {
814-
// Err(anyhow::anyhow!(e))
815-
// }
816-
// }
817-
// None => Err(anyhow::anyhow!(e)),
818-
// },
819-
// }
820-
// }
821-
// _ => {
822-
// self.vector_search_with_remote_embeddings(
823-
// query,
824-
// pipeline,
825-
// query_parameters,
826-
// top_k,
827-
// &pool,
828-
// )
829-
// .await
830-
// }
831-
// }
832-
// .map(|r| {
833-
// r.into_iter()
834-
// .map(|(score, id, metadata)| (1. - score, id, metadata))
835-
// .collect()
836-
// })
837-
}
838-
839-
#[instrument(skip(self, pool))]
840-
#[allow(clippy::type_complexity)]
841-
async fn vector_search_with_remote_embeddings(
842-
&mut self,
843-
query: &str,
844-
pipeline: &mut Pipeline,
845-
query_parameters: Json,
846-
top_k: i64,
847-
pool: &PgPool,
848-
) -> anyhow::Result<Vec<(f64, String, Json)>> {
849-
// TODO: Make this actually work maybe an alias for the new search or something idk
850-
unimplemented!()
851-
852-
// self.verify_in_database(false).await?;
853-
854-
// // Have to set the project info before we can get and set the model
855-
// pipeline.set_project_info(
856-
// self.database_data
857-
// .as_ref()
858-
// .context(
859-
// "Collection must be verified to perform vector search with remote embeddings",
860-
// )?
861-
// .project_info
862-
// .clone(),
863-
// );
864-
// // Verify to get and set the model if we don't have it set on the pipeline yet
865-
// pipeline.verify_in_database(false).await?;
866-
// let model = pipeline
867-
// .model
868-
// .as_ref()
869-
// .context("Pipeline must be verified to perform vector search with remote embeddings")?;
870-
871-
// // We need to make sure we are not mutably and immutably borrowing the same things
872-
// let embedding = {
873-
// let remote_embeddings =
874-
// build_remote_embeddings(model.runtime, &model.name, &query_parameters)?;
875-
// let mut embeddings = remote_embeddings.embed(vec![query.to_string()]).await?;
876-
// std::mem::take(&mut embeddings[0])
877-
// };
878-
879-
// let embeddings_table_name = format!("{}.{}_embeddings", self.name, pipeline.name);
880-
// sqlx::query_as(&query_builder!(
881-
// queries::VECTOR_SEARCH,
882-
// embeddings_table_name,
883-
// self.chunks_table_name,
884-
// self.documents_table_name
885-
// ))
886-
// .bind(embedding)
887-
// .bind(top_k)
888-
// .fetch_all(pool)
889-
// .await
890-
// .map_err(|e| anyhow::anyhow!(e))
786+
let (built_query, values) = build_vector_search_query(
787+
query,
788+
self,
789+
query_parameters.clone().unwrap_or_default(),
790+
pipeline,
791+
)
792+
.await?;
793+
let results: Result<Vec<(Json, String, f64)>, _> =
794+
sqlx::query_as_with(&built_query, values)
795+
.fetch_all(&pool)
796+
.await;
797+
match results {
798+
Ok(r) => Ok(r
799+
.into_iter()
800+
.map(|v| {
801+
serde_json::json!({
802+
"document": v.0,
803+
"chunk": v.1,
804+
"score": v.2
805+
})
806+
.into()
807+
})
808+
.collect()),
809+
Err(e) => match e.as_database_error() {
810+
Some(d) => {
811+
if d.code() == Some(Cow::from("XX000")) {
812+
self.verify_in_database(false).await?;
813+
let project_info = &self
814+
.database_data
815+
.as_ref()
816+
.context("Database data must be set to do remote embeddings search")?
817+
.project_info;
818+
pipeline.set_project_info(project_info.to_owned());
819+
pipeline.verify_in_database(false).await?;
820+
let (built_query, values) = build_vector_search_query(
821+
query,
822+
self,
823+
query_parameters.clone().unwrap_or_default(),
824+
pipeline,
825+
)
826+
.await?;
827+
let results: Vec<(Json, String, f64)> =
828+
sqlx::query_as_with(&built_query, values)
829+
.fetch_all(&pool)
830+
.await?;
831+
Ok(results
832+
.into_iter()
833+
.map(|v| {
834+
serde_json::json!({
835+
"document": v.0,
836+
"chunk": v.1,
837+
"score": v.2
838+
})
839+
.into()
840+
})
841+
.collect())
842+
} else {
843+
Err(anyhow::anyhow!(e))
844+
}
845+
}
846+
None => Err(anyhow::anyhow!(e)),
847+
},
848+
}
891849
}
892850

893851
#[instrument(skip(self))]

0 commit comments

Comments
 (0)