Skip to content

Commit f75a2ec

Browse files
committed
Finished pipeline as a pass through and more tests
1 parent 9df12b5 commit f75a2ec

File tree

3 files changed

+374
-280
lines changed

3 files changed

+374
-280
lines changed

pgml-sdks/pgml/src/lib.rs

Lines changed: 220 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -646,6 +646,158 @@ mod tests {
646646
Ok(())
647647
}
648648

649+
#[sqlx::test]
650+
async fn pipeline_sync_status() -> anyhow::Result<()> {
651+
internal_init_logger(None, None).ok();
652+
let collection_name = "test_r_c_pss_5";
653+
let mut collection = Collection::new(collection_name, None);
654+
let pipeline_name = "test_r_p_pss_0";
655+
let mut pipeline = MultiFieldPipeline::new(
656+
pipeline_name,
657+
Some(
658+
json!({
659+
"title": {
660+
"embed": {
661+
"model": "intfloat/e5-small"
662+
},
663+
"full_text_search": {
664+
"configuration": "english"
665+
},
666+
"splitter": {
667+
"model": "recursive_character"
668+
}
669+
}
670+
})
671+
.into(),
672+
),
673+
)?;
674+
collection.add_pipeline(&mut pipeline).await?;
675+
let documents = generate_dummy_documents(4);
676+
collection
677+
.upsert_documents(documents[..2].to_owned(), None)
678+
.await?;
679+
let status = pipeline.get_status().await?;
680+
assert_eq!(
681+
status.0,
682+
json!({
683+
"title": {
684+
"chunks": {
685+
"not_synced": 0,
686+
"synced": 2,
687+
"total": 2
688+
},
689+
"embeddings": {
690+
"not_synced": 0,
691+
"synced": 2,
692+
"total": 2
693+
},
694+
"tsvectors": {
695+
"not_synced": 0,
696+
"synced": 2,
697+
"total": 2
698+
},
699+
}
700+
})
701+
);
702+
collection.disable_pipeline(&mut pipeline).await?;
703+
collection
704+
.upsert_documents(documents[2..4].to_owned(), None)
705+
.await?;
706+
let status = pipeline.get_status().await?;
707+
assert_eq!(
708+
status.0,
709+
json!({
710+
"title": {
711+
"chunks": {
712+
"not_synced": 2,
713+
"synced": 2,
714+
"total": 4
715+
},
716+
"embeddings": {
717+
"not_synced": 0,
718+
"synced": 2,
719+
"total": 2
720+
},
721+
"tsvectors": {
722+
"not_synced": 0,
723+
"synced": 2,
724+
"total": 2
725+
},
726+
}
727+
})
728+
);
729+
collection.enable_pipeline(&mut pipeline).await?;
730+
let status = pipeline.get_status().await?;
731+
assert_eq!(
732+
status.0,
733+
json!({
734+
"title": {
735+
"chunks": {
736+
"not_synced": 0,
737+
"synced": 4,
738+
"total": 4
739+
},
740+
"embeddings": {
741+
"not_synced": 0,
742+
"synced": 4,
743+
"total": 4
744+
},
745+
"tsvectors": {
746+
"not_synced": 0,
747+
"synced": 4,
748+
"total": 4
749+
},
750+
}
751+
})
752+
);
753+
collection.archive().await?;
754+
Ok(())
755+
}
756+
757+
#[sqlx::test]
758+
async fn can_specify_custom_hnsw_parameters_for_pipelines() -> anyhow::Result<()> {
759+
internal_init_logger(None, None).ok();
760+
let collection_name = "test_r_c_cschpfp_4";
761+
let mut collection = Collection::new(collection_name, None);
762+
let pipeline_name = "test_r_p_cschpfp_0";
763+
let mut pipeline = MultiFieldPipeline::new(
764+
pipeline_name,
765+
Some(
766+
json!({
767+
"title": {
768+
"embed": {
769+
"model": "intfloat/e5-small",
770+
"hnsw": {
771+
"m": 100,
772+
"ef_construction": 200
773+
}
774+
}
775+
}
776+
})
777+
.into(),
778+
),
779+
)?;
780+
collection.add_pipeline(&mut pipeline).await?;
781+
let schema = format!("{collection_name}_{pipeline_name}");
782+
let full_embeddings_table_name = format!("{schema}.title_embeddings");
783+
let embeddings_table_name = full_embeddings_table_name.split('.').collect::<Vec<_>>()[1];
784+
let pool = get_or_initialize_pool(&None).await?;
785+
let results: Vec<(String, String)> = sqlx::query_as(&query_builder!(
786+
"select indexname, indexdef from pg_indexes where tablename = '%d' and schemaname = '%d'",
787+
embeddings_table_name,
788+
schema
789+
)).fetch_all(&pool).await?;
790+
let names = results.iter().map(|(name, _)| name).collect::<Vec<_>>();
791+
let definitions = results
792+
.iter()
793+
.map(|(_, definition)| definition)
794+
.collect::<Vec<_>>();
795+
assert!(names.contains(&&"title_pipeline_embedding_hnsw_vector_index".to_string()));
796+
assert!(definitions.contains(&&format!("CREATE INDEX title_pipeline_embedding_hnsw_vector_index ON {full_embeddings_table_name} USING hnsw (embedding vector_cosine_ops) WITH (m='100', ef_construction='200')")));
797+
collection.archive().await?;
798+
Ok(())
799+
}
800+
649801
///////////////////////////////
650802
// Searches ///////////////////
651803
///////////////////////////////
@@ -959,99 +1111,6 @@ mod tests {
9591111
Ok(())
9601112
}
9611113

