From f95b897c942a338bd67361c7923c164e3b63a969 Mon Sep 17 00:00:00 2001 From: Xing Guo Date: Mon, 23 Oct 2023 21:30:20 +0800 Subject: [PATCH] Mutable SQL commands should be used in Spi::update(). Commands like 'INSERT', 'CREATE TABLE' are modifying database objects. They should be used in 'Spi::update()'[1][2]. Besides, one of the unnecessary calling of '.clone()' is removed. [1]: https://www.postgresql.org/docs/current/spi-spi-execute.html [2]: https://github.com/pgcentralfoundation/pgrx/blob/64fbc2e2efb06212046082ffcde2e794ccb4d48a/pgrx/src/spi/client.rs#L60 Co-Authored-By: Chen Mulong --- pgml-extension/src/orm/dataset.rs | 2 +- pgml-extension/src/orm/model.rs | 8 ++++---- pgml-extension/src/orm/snapshot.rs | 6 +++--- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/pgml-extension/src/orm/dataset.rs b/pgml-extension/src/orm/dataset.rs index 7c88b4575..9e22ef0ae 100644 --- a/pgml-extension/src/orm/dataset.rs +++ b/pgml-extension/src/orm/dataset.rs @@ -95,7 +95,7 @@ impl Display for TextDataset { fn drop_table_if_exists(table_name: &str) { // Avoid the existence for DROP TABLE IF EXISTS warning by checking the schema for the table first let table_count = Spi::get_one_with_args::("SELECT COUNT(*) FROM information_schema.tables WHERE table_name = $1 AND table_schema = 'pgml'", vec![ - (PgBuiltInOids::TEXTOID.oid(), table_name.clone().into_datum()) + (PgBuiltInOids::TEXTOID.oid(), table_name.into_datum()) ]).unwrap().unwrap(); if table_count == 1 { Spi::run(&format!(r#"DROP TABLE pgml.{table_name} CASCADE"#)).unwrap(); diff --git a/pgml-extension/src/orm/model.rs b/pgml-extension/src/orm/model.rs index 370ae7b02..da1940f60 100644 --- a/pgml-extension/src/orm/model.rs +++ b/pgml-extension/src/orm/model.rs @@ -96,8 +96,8 @@ impl Model { let dataset = snapshot.tabular_dataset(); let status = Status::in_progress; // Create the model record. - Spi::connect(|client| { - let result = client.select(" + Spi::connect(|mut client| { + let result = client.update(" INSERT INTO pgml.models (project_id, snapshot_id, algorithm, runtime, hyperparams, status, search, search_params, search_args, num_features) VALUES ($1, $2, cast($3 AS pgml.algorithm), cast($4 AS pgml.runtime), $5, cast($6 as pgml.status), $7, $8, $9, $10) RETURNING id, project_id, snapshot_id, algorithm, runtime, hyperparams, status, metrics, search, search_params, search_args, created_at, updated_at;", @@ -168,8 +168,8 @@ impl Model { let dataset = snapshot.text_dataset(); // Create the model record. - Spi::connect(|client| { - let result = client.select(" + Spi::connect(|mut client| { + let result = client.update(" INSERT INTO pgml.models (project_id, snapshot_id, algorithm, runtime, hyperparams, status, search, search_params, search_args, num_features) VALUES ($1, $2, cast($3 AS pgml.algorithm), cast($4 AS pgml.runtime), $5, cast($6 as pgml.status), $7, $8, $9, $10) RETURNING id, project_id, snapshot_id, algorithm, runtime::TEXT, hyperparams, status, metrics, search, search_params, search_args, created_at, updated_at;", diff --git a/pgml-extension/src/orm/snapshot.rs b/pgml-extension/src/orm/snapshot.rs index 6cf6f776c..85f697508 100644 --- a/pgml-extension/src/orm/snapshot.rs +++ b/pgml-extension/src/orm/snapshot.rs @@ -508,7 +508,7 @@ impl Snapshot { let preprocessors: HashMap = serde_json::from_value(preprocess.0).expect("is valid"); - Spi::connect(|client| { + Spi::connect(|mut client| { let mut columns: Vec = Vec::new(); client.select("SELECT column_name::TEXT, udt_name::TEXT, is_nullable::BOOLEAN, ordinal_position::INTEGER FROM information_schema.columns WHERE table_schema = $1 AND table_name = $2 ORDER BY ordinal_position ASC", None, @@ -587,7 +587,7 @@ impl Snapshot { } } - let result = client.select("INSERT INTO pgml.snapshots (relation_name, y_column_name, test_size, test_sampling, status, columns, materialized) VALUES ($1, $2, $3, $4::pgml.sampling, $5::pgml.status, $6, $7) RETURNING id, relation_name, y_column_name, test_size, test_sampling::TEXT, status::TEXT, columns, analysis, created_at, updated_at;", + let result = client.update("INSERT INTO pgml.snapshots (relation_name, y_column_name, test_size, test_sampling, status, columns, materialized) VALUES ($1, $2, $3, $4::pgml.sampling, $5::pgml.status, $6, $7) RETURNING id, relation_name, y_column_name, test_size, test_sampling::TEXT, status::TEXT, columns, analysis, created_at, updated_at;", Some(1), Some(vec![ (PgBuiltInOids::TEXTOID.oid(), relation_name.into_datum()), @@ -623,7 +623,7 @@ impl Snapshot { if s.test_sampling == Sampling::random { sql += " ORDER BY random()"; } - client.select(&sql, None, None).unwrap(); + client.update(&sql, None, None).unwrap(); } snapshot = Some(s); });