Skip to content

Commit 9aaa31b

Browse files
committed
Working conditional pipeline running on document upsert
1 parent 44ab0ed commit 9aaa31b

File tree

4 files changed

+378
-260
lines changed

4 files changed

+378
-260
lines changed

pgml-sdks/pgml/src/collection.rs

Lines changed: 148 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ use serde_json::json;
99
use sqlx::postgres::PgPool;
1010
use sqlx::Executor;
1111
use sqlx::PgConnection;
12+
use sqlx::Postgres;
13+
use sqlx::Transaction;
1214
use std::borrow::Cow;
1315
use std::path::Path;
1416
use std::sync::Arc;
@@ -274,18 +276,43 @@ impl Collection {
274276
/// ```
275277
#[instrument(skip(self))]
276278
pub async fn add_pipeline(&mut self, pipeline: &mut MultiFieldPipeline) -> anyhow::Result<()> {
279+
// The flow for this function:
280+
// 1. Create collection if it does not exists
281+
// 2. Create the pipeline if it does not exist and add it to the collection.pipelines table with ACTIVE = FALSE
282+
// 3. Create the tables for the collection_pipeline schema
283+
// 4. Start a transaction
284+
// 5. Sync the pipeline
285+
// 6. Set the pipeline ACTIVE = TRUE
286+
// 7. Commit the transaction
277287
self.verify_in_database(false).await?;
278288
let project_info = &self
279289
.database_data
280290
.as_ref()
281291
.context("Database data must be set to add a pipeline to a collection")?
282292
.project_info;
283293
pipeline.set_project_info(project_info.clone());
284-
pipeline.verify_in_database(true).await?;
294+
pipeline.verify_in_database(false).await?;
295+
pipeline.create_tables().await?;
296+
297+
let pool = get_or_initialize_pool(&self.database_url).await?;
298+
let transaction = pool.begin().await?;
299+
let transaction = Arc::new(Mutex::new(transaction));
300+
285301
let mp = MultiProgress::new();
286302
mp.println(format!("Added Pipeline {}, Now Syncing...", pipeline.name))?;
287-
self.sync_pipeline(pipeline).await?;
288-
eprintln!("Done Syncing {}\n", pipeline.name);
303+
pipeline.execute(None, transaction.clone()).await?;
304+
let mut transaction = Arc::into_inner(transaction)
305+
.context("Error transaction dangling")?
306+
.into_inner();
307+
sqlx::query(&query_builder!(
308+
"UPDATE %s SET active = TRUE WHERE name = $1",
309+
self.pipelines_table_name
310+
))
311+
.bind(&pipeline.name)
312+
.execute(&mut *transaction)
313+
.await?;
314+
transaction.commit().await?;
315+
mp.println(format!("Done Syncing {}\n", pipeline.name))?;
289316
Ok(())
290317
}
291318

@@ -308,28 +335,28 @@ impl Collection {
308335
/// }
309336
/// ```
310337
#[instrument(skip(self))]
311-
pub async fn remove_pipeline(
312-
&mut self,
313-
pipeline: &mut MultiFieldPipeline,
314-
) -> anyhow::Result<()> {
315-
let pool = get_or_initialize_pool(&self.database_url).await?;
338+
pub async fn remove_pipeline(&mut self, pipeline: &MultiFieldPipeline) -> anyhow::Result<()> {
339+
// The flow for this function:
340+
// Create collection if it does not exist
341+
// Begin a transaction
342+
// Drop the collection_pipeline schema
343+
// Delete the pipeline from the collection.pipelines table
344+
// Commit the transaction
316345
self.verify_in_database(false).await?;
317346
let project_info = &self
318347
.database_data
319348
.as_ref()
320-
.context("Database data must be set to remove pipeline from collection")?
349+
.context("Database data must be set to remove a pipeline from a collection")?
321350
.project_info;
322-
pipeline.set_project_info(project_info.clone());
323-
pipeline.verify_in_database(false).await?;
324-
351+
let pool = get_or_initialize_pool(&self.database_url).await?;
325352
let pipeline_schema = format!("{}_{}", project_info.name, pipeline.name);
326353

327354
let mut transaction = pool.begin().await?;
328355
transaction
329356
.execute(query_builder!("DROP SCHEMA IF EXISTS %s CASCADE", pipeline_schema).as_str())
330357
.await?;
331358
sqlx::query(&query_builder!(
332-
"UPDATE %s SET active = FALSE WHERE name = $1",
359+
"DELETE FROM %s WHERE name = $1",
333360
self.pipelines_table_name
334361
))
335362
.bind(&pipeline.name)
@@ -344,7 +371,7 @@ impl Collection {
344371
///
345372
/// # Arguments
346373
///
347-
/// * `pipeline` - The [Pipeline] to remove.
374+
/// * `pipeline` - The [Pipeline] to enable
348375
///
349376
/// # Example
350377
///
@@ -359,22 +386,18 @@ impl Collection {
359386
/// }
360387
/// ```
361388
#[instrument(skip(self))]
362-
pub async fn enable_pipeline(&self, pipeline: &Pipeline) -> anyhow::Result<()> {
363-
sqlx::query(&query_builder!(
364-
"UPDATE %s SET active = TRUE WHERE name = $1",
365-
self.pipelines_table_name
366-
))
367-
.bind(&pipeline.name)
368-
.execute(&get_or_initialize_pool(&self.database_url).await?)
369-
.await?;
370-
Ok(())
389+
pub async fn enable_pipeline(
390+
&mut self,
391+
pipeline: &mut MultiFieldPipeline,
392+
) -> anyhow::Result<()> {
393+
self.add_pipeline(pipeline).await
371394
}
372395

373396
/// Disables a [Pipeline] on the [Collection]
374397
///
375398
/// # Arguments
376399
///
377-
/// * `pipeline` - The [Pipeline] to remove.
400+
/// * `pipeline` - The [Pipeline] to disable
378401
///
379402
/// # Example
380403
///
@@ -389,14 +412,38 @@ impl Collection {
389412
/// }
390413
/// ```
391414
#[instrument(skip(self))]
392-
pub async fn disable_pipeline(&self, pipeline: &Pipeline) -> anyhow::Result<()> {
415+
pub async fn disable_pipeline(&mut self, pipeline: &MultiFieldPipeline) -> anyhow::Result<()> {
416+
// Our current system for keeping documents, chunks, embeddings, and tsvectors in sync
417+
// does not play nice with disabling and then re-enabling pipelines.
418+
// For now, when disabling a pipeline, simply delete its schema and remake it later
419+
// The flow for this function:
420+
// 1. Create the collection if it does not exist
421+
// 2. Begin a transaction
422+
// 3. Set the pipelines ACTIVE = FALSE in the collection.pipelines table
423+
// 4. Drop the collection_pipeline schema (this will get remade if they enable it again)
424+
// 5. Commit the transaction
425+
self.verify_in_database(false).await?;
426+
let project_info = &self
427+
.database_data
428+
.as_ref()
429+
.context("Database data must be set to remove a pipeline from a collection")?
430+
.project_info;
431+
let pool = get_or_initialize_pool(&self.database_url).await?;
432+
let pipeline_schema = format!("{}_{}", project_info.name, pipeline.name);
433+
434+
let mut transaction = pool.begin().await?;
393435
sqlx::query(&query_builder!(
394436
"UPDATE %s SET active = FALSE WHERE name = $1",
395437
self.pipelines_table_name
396438
))
397439
.bind(&pipeline.name)
398-
.execute(&get_or_initialize_pool(&self.database_url).await?)
440+
.execute(&mut *transaction)
399441
.await?;
442+
transaction
443+
.execute(query_builder!("DROP SCHEMA IF EXISTS %s CASCADE", pipeline_schema).as_str())
444+
.await?;
445+
transaction.commit().await?;
446+
400447
Ok(())
401448
}
402449

@@ -442,13 +489,21 @@ impl Collection {
442489
/// Ok(())
443490
/// }
444491
/// ```
445-
// TODO: Make it so if we upload the same documen twice it doesn't do anything
446492
#[instrument(skip(self, documents))]
447493
pub async fn upsert_documents(
448494
&mut self,
449495
documents: Vec<Json>,
450496
_args: Option<Json>,
451497
) -> anyhow::Result<()> {
498+
// The flow for this function
499+
// 1. Create the collection if it does not exist
500+
// 2. Get all pipelines where ACTIVE = TRUE
501+
// 3. Create each pipeline and the collection_pipeline schema and tables if they don't already exist
502+
// 4. Foreach document
503+
// -> Begin a transaction returning the old document if it existed
504+
// -> Insert the document
505+
// -> Foreach pipeline check if we need to resync the document and if so sync the document
506+
// -> Commit the transaction
452507
let pool = get_or_initialize_pool(&self.database_url).await?;
453508
self.verify_in_database(false).await?;
454509
let mut pipelines = self.get_pipelines().await?;
@@ -468,20 +523,55 @@ impl Collection {
468523
let md5_digest = md5::compute(id.as_bytes());
469524
let source_uuid = uuid::Uuid::from_slice(&md5_digest.0)?;
470525

471-
let document_id: i64 = sqlx::query_scalar(&query_builder!("INSERT INTO %s (source_uuid, document) VALUES ($1, $2) ON CONFLICT (source_uuid) DO UPDATE SET document = $2 RETURNING id", self.documents_table_name)).bind(source_uuid).bind(document).fetch_one(&mut *transaction).await?;
526+
let (document_id, previous_document): (i64, Option<Json>) = sqlx::query_as(&query_builder!(
527+
"WITH prev AS (SELECT document FROM %s WHERE source_uuid = $1) INSERT INTO %s (source_uuid, document) VALUES ($1, $2) ON CONFLICT (source_uuid) DO UPDATE SET document = EXCLUDED.document RETURNING id, (SELECT document FROM prev)",
528+
self.documents_table_name,
529+
self.documents_table_name
530+
))
531+
.bind(&source_uuid)
532+
.bind(&document)
533+
.fetch_one(&mut *transaction)
534+
.await?;
472535

473536
let transaction = Arc::new(Mutex::new(transaction));
474537
if !pipelines.is_empty() {
475538
use futures::stream::StreamExt;
476539
futures::stream::iter(&mut pipelines)
477540
// Need this map to get around moving the transaction
478-
.map(|pipeline| (pipeline, transaction.clone()))
479-
.for_each_concurrent(10, |(pipeline, transaction)| async move {
480-
pipeline
481-
.execute(Some(document_id), transaction)
482-
.await
483-
.expect("Failed to execute pipeline");
541+
.map(|pipeline| {
542+
(
543+
pipeline,
544+
previous_document.clone(),
545+
document.clone(),
546+
transaction.clone(),
547+
)
484548
})
549+
.for_each_concurrent(
550+
10,
551+
|(pipeline, previous_document, document, transaction)| async move {
552+
// Can unwrap here as we know it has parsed schema from the create_table call above
553+
match previous_document {
554+
Some(previous_document) => {
555+
let should_run =
556+
pipeline.parsed_schema.as_ref().unwrap().iter().any(
557+
|(key, _)| document[key] != previous_document[key],
558+
);
559+
if should_run {
560+
pipeline
561+
.execute(Some(document_id), transaction)
562+
.await
563+
.expect("Failed to execute pipeline");
564+
}
565+
}
566+
None => {
567+
pipeline
568+
.execute(Some(document_id), transaction)
569+
.await
570+
.expect("Failed to execute pipeline");
571+
}
572+
}
573+
},
574+
)
485575
.await;
486576
}
487577

@@ -705,29 +795,30 @@ impl Collection {
705795
// Ok(())
706796
}
707797

708-
#[instrument(skip(self))]
709-
async fn sync_pipeline(&mut self, pipeline: &mut MultiFieldPipeline) -> anyhow::Result<()> {
710-
self.verify_in_database(false).await?;
711-
let project_info = &self
712-
.database_data
713-
.as_ref()
714-
.context("Database data must be set to get collection pipelines")?
715-
.project_info;
716-
pipeline.set_project_info(project_info.clone());
717-
pipeline.create_tables().await?;
718-
719-
let pool = get_or_initialize_pool(&self.database_url).await?;
720-
let transaction = pool.begin().await?;
721-
let transaction = Arc::new(Mutex::new(transaction));
722-
pipeline.execute(None, transaction.clone()).await?;
723-
724-
Arc::into_inner(transaction)
725-
.context("Error transaction dangling")?
726-
.into_inner()
727-
.commit()
728-
.await?;
729-
Ok(())
730-
}
798+
// #[instrument(skip(self))]
799+
// async fn sync_pipeline(
800+
// &mut self,
801+
// pipeline: &mut MultiFieldPipeline,
802+
// transaction: Arc<Mutex<Transaction<'static, Postgres>>>,
803+
// ) -> anyhow::Result<()> {
804+
// self.verify_in_database(false).await?;
805+
// let project_info = &self
806+
// .database_data
807+
// .as_ref()
808+
// .context("Database data must be set to get collection pipelines")?
809+
// .project_info;
810+
// pipeline.set_project_info(project_info.clone());
811+
// pipeline.create_tables().await?;
812+
813+
// pipeline.execute(None, transaction).await?;
814+
815+
// Arc::into_inner(transaction)
816+
// .context("Error transaction dangling")?
817+
// .into_inner()
818+
// .commit()
819+
// .await?;
820+
// Ok(())
821+
// }
731822

732823
#[instrument(skip(self))]
733824
pub async fn search(

pgml-sdks/pgml/src/lib.rs

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ mod tests {
309309
#[sqlx::test]
310310
async fn can_add_pipeline_and_upsert_documents() -> anyhow::Result<()> {
311311
internal_init_logger(None, None).ok();
312-
let collection_name = "test_r_c_capaud_44";
312+
let collection_name = "test_r_c_capaud_46";
313313
let pipeline_name = "test_r_p_capaud_6";
314314
let mut pipeline = MultiFieldPipeline::new(
315315
pipeline_name,
@@ -361,13 +361,13 @@ mod tests {
361361
.fetch_all(&pool)
362362
.await?;
363363
assert!(body_chunks.len() == 4);
364-
collection.archive().await?;
365364
let tsvectors_table = format!("{}_{}.body_tsvectors", collection_name, pipeline_name);
366365
let tsvectors: Vec<models::TSVector> =
367366
sqlx::query_as(&query_builder!("SELECT * FROM %s", tsvectors_table))
368367
.fetch_all(&pool)
369368
.await?;
370369
assert!(tsvectors.len() == 4);
370+
collection.archive().await?;
371371
Ok(())
372372
}
373373

@@ -588,6 +588,18 @@ mod tests {
588588
Ok(())
589589
}
590590

591+
#[sqlx::test]
592+
async fn can_update_documents() -> anyhow::Result<()> {
593+
let collection_name = "test_r_c_cud_0";
594+
let mut collection = Collection::new(collection_name, None);
595+
let mut documents = generate_dummy_documents(1);
596+
collection.upsert_documents(documents.clone(), None).await?;
597+
documents[0]["body"] = json!("new body");
598+
collection.upsert_documents(documents, None).await?;
599+
// collection.archive().await?;
600+
Ok(())
601+
}
602+
591603
#[sqlx::test]
592604
async fn can_search_with_local_embeddings() -> anyhow::Result<()> {
593605
internal_init_logger(None, None).ok();

0 commit comments

Comments
 (0)