@@ -9,6 +9,8 @@ use serde_json::json;
9
9
use sqlx:: postgres:: PgPool ;
10
10
use sqlx:: Executor ;
11
11
use sqlx:: PgConnection ;
12
+ use sqlx:: Postgres ;
13
+ use sqlx:: Transaction ;
12
14
use std:: borrow:: Cow ;
13
15
use std:: path:: Path ;
14
16
use std:: sync:: Arc ;
@@ -274,18 +276,43 @@ impl Collection {
274
276
/// ```
275
277
#[ instrument( skip( self ) ) ]
276
278
pub async fn add_pipeline ( & mut self , pipeline : & mut MultiFieldPipeline ) -> anyhow:: Result < ( ) > {
279
+ // The flow for this function:
280
+ // 1. Create collection if it does not exists
281
+ // 2. Create the pipeline if it does not exist and add it to the collection.pipelines table with ACTIVE = FALSE
282
+ // 3. Create the tables for the collection_pipeline schema
283
+ // 4. Start a transaction
284
+ // 5. Sync the pipeline
285
+ // 6. Set the pipeline ACTIVE = TRUE
286
+ // 7. Commit the transaction
277
287
self . verify_in_database ( false ) . await ?;
278
288
let project_info = & self
279
289
. database_data
280
290
. as_ref ( )
281
291
. context ( "Database data must be set to add a pipeline to a collection" ) ?
282
292
. project_info ;
283
293
pipeline. set_project_info ( project_info. clone ( ) ) ;
284
- pipeline. verify_in_database ( true ) . await ?;
294
+ pipeline. verify_in_database ( false ) . await ?;
295
+ pipeline. create_tables ( ) . await ?;
296
+
297
+ let pool = get_or_initialize_pool ( & self . database_url ) . await ?;
298
+ let transaction = pool. begin ( ) . await ?;
299
+ let transaction = Arc :: new ( Mutex :: new ( transaction) ) ;
300
+
285
301
let mp = MultiProgress :: new ( ) ;
286
302
mp. println ( format ! ( "Added Pipeline {}, Now Syncing..." , pipeline. name) ) ?;
287
- self . sync_pipeline ( pipeline) . await ?;
288
- eprintln ! ( "Done Syncing {}\n " , pipeline. name) ;
303
+ pipeline. execute ( None , transaction. clone ( ) ) . await ?;
304
+ let mut transaction = Arc :: into_inner ( transaction)
305
+ . context ( "Error transaction dangling" ) ?
306
+ . into_inner ( ) ;
307
+ sqlx:: query ( & query_builder ! (
308
+ "UPDATE %s SET active = TRUE WHERE name = $1" ,
309
+ self . pipelines_table_name
310
+ ) )
311
+ . bind ( & pipeline. name )
312
+ . execute ( & mut * transaction)
313
+ . await ?;
314
+ transaction. commit ( ) . await ?;
315
+ mp. println ( format ! ( "Done Syncing {}\n " , pipeline. name) ) ?;
289
316
Ok ( ( ) )
290
317
}
291
318
@@ -308,28 +335,28 @@ impl Collection {
308
335
/// }
309
336
/// ```
310
337
#[ instrument( skip( self ) ) ]
311
- pub async fn remove_pipeline (
312
- & mut self ,
313
- pipeline : & mut MultiFieldPipeline ,
314
- ) -> anyhow:: Result < ( ) > {
315
- let pool = get_or_initialize_pool ( & self . database_url ) . await ?;
338
+ pub async fn remove_pipeline ( & mut self , pipeline : & MultiFieldPipeline ) -> anyhow:: Result < ( ) > {
339
+ // The flow for this function:
340
+ // Create collection if it does not exist
341
+ // Begin a transaction
342
+ // Drop the collection_pipeline schema
343
+ // Delete the pipeline from the collection.pipelines table
344
+ // Commit the transaction
316
345
self . verify_in_database ( false ) . await ?;
317
346
let project_info = & self
318
347
. database_data
319
348
. as_ref ( )
320
- . context ( "Database data must be set to remove pipeline from collection" ) ?
349
+ . context ( "Database data must be set to remove a pipeline from a collection" ) ?
321
350
. project_info ;
322
- pipeline. set_project_info ( project_info. clone ( ) ) ;
323
- pipeline. verify_in_database ( false ) . await ?;
324
-
351
+ let pool = get_or_initialize_pool ( & self . database_url ) . await ?;
325
352
let pipeline_schema = format ! ( "{}_{}" , project_info. name, pipeline. name) ;
326
353
327
354
let mut transaction = pool. begin ( ) . await ?;
328
355
transaction
329
356
. execute ( query_builder ! ( "DROP SCHEMA IF EXISTS %s CASCADE" , pipeline_schema) . as_str ( ) )
330
357
. await ?;
331
358
sqlx:: query ( & query_builder ! (
332
- "UPDATE %s SET active = FALSE WHERE name = $1" ,
359
+ "DELETE FROM %s WHERE name = $1" ,
333
360
self . pipelines_table_name
334
361
) )
335
362
. bind ( & pipeline. name )
@@ -344,7 +371,7 @@ impl Collection {
344
371
///
345
372
/// # Arguments
346
373
///
347
- /// * `pipeline` - The [Pipeline] to remove.
374
+ /// * `pipeline` - The [Pipeline] to enable
348
375
///
349
376
/// # Example
350
377
///
@@ -359,22 +386,18 @@ impl Collection {
359
386
/// }
360
387
/// ```
361
388
#[ instrument( skip( self ) ) ]
362
- pub async fn enable_pipeline ( & self , pipeline : & Pipeline ) -> anyhow:: Result < ( ) > {
363
- sqlx:: query ( & query_builder ! (
364
- "UPDATE %s SET active = TRUE WHERE name = $1" ,
365
- self . pipelines_table_name
366
- ) )
367
- . bind ( & pipeline. name )
368
- . execute ( & get_or_initialize_pool ( & self . database_url ) . await ?)
369
- . await ?;
370
- Ok ( ( ) )
389
+ pub async fn enable_pipeline (
390
+ & mut self ,
391
+ pipeline : & mut MultiFieldPipeline ,
392
+ ) -> anyhow:: Result < ( ) > {
393
+ self . add_pipeline ( pipeline) . await
371
394
}
372
395
373
396
/// Disables a [Pipeline] on the [Collection]
374
397
///
375
398
/// # Arguments
376
399
///
377
- /// * `pipeline` - The [Pipeline] to remove.
400
+ /// * `pipeline` - The [Pipeline] to disable
378
401
///
379
402
/// # Example
380
403
///
@@ -389,14 +412,38 @@ impl Collection {
389
412
/// }
390
413
/// ```
391
414
#[ instrument( skip( self ) ) ]
392
- pub async fn disable_pipeline ( & self , pipeline : & Pipeline ) -> anyhow:: Result < ( ) > {
415
+ pub async fn disable_pipeline ( & mut self , pipeline : & MultiFieldPipeline ) -> anyhow:: Result < ( ) > {
416
+ // Our current system for keeping documents, chunks, embeddings, and tsvectors in sync
417
+ // does not play nice with disabling and then re-enabling pipelines.
418
+ // For now, when disabling a pipeline, simply delete its schema and remake it later
419
+ // The flow for this function:
420
+ // 1. Create the collection if it does not exist
421
+ // 2. Begin a transaction
422
+ // 3. Set the pipelines ACTIVE = FALSE in the collection.pipelines table
423
+ // 4. Drop the collection_pipeline schema (this will get remade if they enable it again)
424
+ // 5. Commit the transaction
425
+ self . verify_in_database ( false ) . await ?;
426
+ let project_info = & self
427
+ . database_data
428
+ . as_ref ( )
429
+ . context ( "Database data must be set to remove a pipeline from a collection" ) ?
430
+ . project_info ;
431
+ let pool = get_or_initialize_pool ( & self . database_url ) . await ?;
432
+ let pipeline_schema = format ! ( "{}_{}" , project_info. name, pipeline. name) ;
433
+
434
+ let mut transaction = pool. begin ( ) . await ?;
393
435
sqlx:: query ( & query_builder ! (
394
436
"UPDATE %s SET active = FALSE WHERE name = $1" ,
395
437
self . pipelines_table_name
396
438
) )
397
439
. bind ( & pipeline. name )
398
- . execute ( & get_or_initialize_pool ( & self . database_url ) . await ? )
440
+ . execute ( & mut * transaction )
399
441
. await ?;
442
+ transaction
443
+ . execute ( query_builder ! ( "DROP SCHEMA IF EXISTS %s CASCADE" , pipeline_schema) . as_str ( ) )
444
+ . await ?;
445
+ transaction. commit ( ) . await ?;
446
+
400
447
Ok ( ( ) )
401
448
}
402
449
@@ -442,13 +489,21 @@ impl Collection {
442
489
/// Ok(())
443
490
/// }
444
491
/// ```
445
- // TODO: Make it so if we upload the same documen twice it doesn't do anything
446
492
#[ instrument( skip( self , documents) ) ]
447
493
pub async fn upsert_documents (
448
494
& mut self ,
449
495
documents : Vec < Json > ,
450
496
_args : Option < Json > ,
451
497
) -> anyhow:: Result < ( ) > {
498
+ // The flow for this function
499
+ // 1. Create the collection if it does not exist
500
+ // 2. Get all pipelines where ACTIVE = TRUE
501
+ // 3. Create each pipeline and the collection_pipeline schema and tables if they don't already exist
502
+ // 4. Foreach document
503
+ // -> Begin a transaction returning the old document if it existed
504
+ // -> Insert the document
505
+ // -> Foreach pipeline check if we need to resync the document and if so sync the document
506
+ // -> Commit the transaction
452
507
let pool = get_or_initialize_pool ( & self . database_url ) . await ?;
453
508
self . verify_in_database ( false ) . await ?;
454
509
let mut pipelines = self . get_pipelines ( ) . await ?;
@@ -468,20 +523,55 @@ impl Collection {
468
523
let md5_digest = md5:: compute ( id. as_bytes ( ) ) ;
469
524
let source_uuid = uuid:: Uuid :: from_slice ( & md5_digest. 0 ) ?;
470
525
471
- let document_id: i64 = sqlx:: query_scalar ( & query_builder ! ( "INSERT INTO %s (source_uuid, document) VALUES ($1, $2) ON CONFLICT (source_uuid) DO UPDATE SET document = $2 RETURNING id" , self . documents_table_name) ) . bind ( source_uuid) . bind ( document) . fetch_one ( & mut * transaction) . await ?;
526
+ let ( document_id, previous_document) : ( i64 , Option < Json > ) = sqlx:: query_as ( & query_builder ! (
527
+ "WITH prev AS (SELECT document FROM %s WHERE source_uuid = $1) INSERT INTO %s (source_uuid, document) VALUES ($1, $2) ON CONFLICT (source_uuid) DO UPDATE SET document = EXCLUDED.document RETURNING id, (SELECT document FROM prev)" ,
528
+ self . documents_table_name,
529
+ self . documents_table_name
530
+ ) )
531
+ . bind ( & source_uuid)
532
+ . bind ( & document)
533
+ . fetch_one ( & mut * transaction)
534
+ . await ?;
472
535
473
536
let transaction = Arc :: new ( Mutex :: new ( transaction) ) ;
474
537
if !pipelines. is_empty ( ) {
475
538
use futures:: stream:: StreamExt ;
476
539
futures:: stream:: iter ( & mut pipelines)
477
540
// Need this map to get around moving the transaction
478
- . map ( |pipeline| ( pipeline, transaction. clone ( ) ) )
479
- . for_each_concurrent ( 10 , |( pipeline, transaction) | async move {
480
- pipeline
481
- . execute ( Some ( document_id) , transaction)
482
- . await
483
- . expect ( "Failed to execute pipeline" ) ;
541
+ . map ( |pipeline| {
542
+ (
543
+ pipeline,
544
+ previous_document. clone ( ) ,
545
+ document. clone ( ) ,
546
+ transaction. clone ( ) ,
547
+ )
484
548
} )
549
+ . for_each_concurrent (
550
+ 10 ,
551
+ |( pipeline, previous_document, document, transaction) | async move {
552
+ // Can unwrap here as we know it has parsed schema from the create_table call above
553
+ match previous_document {
554
+ Some ( previous_document) => {
555
+ let should_run =
556
+ pipeline. parsed_schema . as_ref ( ) . unwrap ( ) . iter ( ) . any (
557
+ |( key, _) | document[ key] != previous_document[ key] ,
558
+ ) ;
559
+ if should_run {
560
+ pipeline
561
+ . execute ( Some ( document_id) , transaction)
562
+ . await
563
+ . expect ( "Failed to execute pipeline" ) ;
564
+ }
565
+ }
566
+ None => {
567
+ pipeline
568
+ . execute ( Some ( document_id) , transaction)
569
+ . await
570
+ . expect ( "Failed to execute pipeline" ) ;
571
+ }
572
+ }
573
+ } ,
574
+ )
485
575
. await ;
486
576
}
487
577
@@ -705,29 +795,30 @@ impl Collection {
705
795
// Ok(())
706
796
}
707
797
708
- #[ instrument( skip( self ) ) ]
709
- async fn sync_pipeline ( & mut self , pipeline : & mut MultiFieldPipeline ) -> anyhow:: Result < ( ) > {
710
- self . verify_in_database ( false ) . await ?;
711
- let project_info = & self
712
- . database_data
713
- . as_ref ( )
714
- . context ( "Database data must be set to get collection pipelines" ) ?
715
- . project_info ;
716
- pipeline. set_project_info ( project_info. clone ( ) ) ;
717
- pipeline. create_tables ( ) . await ?;
718
-
719
- let pool = get_or_initialize_pool ( & self . database_url ) . await ?;
720
- let transaction = pool. begin ( ) . await ?;
721
- let transaction = Arc :: new ( Mutex :: new ( transaction) ) ;
722
- pipeline. execute ( None , transaction. clone ( ) ) . await ?;
723
-
724
- Arc :: into_inner ( transaction)
725
- . context ( "Error transaction dangling" ) ?
726
- . into_inner ( )
727
- . commit ( )
728
- . await ?;
729
- Ok ( ( ) )
730
- }
798
+ // #[instrument(skip(self))]
799
+ // async fn sync_pipeline(
800
+ // &mut self,
801
+ // pipeline: &mut MultiFieldPipeline,
802
+ // transaction: Arc<Mutex<Transaction<'static, Postgres>>>,
803
+ // ) -> anyhow::Result<()> {
804
+ // self.verify_in_database(false).await?;
805
+ // let project_info = &self
806
+ // .database_data
807
+ // .as_ref()
808
+ // .context("Database data must be set to get collection pipelines")?
809
+ // .project_info;
810
+ // pipeline.set_project_info(project_info.clone());
811
+ // pipeline.create_tables().await?;
812
+
813
+ // pipeline.execute(None, transaction).await?;
814
+
815
+ // Arc::into_inner(transaction)
816
+ // .context("Error transaction dangling")?
817
+ // .into_inner()
818
+ // .commit()
819
+ // .await?;
820
+ // Ok(())
821
+ // }
731
822
732
823
#[ instrument( skip( self ) ) ]
733
824
pub async fn search (
0 commit comments