@@ -15,6 +15,7 @@ use std::time::SystemTime;
15
15
use tracing:: { instrument, warn} ;
16
16
use walkdir:: WalkDir ;
17
17
18
+ use crate :: search_query_builder:: build_search_query;
18
19
use crate :: vector_search_query_builder:: build_vector_search_query;
19
20
use crate :: {
20
21
filter_builder, get_or_initialize_pool,
@@ -712,15 +713,42 @@ impl Collection {
712
713
713
714
#[ instrument( skip( self ) ) ]
714
715
pub async fn search (
715
- & self ,
716
+ & mut self ,
716
717
query : Json ,
717
- pipeline : & MultiFieldPipeline ,
718
+ pipeline : & mut MultiFieldPipeline ,
718
719
) -> anyhow:: Result < Vec < Json > > {
719
720
let pool = get_or_initialize_pool ( & self . database_url ) . await ?;
720
- let ( query, values) =
721
- crate :: search_query_builder:: build_search_query ( self , query, pipeline) . await ?;
722
- let results: Vec < ( Json , ) > = sqlx:: query_as_with ( & query, values) . fetch_all ( & pool) . await ?;
723
- Ok ( results. into_iter ( ) . map ( |r| r. 0 ) . collect ( ) )
721
+ let ( built_query, values) = build_search_query ( self , query. clone ( ) , pipeline) . await ?;
722
+ let results: Result < Vec < ( Json , ) > , _ > = sqlx:: query_as_with ( & built_query, values)
723
+ . fetch_all ( & pool)
724
+ . await ;
725
+
726
+ match results {
727
+ Ok ( r) => Ok ( r. into_iter ( ) . map ( |r| r. 0 ) . collect ( ) ) ,
728
+ Err ( e) => match e. as_database_error ( ) {
729
+ Some ( d) => {
730
+ if d. code ( ) == Some ( Cow :: from ( "XX000" ) ) {
731
+ self . verify_in_database ( false ) . await ?;
732
+ let project_info = & self
733
+ . database_data
734
+ . as_ref ( )
735
+ . context ( "Database data must be set to do remote embeddings search" ) ?
736
+ . project_info ;
737
+ pipeline. set_project_info ( project_info. to_owned ( ) ) ;
738
+ pipeline. verify_in_database ( false ) . await ?;
739
+ let ( built_query, values) =
740
+ build_search_query ( self , query, pipeline) . await ?;
741
+ let results: Vec < ( Json , ) > = sqlx:: query_as_with ( & built_query, values)
742
+ . fetch_all ( & pool)
743
+ . await ?;
744
+ Ok ( results. into_iter ( ) . map ( |r| r. 0 ) . collect ( ) )
745
+ } else {
746
+ Err ( anyhow:: anyhow!( e) )
747
+ }
748
+ }
749
+ None => Err ( anyhow:: anyhow!( e) ) ,
750
+ } ,
751
+ }
724
752
}
725
753
726
754
/// Performs vector search on the [Collection]
@@ -752,142 +780,72 @@ impl Collection {
752
780
pipeline : & mut MultiFieldPipeline ,
753
781
query_parameters : Option < Json > ,
754
782
top_k : Option < i64 > ,
755
- ) -> anyhow:: Result < Vec < ( f64 , String , Json ) > > {
783
+ ) -> anyhow:: Result < Vec < Json > > {
756
784
let pool = get_or_initialize_pool ( & self . database_url ) . await ?;
757
785
758
- let ( query, sqlx_values) =
759
- build_vector_search_query ( query, self , query_parameters. unwrap_or_default ( ) , pipeline)
760
- . await ?;
761
-
762
- // With this system, we only do the wrong type of vector search once
763
- // let runtime = if pipeline.model.is_some() {
764
- // pipeline.model.as_ref().unwrap().runtime
765
- // } else {
766
- // ModelRuntime::Python
767
- // };
768
-
769
- unimplemented ! ( )
770
-
771
- // let pool = get_or_initialize_pool(&self.database_url).await?;
772
-
773
- // let query_parameters = query_parameters.unwrap_or_default();
774
- // let top_k = top_k.unwrap_or(5);
775
-
776
- // // With this system, we only do the wrong type of vector search once
777
- // let runtime = if pipeline.model.is_some() {
778
- // pipeline.model.as_ref().unwrap().runtime
779
- // } else {
780
- // ModelRuntime::Python
781
- // };
782
- // match runtime {
783
- // ModelRuntime::Python => {
784
- // let embeddings_table_name = format!("{}.{}_embeddings", self.name, pipeline.name);
785
-
786
- // let result = sqlx::query_as(&query_builder!(
787
- // queries::EMBED_AND_VECTOR_SEARCH,
788
- // self.pipelines_table_name,
789
- // embeddings_table_name,
790
- // self.chunks_table_name,
791
- // self.documents_table_name
792
- // ))
793
- // .bind(&pipeline.name)
794
- // .bind(query)
795
- // .bind(&query_parameters)
796
- // .bind(top_k)
797
- // .fetch_all(&pool)
798
- // .await;
799
-
800
- // match result {
801
- // Ok(r) => Ok(r),
802
- // Err(e) => match e.as_database_error() {
803
- // Some(d) => {
804
- // if d.code() == Some(Cow::from("XX000")) {
805
- // self.vector_search_with_remote_embeddings(
806
- // query,
807
- // pipeline,
808
- // query_parameters,
809
- // top_k,
810
- // &pool,
811
- // )
812
- // .await
813
- // } else {
814
- // Err(anyhow::anyhow!(e))
815
- // }
816
- // }
817
- // None => Err(anyhow::anyhow!(e)),
818
- // },
819
- // }
820
- // }
821
- // _ => {
822
- // self.vector_search_with_remote_embeddings(
823
- // query,
824
- // pipeline,
825
- // query_parameters,
826
- // top_k,
827
- // &pool,
828
- // )
829
- // .await
830
- // }
831
- // }
832
- // .map(|r| {
833
- // r.into_iter()
834
- // .map(|(score, id, metadata)| (1. - score, id, metadata))
835
- // .collect()
836
- // })
837
- }
838
-
839
- #[ instrument( skip( self , pool) ) ]
840
- #[ allow( clippy:: type_complexity) ]
841
- async fn vector_search_with_remote_embeddings (
842
- & mut self ,
843
- query : & str ,
844
- pipeline : & mut Pipeline ,
845
- query_parameters : Json ,
846
- top_k : i64 ,
847
- pool : & PgPool ,
848
- ) -> anyhow:: Result < Vec < ( f64 , String , Json ) > > {
849
- // TODO: Make this actually work maybe an alias for the new search or something idk
850
- unimplemented ! ( )
851
-
852
- // self.verify_in_database(false).await?;
853
-
854
- // // Have to set the project info before we can get and set the model
855
- // pipeline.set_project_info(
856
- // self.database_data
857
- // .as_ref()
858
- // .context(
859
- // "Collection must be verified to perform vector search with remote embeddings",
860
- // )?
861
- // .project_info
862
- // .clone(),
863
- // );
864
- // // Verify to get and set the model if we don't have it set on the pipeline yet
865
- // pipeline.verify_in_database(false).await?;
866
- // let model = pipeline
867
- // .model
868
- // .as_ref()
869
- // .context("Pipeline must be verified to perform vector search with remote embeddings")?;
870
-
871
- // // We need to make sure we are not mutably and immutably borrowing the same things
872
- // let embedding = {
873
- // let remote_embeddings =
874
- // build_remote_embeddings(model.runtime, &model.name, &query_parameters)?;
875
- // let mut embeddings = remote_embeddings.embed(vec![query.to_string()]).await?;
876
- // std::mem::take(&mut embeddings[0])
877
- // };
878
-
879
- // let embeddings_table_name = format!("{}.{}_embeddings", self.name, pipeline.name);
880
- // sqlx::query_as(&query_builder!(
881
- // queries::VECTOR_SEARCH,
882
- // embeddings_table_name,
883
- // self.chunks_table_name,
884
- // self.documents_table_name
885
- // ))
886
- // .bind(embedding)
887
- // .bind(top_k)
888
- // .fetch_all(pool)
889
- // .await
890
- // .map_err(|e| anyhow::anyhow!(e))
786
+ let ( built_query, values) = build_vector_search_query (
787
+ query,
788
+ self ,
789
+ query_parameters. clone ( ) . unwrap_or_default ( ) ,
790
+ pipeline,
791
+ )
792
+ . await ?;
793
+ let results: Result < Vec < ( Json , String , f64 ) > , _ > =
794
+ sqlx:: query_as_with ( & built_query, values)
795
+ . fetch_all ( & pool)
796
+ . await ;
797
+ match results {
798
+ Ok ( r) => Ok ( r
799
+ . into_iter ( )
800
+ . map ( |v| {
801
+ serde_json:: json!( {
802
+ "document" : v. 0 ,
803
+ "chunk" : v. 1 ,
804
+ "score" : v. 2
805
+ } )
806
+ . into ( )
807
+ } )
808
+ . collect ( ) ) ,
809
+ Err ( e) => match e. as_database_error ( ) {
810
+ Some ( d) => {
811
+ if d. code ( ) == Some ( Cow :: from ( "XX000" ) ) {
812
+ self . verify_in_database ( false ) . await ?;
813
+ let project_info = & self
814
+ . database_data
815
+ . as_ref ( )
816
+ . context ( "Database data must be set to do remote embeddings search" ) ?
817
+ . project_info ;
818
+ pipeline. set_project_info ( project_info. to_owned ( ) ) ;
819
+ pipeline. verify_in_database ( false ) . await ?;
820
+ let ( built_query, values) = build_vector_search_query (
821
+ query,
822
+ self ,
823
+ query_parameters. clone ( ) . unwrap_or_default ( ) ,
824
+ pipeline,
825
+ )
826
+ . await ?;
827
+ let results: Vec < ( Json , String , f64 ) > =
828
+ sqlx:: query_as_with ( & built_query, values)
829
+ . fetch_all ( & pool)
830
+ . await ?;
831
+ Ok ( results
832
+ . into_iter ( )
833
+ . map ( |v| {
834
+ serde_json:: json!( {
835
+ "document" : v. 0 ,
836
+ "chunk" : v. 1 ,
837
+ "score" : v. 2
838
+ } )
839
+ . into ( )
840
+ } )
841
+ . collect ( ) )
842
+ } else {
843
+ Err ( anyhow:: anyhow!( e) )
844
+ }
845
+ }
846
+ None => Err ( anyhow:: anyhow!( e) ) ,
847
+ } ,
848
+ }
891
849
}
892
850
893
851
#[ instrument( skip( self ) ) ]
0 commit comments