diff --git a/pgml-extension/src/orm/dataset.rs b/pgml-extension/src/orm/dataset.rs index 7c88b4575..9e22ef0ae 100644 --- a/pgml-extension/src/orm/dataset.rs +++ b/pgml-extension/src/orm/dataset.rs @@ -95,7 +95,7 @@ impl Display for TextDataset { fn drop_table_if_exists(table_name: &str) { // Avoid the existence for DROP TABLE IF EXISTS warning by checking the schema for the table first let table_count = Spi::get_one_with_args::("SELECT COUNT(*) FROM information_schema.tables WHERE table_name = $1 AND table_schema = 'pgml'", vec![ - (PgBuiltInOids::TEXTOID.oid(), table_name.clone().into_datum()) + (PgBuiltInOids::TEXTOID.oid(), table_name.into_datum()) ]).unwrap().unwrap(); if table_count == 1 { Spi::run(&format!(r#"DROP TABLE pgml.{table_name} CASCADE"#)).unwrap(); diff --git a/pgml-extension/src/orm/model.rs b/pgml-extension/src/orm/model.rs index 370ae7b02..da1940f60 100644 --- a/pgml-extension/src/orm/model.rs +++ b/pgml-extension/src/orm/model.rs @@ -96,8 +96,8 @@ impl Model { let dataset = snapshot.tabular_dataset(); let status = Status::in_progress; // Create the model record. - Spi::connect(|client| { - let result = client.select(" + Spi::connect(|mut client| { + let result = client.update(" INSERT INTO pgml.models (project_id, snapshot_id, algorithm, runtime, hyperparams, status, search, search_params, search_args, num_features) VALUES ($1, $2, cast($3 AS pgml.algorithm), cast($4 AS pgml.runtime), $5, cast($6 as pgml.status), $7, $8, $9, $10) RETURNING id, project_id, snapshot_id, algorithm, runtime, hyperparams, status, metrics, search, search_params, search_args, created_at, updated_at;", @@ -168,8 +168,8 @@ impl Model { let dataset = snapshot.text_dataset(); // Create the model record. - Spi::connect(|client| { - let result = client.select(" + Spi::connect(|mut client| { + let result = client.update(" INSERT INTO pgml.models (project_id, snapshot_id, algorithm, runtime, hyperparams, status, search, search_params, search_args, num_features) VALUES ($1, $2, cast($3 AS pgml.algorithm), cast($4 AS pgml.runtime), $5, cast($6 as pgml.status), $7, $8, $9, $10) RETURNING id, project_id, snapshot_id, algorithm, runtime::TEXT, hyperparams, status, metrics, search, search_params, search_args, created_at, updated_at;", diff --git a/pgml-extension/src/orm/snapshot.rs b/pgml-extension/src/orm/snapshot.rs index 6cf6f776c..85f697508 100644 --- a/pgml-extension/src/orm/snapshot.rs +++ b/pgml-extension/src/orm/snapshot.rs @@ -508,7 +508,7 @@ impl Snapshot { let preprocessors: HashMap = serde_json::from_value(preprocess.0).expect("is valid"); - Spi::connect(|client| { + Spi::connect(|mut client| { let mut columns: Vec = Vec::new(); client.select("SELECT column_name::TEXT, udt_name::TEXT, is_nullable::BOOLEAN, ordinal_position::INTEGER FROM information_schema.columns WHERE table_schema = $1 AND table_name = $2 ORDER BY ordinal_position ASC", None, @@ -587,7 +587,7 @@ impl Snapshot { } } - let result = client.select("INSERT INTO pgml.snapshots (relation_name, y_column_name, test_size, test_sampling, status, columns, materialized) VALUES ($1, $2, $3, $4::pgml.sampling, $5::pgml.status, $6, $7) RETURNING id, relation_name, y_column_name, test_size, test_sampling::TEXT, status::TEXT, columns, analysis, created_at, updated_at;", + let result = client.update("INSERT INTO pgml.snapshots (relation_name, y_column_name, test_size, test_sampling, status, columns, materialized) VALUES ($1, $2, $3, $4::pgml.sampling, $5::pgml.status, $6, $7) RETURNING id, relation_name, y_column_name, test_size, test_sampling::TEXT, status::TEXT, columns, analysis, created_at, updated_at;", Some(1), Some(vec![ (PgBuiltInOids::TEXTOID.oid(), relation_name.into_datum()), @@ -623,7 +623,7 @@ impl Snapshot { if s.test_sampling == Sampling::random { sql += " ORDER BY random()"; } - client.select(&sql, None, None).unwrap(); + client.update(&sql, None, None).unwrap(); } snapshot = Some(s); });