Skip to content

Commit 099ea60

Browse files
committed
Cleaned up
1 parent 6a9fd14 commit 099ea60

File tree

4 files changed

+29
-19
lines changed

4 files changed

+29
-19
lines changed

pgml-sdks/pgml/src/collection.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -507,7 +507,7 @@ impl Collection {
507507
)
508508
};
509509
let (document_id, previous_document): (i64, Option<Json>) = sqlx::query_as(&query)
510-
.bind(&source_uuid)
510+
.bind(source_uuid)
511511
.bind(&document)
512512
.fetch_one(&mut *transaction)
513513
.await?;

pgml-sdks/pgml/src/lib.rs

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ async fn get_or_initialize_pool(database_url: &Option<String>) -> anyhow::Result
7575

7676
let pool = PgPoolOptions::new()
7777
.acquire_timeout(std::time::Duration::from_millis(timeout))
78-
.connect_lazy(&url)?;
78+
.connect_lazy(url)?;
7979

8080
pools.insert(url.to_string(), pool.clone());
8181
Ok(pool)
@@ -289,7 +289,7 @@ mod tests {
289289
assert!(collection.database_data.is_none());
290290
collection.add_pipeline(&mut pipeline).await?;
291291
assert!(collection.database_data.is_some());
292-
collection.remove_pipeline(&mut pipeline).await?;
292+
collection.remove_pipeline(&pipeline).await?;
293293
let pipelines = collection.get_pipelines().await?;
294294
assert!(pipelines.is_empty());
295295
collection.archive().await?;
@@ -306,7 +306,7 @@ mod tests {
306306
collection.add_pipeline(&mut pipeline2).await?;
307307
let pipelines = collection.get_pipelines().await?;
308308
assert!(pipelines.len() == 2);
309-
collection.remove_pipeline(&mut pipeline1).await?;
309+
collection.remove_pipeline(&pipeline1).await?;
310310
let pipelines = collection.get_pipelines().await?;
311311
assert!(pipelines.len() == 1);
312312
assert!(collection.get_pipeline("test_r_p_carps_1").await.is_err());
@@ -317,7 +317,7 @@ mod tests {
317317
#[tokio::test]
318318
async fn can_add_pipeline_and_upsert_documents() -> anyhow::Result<()> {
319319
internal_init_logger(None, None).ok();
320-
let collection_name = "test_r_c_capaud_47";
320+
let collection_name = "test_r_c_capaud_48";
321321
let pipeline_name = "test_r_p_capaud_6";
322322
let mut pipeline = MultiFieldPipeline::new(
323323
pipeline_name,
@@ -333,7 +333,10 @@ mod tests {
333333
"model": "recursive_character"
334334
},
335335
"semantic_search": {
336-
"model": "intfloat/e5-small",
336+
"model": "hkunlp/instructor-base",
337+
"parameters": {
338+
"instruction": "Represent the Wikipedia document for retrieval"
339+
}
337340
},
338341
"full_text_search": {
339342
"configuration": "english"
@@ -490,7 +493,7 @@ mod tests {
490493
sqlx::query_as(&query_builder!("SELECT * FROM %s", chunks_table))
491494
.fetch_all(&pool)
492495
.await?;
493-
assert!(title_chunks.len() == 0);
496+
assert!(title_chunks.is_empty());
494497
collection.enable_pipeline(&mut pipeline).await?;
495498
let chunks_table = format!("{}_{}.title_chunks", collection_name, pipeline_name);
496499
let title_chunks: Vec<models::Chunk> =
@@ -707,7 +710,7 @@ mod tests {
707710
}
708711
})
709712
);
710-
collection.disable_pipeline(&mut pipeline).await?;
713+
collection.disable_pipeline(&pipeline).await?;
711714
collection
712715
.upsert_documents(documents[2..4].to_owned(), None)
713716
.await?;
@@ -813,7 +816,7 @@ mod tests {
813816
#[tokio::test]
814817
async fn can_search_with_local_embeddings() -> anyhow::Result<()> {
815818
internal_init_logger(None, None).ok();
816-
let collection_name = "test_r_c_cs_67";
819+
let collection_name = "test_r_c_cs_70";
817820
let mut collection = Collection::new(collection_name, None);
818821
let documents = generate_dummy_documents(10);
819822
collection.upsert_documents(documents.clone(), None).await?;
@@ -835,7 +838,10 @@ mod tests {
835838
"model": "recursive_character"
836839
},
837840
"semantic_search": {
838-
"model": "intfloat/e5-small"
841+
"model": "hkunlp/instructor-base",
842+
"parameters": {
843+
"instruction": "Represent the Wikipedia document for retrieval"
844+
}
839845
},
840846
"full_text_search": {
841847
"configuration": "english"
@@ -872,6 +878,9 @@ mod tests {
872878
},
873879
"body": {
874880
"query": "This is the body test",
881+
"parameters": {
882+
"instruction": "Represent the Wikipedia question for retrieving supporting documents: ",
883+
},
875884
"boost": 1.01
876885
},
877886
"notes": {
@@ -896,7 +905,7 @@ mod tests {
896905
.into_iter()
897906
.map(|r| r["document"]["id"].as_u64().unwrap())
898907
.collect();
899-
assert_eq!(ids, vec![3, 8, 2, 7, 4]);
908+
assert_eq!(ids, vec![7, 8, 2, 3, 4]);
900909
collection.archive().await?;
901910
Ok(())
902911
}

pgml-sdks/pgml/src/multi_field_pipeline.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ struct ValidFieldAction {
5050
full_text_search: Option<FullTextSearchAction>,
5151
}
5252

53+
#[allow(clippy::upper_case_acronyms)]
5354
#[derive(Debug, Clone)]
5455
pub struct HNSW {
5556
m: u64,
@@ -113,7 +114,7 @@ impl TryFrom<ValidFieldAction> for FieldAction {
113114
let model = Model::new(Some(v.model), v.source, v.parameters);
114115
let hnsw = v
115116
.hnsw
116-
.map(|v2| HNSW::try_from(v2))
117+
.map(HNSW::try_from)
117118
.unwrap_or_else(|| Ok(HNSW::default()))?;
118119
anyhow::Ok(SemanticSearchAction { model, hnsw })
119120
})
@@ -203,7 +204,7 @@ fn json_to_schema(schema: &Json) -> anyhow::Result<ParsedSchema> {
203204
#[alias_methods(new, get_status, to_dict)]
204205
impl MultiFieldPipeline {
205206
pub fn new(name: &str, schema: Option<Json>) -> anyhow::Result<Self> {
206-
let parsed_schema = schema.as_ref().map(|s| json_to_schema(s)).transpose()?;
207+
let parsed_schema = schema.as_ref().map(json_to_schema).transpose()?;
207208
Ok(Self {
208209
name: name.to_string(),
209210
schema,
@@ -250,7 +251,7 @@ impl MultiFieldPipeline {
250251

251252
results[key] = json!({});
252253

253-
if let Some(_) = value.splitter {
254+
if value.splitter.is_some() {
254255
let chunks_status: (Option<i64>, Option<i64>) = sqlx::query_as(&query_builder!(
255256
"SELECT (SELECT COUNT(DISTINCT document_id) FROM %s), COUNT(id) FROM %s",
256257
chunks_table_name,
@@ -265,7 +266,7 @@ impl MultiFieldPipeline {
265266
});
266267
}
267268

268-
if let Some(_) = value.semantic_search {
269+
if value.semantic_search.is_some() {
269270
let embeddings_table_name = format!("{schema}.{key}_embeddings");
270271
let embeddings_status: (Option<i64>, Option<i64>) =
271272
sqlx::query_as(&query_builder!(
@@ -282,7 +283,7 @@ impl MultiFieldPipeline {
282283
});
283284
}
284285

285-
if let Some(_) = value.full_text_search {
286+
if value.full_text_search.is_some() {
286287
let tsvectors_table_name = format!("{schema}.{key}_tsvectors");
287288
let tsvectors_status: (Option<i64>, Option<i64>) = sqlx::query_as(&query_builder!(
288289
"SELECT (SELECT count(*) FROM %s), (SELECT count(*) FROM %s)",

pgml-sdks/pgml/src/search_query_builder.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ use crate::{
2121
#[derive(Debug, Deserialize)]
2222
struct ValidSemanticSearchAction {
2323
query: String,
24-
model_parameters: Option<Json>,
24+
parameters: Option<Json>,
2525
boost: Option<f32>,
2626
}
2727

@@ -107,7 +107,7 @@ pub async fn build_search_query(
107107
"transformer => (SELECT schema #>> '{{{key},semantic_search,model}}' FROM pipeline)",
108108
)),
109109
Expr::cust_with_values("text => $1", [&vsa.query]),
110-
Expr::cust(format!("kwargs => COALESCE((SELECT schema #> '{{{key},semantic_search,model_parameters}}' FROM pipeline), '{{}}'::jsonb)")),
110+
Expr::cust_with_values("kwargs => $1", [vsa.parameters.unwrap_or_default().0]),
111111
]),
112112
Alias::new("embedding"),
113113
);
@@ -143,7 +143,7 @@ pub async fn build_search_query(
143143
let remote_embeddings = build_remote_embeddings(
144144
model.runtime,
145145
&model.name,
146-
vsa.model_parameters.as_ref(),
146+
vsa.parameters.as_ref(),
147147
)?;
148148
let mut embeddings = remote_embeddings.embed(vec![vsa.query]).await?;
149149
std::mem::take(&mut embeddings[0])

0 commit comments

Comments
 (0)