962-
// #[sqlx::test]
963-
// async fn can_specify_custom_hnsw_parameters_for_pipelines() -> anyhow::Result<()> {
964-
// internal_init_logger(None, None).ok();
965-
// let model = Model::default();
966-
// let splitter = Splitter::default();
967-
// let mut pipeline = Pipeline::new(
968-
// "test_r_p_cschpfp_0",
969-
// Some(model),
970-
// Some(splitter),
971-
// Some(
972-
// serde_json::json!({
973-
// "hnsw": {
974-
// "m": 100,
975-
// "ef_construction": 200
976-
// }
977-
// })
978-
// .into(),
979-
// ),
980-
// );
981-
// let collection_name = "test_r_c_cschpfp_1";
982-
// let mut collection = Collection::new(collection_name, None);
983-
// collection.add_pipeline(&mut pipeline).await?;
984-
// let full_embeddings_table_name = pipeline.create_or_get_embeddings_table().await?;
985-
// let embeddings_table_name = full_embeddings_table_name.split('.').collect::<Vec<_>>()[1];
986-
// let pool = get_or_initialize_pool(&None).await?;
987-
// let results: Vec<(String, String)> = sqlx::query_as(&query_builder!(
988-
// "select indexname, indexdef from pg_indexes where tablename = '%d' and schemaname = '%d'",
989-
// embeddings_table_name,
990-
// collection_name
991-
// )).fetch_all(&pool).await?;
992-
// let names = results.iter().map(|(name, _)| name).collect::<Vec<_>>();
993-
// let definitions = results
994-
// .iter()
995-
// .map(|(_, definition)| definition)
996-
// .collect::<Vec<_>>();
997-
// assert!(names.contains(&&format!("{}_pipeline_hnsw_vector_index", pipeline.name)));
998-
// assert!(definitions.contains(&&format!("CREATE INDEX {}_pipeline_hnsw_vector_index ON {} USING hnsw (embedding vector_cosine_ops) WITH (m='100', ef_construction='200')", pipeline.name, full_embeddings_table_name)));
999-
// Ok(())
1000-
// }
1001-
1002-
// #[sqlx::test]
1003-
// async fn sync_multiple_pipelines() -> anyhow::Result<()> {
1004-
// internal_init_logger(None, None).ok();
1005-
// let model = Model::default();
1006-
// let splitter = Splitter::default();
1007-
// let mut pipeline1 = Pipeline::new(
1008-
// "test_r_p_smp_0",
1009-
// Some(model.clone()),
1010-
// Some(splitter.clone()),
1011-
// Some(
1012-
// serde_json::json!({
1013-
// "full_text_search": {
1014-
// "active": true,
1015-
// "configuration": "english"
1016-
// }
1017-
// })
1018-
// .into(),
1019-
// ),
1020-
// );
1021-
// let mut pipeline2 = Pipeline::new(
1022-
// "test_r_p_smp_1",
1023-
// Some(model),
1024-
// Some(splitter),
1025-
// Some(
1026-
// serde_json::json!({
1027-
// "full_text_search": {
1028-
// "active": true,
1029-
// "configuration": "english"
1030-
// }
1031-
// })
1032-
// .into(),
1033-
// ),
1034-
// );
1035-
// let mut collection = Collection::new("test_r_c_smp_3", None);
1036-
// collection.add_pipeline(&mut pipeline1).await?;
1037-
// collection.add_pipeline(&mut pipeline2).await?;
1038-
// collection
1039-
// .upsert_documents(generate_dummy_documents(3), None)
1040-
// .await?;
1041-
// let status_1 = pipeline1.get_status().await?;
1042-
// let status_2 = pipeline2.get_status().await?;
1043-
// assert!(
1044-
// status_1.chunks_status.synced == status_1.chunks_status.total
1045-
// && status_1.chunks_status.not_synced == 0
1046-
// );
1047-
// assert!(
1048-
// status_2.chunks_status.synced == status_2.chunks_status.total
1049-
// && status_2.chunks_status.not_synced == 0
1050-
// );
1051-
// collection.archive().await?;
1052-
// Ok(())
1053-
// }
1054-
10551114
///////////////////////////////
10561115
// Working With Documents /////
10571116
///////////////////////////////
@@ -1532,6 +1591,74 @@ mod tests {
15321591
Ok(())
15331592
}
15341593

