From 0574574397373f56fbe15e5c77eb20017a8a59dc Mon Sep 17 00:00:00 2001 From: Adam Hendel Date: Tue, 16 Jan 2024 22:37:31 -0600 Subject: [PATCH 1/9] handle model deploy when no metrics to compare --- pgml-extension/src/api.rs | 59 ++++++++++++++++++++++++--------------- 1 file changed, 37 insertions(+), 22 deletions(-) diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs index 380bfb330..24164a9f1 100644 --- a/pgml-extension/src/api.rs +++ b/pgml-extension/src/api.rs @@ -264,37 +264,52 @@ fn train_joint( ); let mut deploy = true; + match automatic_deploy { // Deploy only if metrics are better than previous model. Some(true) | None => { if let Ok(Some(deployed_metrics)) = deployed_metrics { - let deployed_metrics = deployed_metrics.0.as_object().unwrap(); - let deployed_metric = deployed_metrics - .get(&project.task.default_target_metric()) - .unwrap() - .as_f64() - .unwrap(); - info!( - "Comparing to deployed model {}: {:?}", - project.task.default_target_metric(), - deployed_metric - ); - if project.task.value_is_better( - deployed_metric, - new_metrics - .get(&project.task.default_target_metric()) - .unwrap() - .as_f64() - .unwrap(), - ) { + if let Some(deployed_metrics_obj) = deployed_metrics.0.as_object() { + let default_target_metric = project.task.default_target_metric(); + let deployed_metric = deployed_metrics_obj + .get(&default_target_metric) + .and_then(|v| v.as_f64()); + info!( + "Comparing to deployed model {}: {:?}", + default_target_metric, deployed_metric + ); + if let (Some(deployed_metric_value), Some(new_metric_value)) = ( + deployed_metric, + new_metrics.get(&default_target_metric).and_then(|v| v.as_f64()), + ) { + if project.task.value_is_better(deployed_metric_value, new_metric_value) { + warning!( + "New model's {} is not better than old model: {} is not better than {}", + &project.task.default_target_metric(), + new_metric_value, + deployed_metric_value + ); + deploy = false; + } + } else { + warning!("Failed to retrieve or parse deployed/new metrics for {}. Ensure train/test split results in both positive and negative label records.", + &project.task.default_target_metric()); + deploy = false; + } + } else { + warning!("Failed to parse deployed model metrics. Ensure train/test split results in both positive and negative label records."); deploy = false; } + } else { + warning!("Failed to obtain currently deployed model metrics. Check if the deployed model metrics are available and correctly formatted."); + deploy = false; } } - - Some(false) => deploy = false, + Some(false) => { + warning!("Automatic deployment disabled via configuration."); + deploy = false; + } }; - if deploy { project.deploy(model.id, Strategy::new_score); } else { From 667f68ec88ab530e96828a11d25941080d7783a9 Mon Sep 17 00:00:00 2001 From: Adam Hendel Date: Tue, 16 Jan 2024 22:40:31 -0600 Subject: [PATCH 2/9] better warn msg --- pgml-extension/src/api.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs index 24164a9f1..f2308a139 100644 --- a/pgml-extension/src/api.rs +++ b/pgml-extension/src/api.rs @@ -284,7 +284,7 @@ fn train_joint( ) { if project.task.value_is_better(deployed_metric_value, new_metric_value) { warning!( - "New model's {} is not better than old model: {} is not better than {}", + "New model's {} is not better than current model. New: {}, Current {}", &project.task.default_target_metric(), new_metric_value, deployed_metric_value From d0ff7251d381ac575a7376ded27022431d21cf17 Mon Sep 17 00:00:00 2001 From: Adam Hendel Date: Wed, 17 Jan 2024 15:17:24 +0000 Subject: [PATCH 3/9] fix first run case --- pgml-extension/src/api.rs | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs index f2308a139..ada756c7a 100644 --- a/pgml-extension/src/api.rs +++ b/pgml-extension/src/api.rs @@ -266,7 +266,7 @@ fn train_joint( let mut deploy = true; match automatic_deploy { - // Deploy only if metrics are better than previous model. + // Deploy only if metrics are better than previous model, or if its the first model Some(true) | None => { if let Ok(Some(deployed_metrics)) = deployed_metrics { if let Some(deployed_metrics_obj) = deployed_metrics.0.as_object() { @@ -297,12 +297,9 @@ fn train_joint( deploy = false; } } else { - warning!("Failed to parse deployed model metrics. Ensure train/test split results in both positive and negative label records."); + warning!("Failed to parse deployed model metrics. Check data types of model metadata on pgml.models.metrics"); deploy = false; } - } else { - warning!("Failed to obtain currently deployed model metrics. Check if the deployed model metrics are available and correctly formatted."); - deploy = false; } } Some(false) => { From accaab001f02130048871bf2ab6b407660683452 Mon Sep 17 00:00:00 2001 From: Adam Hendel Date: Wed, 17 Jan 2024 22:12:49 +0000 Subject: [PATCH 4/9] impl stratified --- pgml-extension/src/orm/sampling.rs | 109 +++++++++++++++++++++++++++++ pgml-extension/src/orm/snapshot.rs | 34 ++------- 2 files changed, 114 insertions(+), 29 deletions(-) diff --git a/pgml-extension/src/orm/sampling.rs b/pgml-extension/src/orm/sampling.rs index 6bb3d7b5a..442f683ac 100644 --- a/pgml-extension/src/orm/sampling.rs +++ b/pgml-extension/src/orm/sampling.rs @@ -1,11 +1,14 @@ use pgrx::*; use serde::Deserialize; +use super::snapshot::Column; + #[derive(PostgresEnum, Copy, Clone, Eq, PartialEq, Debug, Deserialize)] #[allow(non_camel_case_types)] pub enum Sampling { random, last, + stratified_random, } impl std::str::FromStr for Sampling { @@ -15,6 +18,7 @@ impl std::str::FromStr for Sampling { match input { "random" => Ok(Sampling::random), "last" => Ok(Sampling::last), + "stratified_random" => Ok(Sampling::stratified_random), _ => Err(()), } } @@ -25,6 +29,111 @@ impl std::string::ToString for Sampling { match *self { Sampling::random => "random".to_string(), Sampling::last => "last".to_string(), + Sampling::stratified_random => "stratified_random".to_string(), } } } + +impl Sampling { + // Implementing the sampling strategy in SQL + // Effectively orders the table according to the train/test split + // e.g. first N rows are train, last M rows are test + // where M is configured by the user + pub fn get_sql(&self, relation_name: &str, y_column_names: Vec) -> String { + let col_string = y_column_names + .iter() + .map(|c| c.quoted_name()) + .collect::>() + .join(", "); + match *self { + Sampling::random => { + format!("SELECT {col_string} FROM {relation_name} ORDER BY RANDOM()") + } + Sampling::last => { + format!("SELECT {col_string} FROM {relation_name}") + } + Sampling::stratified_random => { + format!( + " + SELECT * + FROM ( + SELECT + *, + ROW_NUMBER() OVER(PARTITION BY {col_string} ORDER BY RANDOM()) AS rn + FROM {relation_name} + ) AS subquery + ORDER BY rn, RANDOM(); + " + ) + } + } + } +} + +#[cfg(test)] +mod tests { + use crate::orm::snapshot::{Preprocessor, Statistics}; + + use super::*; + + fn get_column_fixtures() -> Vec { + vec![ + Column { + name: "col1".to_string(), + pg_type: "text".to_string(), + nullable: false, + label: true, + position: 0, + size: 0, + array: false, + preprocessor: Preprocessor::default(), + statistics: Statistics::default(), + }, + Column { + name: "col2".to_string(), + pg_type: "text".to_string(), + nullable: false, + label: true, + position: 0, + size: 0, + array: false, + preprocessor: Preprocessor::default(), + statistics: Statistics::default(), + }, + ] + } + + #[test] + fn test_get_sql_random_sampling() { + let sampling = Sampling::random; + let columns = get_column_fixtures(); + let sql = sampling.get_sql("my_table", columns); + assert_eq!(sql, "SELECT \"col1\", \"col2\" FROM my_table ORDER BY RANDOM()"); + } + + #[test] + fn test_get_sql_last_sampling() { + let sampling = Sampling::last; + let columns = get_column_fixtures(); + let sql = sampling.get_sql("my_table", columns); + assert_eq!(sql, "SELECT \"col1\", \"col2\" FROM my_table"); + } + + #[test] + fn test_get_sql_stratified_random_sampling() { + let sampling = Sampling::stratified_random; + let columns = get_column_fixtures(); + let sql = sampling.get_sql("my_table", columns); + let expected_sql = " + SELECT * + FROM ( + SELECT + *, + ROW_NUMBER() OVER(PARTITION BY \"col1\", \"col2\" ORDER BY RANDOM()) AS rn + FROM my_table + ) AS subquery + ORDER BY rn, RANDOM(); + "; + assert_eq!(sql, expected_sql); + } +} diff --git a/pgml-extension/src/orm/snapshot.rs b/pgml-extension/src/orm/snapshot.rs index 6a5973148..9b478fe8a 100644 --- a/pgml-extension/src/orm/snapshot.rs +++ b/pgml-extension/src/orm/snapshot.rs @@ -119,7 +119,7 @@ pub(crate) struct Preprocessor { } #[derive(Debug, PartialEq, Serialize, Deserialize, Clone)] -pub(crate) struct Column { +pub struct Column { pub(crate) name: String, pub(crate) pg_type: String, pub(crate) nullable: bool, @@ -147,7 +147,7 @@ impl Column { ) } - fn quoted_name(&self) -> String { + pub(crate) fn quoted_name(&self) -> String { format!(r#""{}""#, self.name) } @@ -608,13 +608,8 @@ impl Snapshot { }; if materialized { - let mut sql = format!( - r#"CREATE TABLE "pgml"."snapshot_{}" AS SELECT * FROM {}"#, - s.id, s.relation_name - ); - if s.test_sampling == Sampling::random { - sql += " ORDER BY random()"; - } + let sampled_query = s.test_sampling.get_sql(&s.relation_name, s.columns.clone()); + let sql = format!(r#"CREATE TABLE "pgml"."snapshot_{}" AS {}"#, s.id, sampled_query); client.update(&sql, None, None).unwrap(); } snapshot = Some(s); @@ -742,26 +737,7 @@ impl Snapshot { } fn select_sql(&self) -> String { - format!( - "SELECT {} FROM {} {}", - self.columns - .iter() - .map(|c| c.quoted_name()) - .collect::>() - .join(", "), - self.relation_name(), - match self.materialized { - // If the snapshot is materialized, we already randomized it. - true => "", - false => { - if self.test_sampling == Sampling::random { - "ORDER BY random()" - } else { - "" - } - } - }, - ) + self.test_sampling.get_sql(&self.relation_name(), self.columns.clone()) } fn train_test_split(&self, num_rows: usize) -> (usize, usize) { From 58310256fae05912aedda7fda7c0aee1ee10f9ea Mon Sep 17 00:00:00 2001 From: Adam Hendel Date: Wed, 17 Jan 2024 22:29:24 +0000 Subject: [PATCH 5/9] handle case where exists has no metrics --- pgml-extension/src/api.rs | 36 +++++++++++++++++++++--------------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs index ada756c7a..1580de944 100644 --- a/pgml-extension/src/api.rs +++ b/pgml-extension/src/api.rs @@ -278,23 +278,29 @@ fn train_joint( "Comparing to deployed model {}: {:?}", default_target_metric, deployed_metric ); - if let (Some(deployed_metric_value), Some(new_metric_value)) = ( - deployed_metric, - new_metrics.get(&default_target_metric).and_then(|v| v.as_f64()), - ) { - if project.task.value_is_better(deployed_metric_value, new_metric_value) { - warning!( - "New model's {} is not better than current model. New: {}, Current {}", - &project.task.default_target_metric(), - new_metric_value, - deployed_metric_value - ); + let new_metric = new_metrics.get(&default_target_metric).and_then(|v| v.as_f64()); + + match (deployed_metric, new_metric) { + (Some(deployed), Some(new)) => { + // only compare metrics when both new and old model have metrics to compare + if project.task.value_is_better(deployed, new) { + warning!( + "New model's {} is not better than current model. New: {}, Current {}", + &default_target_metric, + new, + deployed + ); + deploy = false; + } + } + (None, None) => { + warning!("No metrics available for both deployed and new model. Deploying new model.") + } + (Some(_deployed), None) => { + warning!("No metrics for new model. Retaining old model."); deploy = false; } - } else { - warning!("Failed to retrieve or parse deployed/new metrics for {}. Ensure train/test split results in both positive and negative label records.", - &project.task.default_target_metric()); - deploy = false; + (None, Some(_new)) => warning!("No metrics for deployed model. Deploying new model."), } } else { warning!("Failed to parse deployed model metrics. Check data types of model metadata on pgml.models.metrics"); From 7710b13378adf5431cdc7afbec92368d55b590dc Mon Sep 17 00:00:00 2001 From: Adam Hendel Date: Wed, 17 Jan 2024 22:44:26 +0000 Subject: [PATCH 6/9] change default samping to stratified --- pgml-extension/src/api.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs index 1580de944..17e3af172 100644 --- a/pgml-extension/src/api.rs +++ b/pgml-extension/src/api.rs @@ -100,7 +100,7 @@ fn train( search_params: default!(JsonB, "'{}'"), search_args: default!(JsonB, "'{}'"), test_size: default!(f32, 0.25), - test_sampling: default!(Sampling, "'last'"), + test_sampling: default!(Sampling, "'stratified_random'"), runtime: default!(Option, "NULL"), automatic_deploy: default!(Option, true), materialize_snapshot: default!(bool, false), @@ -146,7 +146,7 @@ fn train_joint( search_params: default!(JsonB, "'{}'"), search_args: default!(JsonB, "'{}'"), test_size: default!(f32, 0.25), - test_sampling: default!(Sampling, "'last'"), + test_sampling: default!(Sampling, "'stratified_random'"), runtime: default!(Option, "NULL"), automatic_deploy: default!(Option, true), materialize_snapshot: default!(bool, false), @@ -535,7 +535,7 @@ fn snapshot( relation_name: &str, y_column_name: &str, test_size: default!(f32, 0.25), - test_sampling: default!(Sampling, "'last'"), + test_sampling: default!(Sampling, "'stratified_random'"), preprocess: default!(JsonB, "'{}'"), ) -> TableIterator<'static, (name!(relation, String), name!(y_column_name, String))> { Snapshot::create( @@ -807,7 +807,7 @@ fn tune( model_name: default!(Option<&str>, "NULL"), hyperparams: default!(JsonB, "'{}'"), test_size: default!(f32, 0.25), - test_sampling: default!(Sampling, "'last'"), + test_sampling: default!(Sampling, "'stratified_random'"), automatic_deploy: default!(Option, true), materialize_snapshot: default!(bool, false), ) -> TableIterator< From 9b2c44aa1d716aa6111ed22c5a2c7669cbde92de Mon Sep 17 00:00:00 2001 From: Adam Hendel Date: Fri, 19 Jan 2024 04:19:31 +0000 Subject: [PATCH 7/9] no rando when already materialized --- pgml-extension/src/orm/sampling.rs | 8 ++++---- pgml-extension/src/orm/snapshot.rs | 15 ++++++++++++++- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/pgml-extension/src/orm/sampling.rs b/pgml-extension/src/orm/sampling.rs index 442f683ac..e0b3bf238 100644 --- a/pgml-extension/src/orm/sampling.rs +++ b/pgml-extension/src/orm/sampling.rs @@ -47,10 +47,10 @@ impl Sampling { .join(", "); match *self { Sampling::random => { - format!("SELECT {col_string} FROM {relation_name} ORDER BY RANDOM()") + format!("SELECT * FROM {relation_name} ORDER BY RANDOM()") } Sampling::last => { - format!("SELECT {col_string} FROM {relation_name}") + format!("SELECT * FROM {relation_name}") } Sampling::stratified_random => { format!( @@ -108,7 +108,7 @@ mod tests { let sampling = Sampling::random; let columns = get_column_fixtures(); let sql = sampling.get_sql("my_table", columns); - assert_eq!(sql, "SELECT \"col1\", \"col2\" FROM my_table ORDER BY RANDOM()"); + assert_eq!(sql, "SELECT * FROM my_table ORDER BY RANDOM()"); } #[test] @@ -116,7 +116,7 @@ mod tests { let sampling = Sampling::last; let columns = get_column_fixtures(); let sql = sampling.get_sql("my_table", columns); - assert_eq!(sql, "SELECT \"col1\", \"col2\" FROM my_table"); + assert_eq!(sql, "SELECT * FROM my_table"); } #[test] diff --git a/pgml-extension/src/orm/snapshot.rs b/pgml-extension/src/orm/snapshot.rs index 9b478fe8a..12b9e9813 100644 --- a/pgml-extension/src/orm/snapshot.rs +++ b/pgml-extension/src/orm/snapshot.rs @@ -737,7 +737,20 @@ impl Snapshot { } fn select_sql(&self) -> String { - self.test_sampling.get_sql(&self.relation_name(), self.columns.clone()) + match self.materialized { + true => { + format!( + "SELECT {} FROM {}", + self.columns + .iter() + .map(|c| c.quoted_name()) + .collect::>() + .join(", "), + self.relation_name() + ) + } + false => self.test_sampling.get_sql(&self.relation_name(), self.columns.clone()), + } } fn train_test_split(&self, num_rows: usize) -> (usize, usize) { From 489bce180212e8e8692b76b5a00b2231036c8621 Mon Sep 17 00:00:00 2001 From: Adam Hendel Date: Fri, 19 Jan 2024 04:32:03 +0000 Subject: [PATCH 8/9] update enum and function signatures --- pgml-extension/sql/pgml--2.8.1--2.8.2.sql | 106 ++++++++++++++++++++++ pgml-extension/src/api.rs | 8 +- pgml-extension/src/orm/sampling.rs | 12 +-- 3 files changed, 116 insertions(+), 10 deletions(-) diff --git a/pgml-extension/sql/pgml--2.8.1--2.8.2.sql b/pgml-extension/sql/pgml--2.8.1--2.8.2.sql index 2c6264fb9..3f00d9b71 100644 --- a/pgml-extension/sql/pgml--2.8.1--2.8.2.sql +++ b/pgml-extension/sql/pgml--2.8.1--2.8.2.sql @@ -25,3 +25,109 @@ CREATE FUNCTION pgml."deploy"( AS 'MODULE_PATHNAME', 'deploy_strategy_wrapper'; ALTER TYPE pgml.strategy ADD VALUE 'specific'; + +-- src/orm/sampling.rs:6 +-- pgml::orm::sampling::Sampling +DROP TYPE IF EXISTS pgml.Sampling; +CREATE TYPE pgml.Sampling AS ENUM ( + 'random', + 'last', + 'stratified' +); + +-- src/api.rs:534 +-- pgml::api::snapshot +DROP FUNCTION IF EXISTS pgml."snapshot"(text, text, real, pgml.Sampling, jsonb); +CREATE FUNCTION pgml."snapshot"( + "relation_name" TEXT, /* &str */ + "y_column_name" TEXT, /* &str */ + "test_size" real DEFAULT 0.25, /* f32 */ + "test_sampling" pgml.Sampling DEFAULT 'stratified', /* pgml::orm::sampling::Sampling */ + "preprocess" jsonb DEFAULT '{}' /* pgrx::datum::json::JsonB */ +) RETURNS TABLE ( + "relation" TEXT, /* alloc::string::String */ + "y_column_name" TEXT /* alloc::string::String */ +) +STRICT +LANGUAGE c /* Rust */ +AS 'MODULE_PATHNAME', 'snapshot_wrapper'; + +-- src/api.rs:802 +-- pgml::api::tune +DROP FUNCTION IF EXISTS pgml."tune"(text, text, text, text, text, jsonb, real, pgml.Sampling, bool, bool); +CREATE FUNCTION pgml."tune"( + "project_name" TEXT, /* &str */ + "task" TEXT DEFAULT NULL, /* core::option::Option<&str> */ + "relation_name" TEXT DEFAULT NULL, /* core::option::Option<&str> */ + "y_column_name" TEXT DEFAULT NULL, /* core::option::Option<&str> */ + "model_name" TEXT DEFAULT NULL, /* core::option::Option<&str> */ + "hyperparams" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */ + "test_size" real DEFAULT 0.25, /* f32 */ + "test_sampling" pgml.Sampling DEFAULT 'stratified', /* pgml::orm::sampling::Sampling */ + "automatic_deploy" bool DEFAULT true, /* core::option::Option */ + "materialize_snapshot" bool DEFAULT false /* bool */ +) RETURNS TABLE ( + "status" TEXT, /* alloc::string::String */ + "task" TEXT, /* alloc::string::String */ + "algorithm" TEXT, /* alloc::string::String */ + "deployed" bool /* bool */ +) +PARALLEL SAFE +LANGUAGE c /* Rust */ +AS 'MODULE_PATHNAME', 'tune_wrapper'; + +-- src/api.rs:92 +-- pgml::api::train +DROP FUNCTION IF EXISTS pgml."train"(text, text, text, text, pgml.Algorithm, jsonb, pgml.Search, jsonb, jsonb, real, pgml.Sampling, pgml.Runtime, bool, bool, jsonb); +CREATE FUNCTION pgml."train"( + "project_name" TEXT, /* &str */ + "task" TEXT DEFAULT NULL, /* core::option::Option<&str> */ + "relation_name" TEXT DEFAULT NULL, /* core::option::Option<&str> */ + "y_column_name" TEXT DEFAULT NULL, /* core::option::Option<&str> */ + "algorithm" pgml.Algorithm DEFAULT 'linear', /* pgml::orm::algorithm::Algorithm */ + "hyperparams" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */ + "search" pgml.Search DEFAULT NULL, /* core::option::Option */ + "search_params" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */ + "search_args" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */ + "test_size" real DEFAULT 0.25, /* f32 */ + "test_sampling" pgml.Sampling DEFAULT 'stratified', /* pgml::orm::sampling::Sampling */ + "runtime" pgml.Runtime DEFAULT NULL, /* core::option::Option */ + "automatic_deploy" bool DEFAULT true, /* core::option::Option */ + "materialize_snapshot" bool DEFAULT false, /* bool */ + "preprocess" jsonb DEFAULT '{}' /* pgrx::datum::json::JsonB */ +) RETURNS TABLE ( + "project" TEXT, /* alloc::string::String */ + "task" TEXT, /* alloc::string::String */ + "algorithm" TEXT, /* alloc::string::String */ + "deployed" bool /* bool */ +) +LANGUAGE c /* Rust */ +AS 'MODULE_PATHNAME', 'train_wrapper'; + +-- src/api.rs:138 +-- pgml::api::train_joint +DROP FUNCTION IF EXISTS pgml."train_joint"(text, text, text, text, pgml.Algorithm, jsonb, pgml.Search, jsonb, jsonb, real, pgml.Sampling, pgml.Runtime, bool, bool, jsonb); +CREATE FUNCTION pgml."train_joint"( + "project_name" TEXT, /* &str */ + "task" TEXT DEFAULT NULL, /* core::option::Option<&str> */ + "relation_name" TEXT DEFAULT NULL, /* core::option::Option<&str> */ + "y_column_name" TEXT[] DEFAULT NULL, /* core::option::Option> */ + "algorithm" pgml.Algorithm DEFAULT 'linear', /* pgml::orm::algorithm::Algorithm */ + "hyperparams" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */ + "search" pgml.Search DEFAULT NULL, /* core::option::Option */ + "search_params" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */ + "search_args" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */ + "test_size" real DEFAULT 0.25, /* f32 */ + "test_sampling" pgml.Sampling DEFAULT 'stratified', /* pgml::orm::sampling::Sampling */ + "runtime" pgml.Runtime DEFAULT NULL, /* core::option::Option */ + "automatic_deploy" bool DEFAULT true, /* core::option::Option */ + "materialize_snapshot" bool DEFAULT false, /* bool */ + "preprocess" jsonb DEFAULT '{}' /* pgrx::datum::json::JsonB */ +) RETURNS TABLE ( + "project" TEXT, /* alloc::string::String */ + "task" TEXT, /* alloc::string::String */ + "algorithm" TEXT, /* alloc::string::String */ + "deployed" bool /* bool */ +) +LANGUAGE c /* Rust */ +AS 'MODULE_PATHNAME', 'train_joint_wrapper'; diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs index 17e3af172..7fd5012c8 100644 --- a/pgml-extension/src/api.rs +++ b/pgml-extension/src/api.rs @@ -100,7 +100,7 @@ fn train( search_params: default!(JsonB, "'{}'"), search_args: default!(JsonB, "'{}'"), test_size: default!(f32, 0.25), - test_sampling: default!(Sampling, "'stratified_random'"), + test_sampling: default!(Sampling, "'stratified'"), runtime: default!(Option, "NULL"), automatic_deploy: default!(Option, true), materialize_snapshot: default!(bool, false), @@ -146,7 +146,7 @@ fn train_joint( search_params: default!(JsonB, "'{}'"), search_args: default!(JsonB, "'{}'"), test_size: default!(f32, 0.25), - test_sampling: default!(Sampling, "'stratified_random'"), + test_sampling: default!(Sampling, "'stratified'"), runtime: default!(Option, "NULL"), automatic_deploy: default!(Option, true), materialize_snapshot: default!(bool, false), @@ -535,7 +535,7 @@ fn snapshot( relation_name: &str, y_column_name: &str, test_size: default!(f32, 0.25), - test_sampling: default!(Sampling, "'stratified_random'"), + test_sampling: default!(Sampling, "'stratified'"), preprocess: default!(JsonB, "'{}'"), ) -> TableIterator<'static, (name!(relation, String), name!(y_column_name, String))> { Snapshot::create( @@ -807,7 +807,7 @@ fn tune( model_name: default!(Option<&str>, "NULL"), hyperparams: default!(JsonB, "'{}'"), test_size: default!(f32, 0.25), - test_sampling: default!(Sampling, "'stratified_random'"), + test_sampling: default!(Sampling, "'stratified'"), automatic_deploy: default!(Option, true), materialize_snapshot: default!(bool, false), ) -> TableIterator< diff --git a/pgml-extension/src/orm/sampling.rs b/pgml-extension/src/orm/sampling.rs index e0b3bf238..2ecd66f5d 100644 --- a/pgml-extension/src/orm/sampling.rs +++ b/pgml-extension/src/orm/sampling.rs @@ -8,7 +8,7 @@ use super::snapshot::Column; pub enum Sampling { random, last, - stratified_random, + stratified, } impl std::str::FromStr for Sampling { @@ -18,7 +18,7 @@ impl std::str::FromStr for Sampling { match input { "random" => Ok(Sampling::random), "last" => Ok(Sampling::last), - "stratified_random" => Ok(Sampling::stratified_random), + "stratified" => Ok(Sampling::stratified), _ => Err(()), } } @@ -29,7 +29,7 @@ impl std::string::ToString for Sampling { match *self { Sampling::random => "random".to_string(), Sampling::last => "last".to_string(), - Sampling::stratified_random => "stratified_random".to_string(), + Sampling::stratified => "stratified".to_string(), } } } @@ -52,7 +52,7 @@ impl Sampling { Sampling::last => { format!("SELECT * FROM {relation_name}") } - Sampling::stratified_random => { + Sampling::stratified => { format!( " SELECT * @@ -120,8 +120,8 @@ mod tests { } #[test] - fn test_get_sql_stratified_random_sampling() { - let sampling = Sampling::stratified_random; + fn test_get_sql_stratified_sampling() { + let sampling = Sampling::stratified; let columns = get_column_fixtures(); let sql = sampling.get_sql("my_table", columns); let expected_sql = " From 0af75a8efdb6f287c09b2f965addbe32b4670427 Mon Sep 17 00:00:00 2001 From: Adam Hendel Date: Thu, 29 Feb 2024 15:20:47 +0000 Subject: [PATCH 9/9] add upgrade test --- .github/workflows/ci.yml | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e9b0b1412..e1859d8aa 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -47,7 +47,12 @@ jobs: if: steps.pgml_extension_changed.outputs.PGML_EXTENSION_CHANGED_FILES != '0' run: | git submodule update --init --recursive + - name: Get current version + id: current-version + run: echo "CI_BRANCH=$(git name-rev --name-only HEAD)" >> $GITHUB_OUTPUT - name: Run tests + env: + CI_BRANCH: ${{ steps.current-version.outputs.CI_BRANCH }} if: steps.pgml_extension_changed.outputs.PGML_EXTENSION_CHANGED_FILES != '0' run: | curl https://sh.rustup.rs -sSf | sh -s -- -y @@ -58,8 +63,13 @@ jobs: cargo pgrx init fi + git checkout master + echo "\q" | cargo pgrx run + psql -p 28816 -h localhost -d pgml -P pager -c "CREATE EXTENSION pgml;" + git checkout $CI_BRANCH + echo "\q" | cargo pgrx run + psql -p 28816 -h localhost -d pgml -P pager -c "ALTER EXTENSION pgml UPDATE;" cargo pgrx test - # cargo pgrx start # psql -p 28815 -h 127.0.0.1 -d pgml -P pager -f tests/test.sql # cargo pgrx stop