Skip to content

Commit a46825a

Browse files
higuoxingbeeender
andauthored
Mutable SQL commands should be used with Spi::update(). (#1114)
Co-authored-by: Chen Mulong <chenmulong@gmail.com>
1 parent 5797d26 commit a46825a

File tree

3 files changed

+8
-8
lines changed

3 files changed

+8
-8
lines changed

pgml-extension/src/orm/dataset.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ impl Display for TextDataset {
9595
fn drop_table_if_exists(table_name: &str) {
9696
// Avoid the existence for DROP TABLE IF EXISTS warning by checking the schema for the table first
9797
let table_count = Spi::get_one_with_args::<i64>("SELECT COUNT(*) FROM information_schema.tables WHERE table_name = $1 AND table_schema = 'pgml'", vec![
98-
(PgBuiltInOids::TEXTOID.oid(), table_name.clone().into_datum())
98+
(PgBuiltInOids::TEXTOID.oid(), table_name.into_datum())
9999
]).unwrap().unwrap();
100100
if table_count == 1 {
101101
Spi::run(&format!(r#"DROP TABLE pgml.{table_name} CASCADE"#)).unwrap();

pgml-extension/src/orm/model.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,8 @@ impl Model {
9696
let dataset = snapshot.tabular_dataset();
9797
let status = Status::in_progress;
9898
// Create the model record.
99-
Spi::connect(|client| {
100-
let result = client.select("
99+
Spi::connect(|mut client| {
100+
let result = client.update("
101101
INSERT INTO pgml.models (project_id, snapshot_id, algorithm, runtime, hyperparams, status, search, search_params, search_args, num_features)
102102
VALUES ($1, $2, cast($3 AS pgml.algorithm), cast($4 AS pgml.runtime), $5, cast($6 as pgml.status), $7, $8, $9, $10)
103103
RETURNING id, project_id, snapshot_id, algorithm, runtime, hyperparams, status, metrics, search, search_params, search_args, created_at, updated_at;",
@@ -168,8 +168,8 @@ impl Model {
168168
let dataset = snapshot.text_dataset();
169169

170170
// Create the model record.
171-
Spi::connect(|client| {
172-
let result = client.select("
171+
Spi::connect(|mut client| {
172+
let result = client.update("
173173
INSERT INTO pgml.models (project_id, snapshot_id, algorithm, runtime, hyperparams, status, search, search_params, search_args, num_features)
174174
VALUES ($1, $2, cast($3 AS pgml.algorithm), cast($4 AS pgml.runtime), $5, cast($6 as pgml.status), $7, $8, $9, $10)
175175
RETURNING id, project_id, snapshot_id, algorithm, runtime::TEXT, hyperparams, status, metrics, search, search_params, search_args, created_at, updated_at;",

pgml-extension/src/orm/snapshot.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -508,7 +508,7 @@ impl Snapshot {
508508
let preprocessors: HashMap<String, Preprocessor> =
509509
serde_json::from_value(preprocess.0).expect("is valid");
510510

511-
Spi::connect(|client| {
511+
Spi::connect(|mut client| {
512512
let mut columns: Vec<Column> = Vec::new();
513513
client.select("SELECT column_name::TEXT, udt_name::TEXT, is_nullable::BOOLEAN, ordinal_position::INTEGER FROM information_schema.columns WHERE table_schema = $1 AND table_name = $2 ORDER BY ordinal_position ASC",
514514
None,
@@ -587,7 +587,7 @@ impl Snapshot {
587587
}
588588
}
589589

590-
let result = client.select("INSERT INTO pgml.snapshots (relation_name, y_column_name, test_size, test_sampling, status, columns, materialized) VALUES ($1, $2, $3, $4::pgml.sampling, $5::pgml.status, $6, $7) RETURNING id, relation_name, y_column_name, test_size, test_sampling::TEXT, status::TEXT, columns, analysis, created_at, updated_at;",
590+
let result = client.update("INSERT INTO pgml.snapshots (relation_name, y_column_name, test_size, test_sampling, status, columns, materialized) VALUES ($1, $2, $3, $4::pgml.sampling, $5::pgml.status, $6, $7) RETURNING id, relation_name, y_column_name, test_size, test_sampling::TEXT, status::TEXT, columns, analysis, created_at, updated_at;",
591591
Some(1),
592592
Some(vec![
593593
(PgBuiltInOids::TEXTOID.oid(), relation_name.into_datum()),
@@ -623,7 +623,7 @@ impl Snapshot {
623623
if s.test_sampling == Sampling::random {
624624
sql += " ORDER BY random()";
625625
}
626-
client.select(&sql, None, None).unwrap();
626+
client.update(&sql, None, None).unwrap();
627627
}
628628
snapshot = Some(s);
629629
});

0 commit comments

Comments
 (0)