1594+
///////////////////////////////
1595+
// Pipeline -> MultiFieldPIpeline
1596+
///////////////////////////////
1597+
1598+
#[test]
1599+
fn pipeline_to_multi_field_pipeline() -> anyhow::Result<()> {
1600+
let model = Model::new(
1601+
Some("test_model".to_string()),
1602+
Some("pgml".to_string()),
1603+
Some(
1604+
json!({
1605+
"test_parameter": 10
1606+
})
1607+
.into(),
1608+
),
1609+
);
1610+
let splitter = Splitter::new(
1611+
Some("test_splitter".to_string()),
1612+
Some(
1613+
json!({
1614+
"test_parameter": 11
1615+
})
1616+
.into(),
1617+
),
1618+
);
1619+
let parameters = json!({
1620+
"full_text_search": {
1621+
"active": true,
1622+
"configuration": "test_configuration"
1623+
},
1624+
"hnsw": {
1625+
"m": 16,
1626+
"ef_construction": 64
1627+
}
1628+
});
1629+
let multi_field_pipeline = Pipeline::new(
1630+
"test_name",
1631+
Some(model),
1632+
Some(splitter),
1633+
Some(parameters.into()),
1634+
);
1635+
let schema = json!({
1636+
"text": {
1637+
"splitter": {
1638+
"model": "test_splitter",
1639+
"parameters": {
1640+
"test_parameter": 11
1641+
}
1642+
},
1643+
"embed": {
1644+
"model": "test_model",
1645+
"parameters": {
1646+
"test_parameter": 10
1647+
},
1648+
"hnsw": {
1649+
"m": 16,
1650+
"ef_construction": 64
1651+
}
1652+
},
1653+
"full_text_search": {
1654+
"configuration": "test_configuration"
1655+
}
1656+
}
1657+
});
1658+
assert_eq!(schema, multi_field_pipeline.schema.unwrap().0);
1659+
Ok(())
1660+
}
1661+
15351662
///////////////////////////////
15361663
// ER Diagram /////////////////
15371664
///////////////////////////////

0 commit comments

Comments
 (0)