From 08fa511a5bcc8bca924c3608fb17873785fe7a23 Mon Sep 17 00:00:00 2001 From: Kevin Zimmerman <4733573+kczimm@users.noreply.github.com> Date: Fri, 4 Aug 2023 09:29:04 -0500 Subject: [PATCH 1/4] rename symbols in whitelist module --- .../src/bindings/transformers/error.rs | 8 +-- .../src/bindings/transformers/mod.rs | 3 +- .../src/bindings/transformers/whitelist.rs | 52 +++++++------------ 3 files changed, 25 insertions(+), 38 deletions(-) diff --git a/pgml-extension/src/bindings/transformers/error.rs b/pgml-extension/src/bindings/transformers/error.rs index e199ad7bb..4fa9d4f19 100644 --- a/pgml-extension/src/bindings/transformers/error.rs +++ b/pgml-extension/src/bindings/transformers/error.rs @@ -2,13 +2,13 @@ use std::fmt; use pyo3::PyErr; -use super::whitelist::WhitelistError; +use super::whitelist; #[derive(Debug)] pub enum Error { Serde(serde_json::Error), Python(PyErr), - Model(WhitelistError), + Model(whitelist::Error), Data(String), } @@ -31,8 +31,8 @@ impl From for Error { } } -impl From for Error { - fn from(value: WhitelistError) -> Self { +impl From for Error { + fn from(value: whitelist::Error) -> Self { Self::Model(value) } } diff --git a/pgml-extension/src/bindings/transformers/mod.rs b/pgml-extension/src/bindings/transformers/mod.rs index f9193ff9f..8d3b3e866 100644 --- a/pgml-extension/src/bindings/transformers/mod.rs +++ b/pgml-extension/src/bindings/transformers/mod.rs @@ -11,7 +11,6 @@ use pyo3::types::PyTuple; use crate::orm::{Task, TextDataset}; use self::error::Error; -use self::whitelist::verify_task_against_whitelist; pub mod error; pub mod whitelist; @@ -36,7 +35,7 @@ pub fn transform( ) -> Result { crate::bindings::venv::activate(); - verify_task_against_whitelist(task)?; + whitelist::verify_task(task)?; let task = serde_json::to_string(task)?; let args = serde_json::to_string(args)?; diff --git a/pgml-extension/src/bindings/transformers/whitelist.rs b/pgml-extension/src/bindings/transformers/whitelist.rs index 145c665ec..7a42442ac 100644 --- a/pgml-extension/src/bindings/transformers/whitelist.rs +++ b/pgml-extension/src/bindings/transformers/whitelist.rs @@ -1,4 +1,4 @@ -use std::{error::Error, fmt}; +use std::fmt; #[cfg(any(test, feature = "pg_test"))] use pgrx::{pg_schema, pg_test}; @@ -11,24 +11,24 @@ static CONFIG_HF_TRUST_REMOTE_CODE_BOOL: &str = "pgml.huggingface_trust_remote_c static CONFIG_HF_TRUST_WHITELIST: &str = "pgml.huggingface_trust_remote_code_whitelist"; #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Copy, Clone)] -pub enum WhitelistError { +pub enum Error { NotInWhitelist, RemoteCodeNotTrusted, } -impl fmt::Display for WhitelistError { +impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - WhitelistError::NotInWhitelist => writeln!(f, "model not in whitelist"), - WhitelistError::RemoteCodeNotTrusted => writeln!(f, "model remote code not trusted"), + Error::NotInWhitelist => writeln!(f, "model not in whitelist"), + Error::RemoteCodeNotTrusted => writeln!(f, "model remote code not trusted"), } } } -impl Error for WhitelistError {} +impl std::error::Error for Error {} /// Verify that the model in the task JSON is allowed based on the huggingface whitelists. -pub fn verify_task_against_whitelist(task: &Value) -> Result<(), WhitelistError> { +pub fn verify_task(task: &Value) -> Result<(), Error> { let task_model = match get_model_name(task) { Some(model) => model.to_string(), None => return Ok(()), @@ -38,7 +38,7 @@ pub fn verify_task_against_whitelist(task: &Value) -> Result<(), WhitelistError> let model_is_allowed = whitelisted_models.is_empty() || whitelisted_models.contains(&task_model); if !model_is_allowed { - return Err(WhitelistError::NotInWhitelist); + return Err(Error::NotInWhitelist); } let task_trust = get_trust_remote_code(task); @@ -52,7 +52,7 @@ pub fn verify_task_against_whitelist(task: &Value) -> Result<(), WhitelistError> let remote_code_allowed = trust_remote_code && model_is_trusted; if !remote_code_allowed && task_trust == Some(true) { - return Err(WhitelistError::RemoteCodeNotTrusted); + return Err(Error::RemoteCodeNotTrusted); } Ok(()) @@ -154,7 +154,7 @@ mod tests { set_config(CONFIG_HF_WHITELIST, "").unwrap(); let task_json = format!(json_template!(), model, false); let task: Value = serde_json::from_str(&task_json).unwrap(); - assert!(verify_task_against_whitelist(&task).is_ok()); + assert!(verify_task(&task).is_ok()); } #[pg_test] @@ -163,15 +163,12 @@ mod tests { set_config(CONFIG_HF_WHITELIST, model).unwrap(); let task_json = format!(json_template!(), model, false); let task: Value = serde_json::from_str(&task_json).unwrap(); - assert!(verify_task_against_whitelist(&task).is_ok()); + assert!(verify_task(&task).is_ok()); set_config(CONFIG_HF_WHITELIST, "other_model").unwrap(); let task_json = format!(json_template!(), model, false); let task: Value = serde_json::from_str(&task_json).unwrap(); - assert_eq!( - verify_task_against_whitelist(&task), - Err(WhitelistError::NotInWhitelist) - ); + assert_eq!(verify_task(&task), Err(Error::NotInWhitelist)); } #[pg_test] @@ -182,23 +179,20 @@ mod tests { let task_json = format!(json_template!(), model, false); let task: Value = serde_json::from_str(&task_json).unwrap(); - assert!(verify_task_against_whitelist(&task).is_ok()); + assert!(verify_task(&task).is_ok()); let task_json = format!(json_template!(), model, true); let task: Value = serde_json::from_str(&task_json).unwrap(); - assert_eq!( - verify_task_against_whitelist(&task), - Err(WhitelistError::RemoteCodeNotTrusted) - ); + assert_eq!(verify_task(&task), Err(Error::RemoteCodeNotTrusted)); set_config(CONFIG_HF_TRUST_REMOTE_CODE_BOOL, "true").unwrap(); let task_json = format!(json_template!(), model, false); let task: Value = serde_json::from_str(&task_json).unwrap(); - assert!(verify_task_against_whitelist(&task).is_ok()); + assert!(verify_task(&task).is_ok()); let task_json = format!(json_template!(), model, true); let task: Value = serde_json::from_str(&task_json).unwrap(); - assert!(verify_task_against_whitelist(&task).is_ok()); + assert!(verify_task(&task).is_ok()); } #[pg_test] @@ -209,25 +203,19 @@ mod tests { let task_json = format!(json_template!(), model, false); let task: Value = serde_json::from_str(&task_json).unwrap(); - assert!(verify_task_against_whitelist(&task).is_ok()); + assert!(verify_task(&task).is_ok()); let task_json = format!(json_template!(), model, true); let task: Value = serde_json::from_str(&task_json).unwrap(); - assert_eq!( - verify_task_against_whitelist(&task), - Err(WhitelistError::RemoteCodeNotTrusted) - ); + assert_eq!(verify_task(&task), Err(Error::RemoteCodeNotTrusted)); set_config(CONFIG_HF_TRUST_REMOTE_CODE_BOOL, "true").unwrap(); let task_json = format!(json_template!(), model, false); let task: Value = serde_json::from_str(&task_json).unwrap(); - assert!(verify_task_against_whitelist(&task).is_ok()); + assert!(verify_task(&task).is_ok()); let task_json = format!(json_template!(), model, true); let task: Value = serde_json::from_str(&task_json).unwrap(); - assert_eq!( - verify_task_against_whitelist(&task), - Err(WhitelistError::RemoteCodeNotTrusted) - ); + assert_eq!(verify_task(&task), Err(Error::RemoteCodeNotTrusted)); } } From bd815b7e0b5d67cd117079d8d355ab91eea2f15b Mon Sep 17 00:00:00 2001 From: Kevin Zimmerman <4733573+kczimm@users.noreply.github.com> Date: Fri, 4 Aug 2023 09:35:37 -0500 Subject: [PATCH 2/4] remove transformers::Error for anyhow --- .../src/bindings/transformers/error.rs | 44 ------------------- .../src/bindings/transformers/mod.rs | 19 ++------ 2 files changed, 4 insertions(+), 59 deletions(-) delete mode 100644 pgml-extension/src/bindings/transformers/error.rs diff --git a/pgml-extension/src/bindings/transformers/error.rs b/pgml-extension/src/bindings/transformers/error.rs deleted file mode 100644 index 4fa9d4f19..000000000 --- a/pgml-extension/src/bindings/transformers/error.rs +++ /dev/null @@ -1,44 +0,0 @@ -use std::fmt; - -use pyo3::PyErr; - -use super::whitelist; - -#[derive(Debug)] -pub enum Error { - Serde(serde_json::Error), - Python(PyErr), - Model(whitelist::Error), - Data(String), -} - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Error::Python(e) => write!(f, "{e}"), - Error::Model(e) => write!(f, "{e}"), - Error::Serde(e) => write!(f, "{e}"), - Error::Data(e) => write!(f, "{e}"), - } - } -} - -impl std::error::Error for Error {} - -impl From for Error { - fn from(value: PyErr) -> Self { - Self::Python(value) - } -} - -impl From for Error { - fn from(value: whitelist::Error) -> Self { - Self::Model(value) - } -} - -impl From for Error { - fn from(value: serde_json::Error) -> Self { - Self::Serde(value) - } -} diff --git a/pgml-extension/src/bindings/transformers/mod.rs b/pgml-extension/src/bindings/transformers/mod.rs index 8d3b3e866..d8d510b48 100644 --- a/pgml-extension/src/bindings/transformers/mod.rs +++ b/pgml-extension/src/bindings/transformers/mod.rs @@ -3,6 +3,7 @@ use std::path::PathBuf; use std::str::FromStr; use std::{collections::HashMap, path::Path}; +use anyhow::{bail, Result}; use once_cell::sync::Lazy; use pgrx::*; use pyo3::prelude::*; @@ -10,13 +11,8 @@ use pyo3::types::PyTuple; use crate::orm::{Task, TextDataset}; -use self::error::Error; - -pub mod error; pub mod whitelist; -pub type Result = std::result::Result; - static PY_MODULE: Lazy> = Lazy::new(|| { Python::with_gil(|py| -> Py { let src = include_str!(concat!( @@ -239,12 +235,7 @@ pub fn load_dataset( "float32" => "FLOAT4", "float16" => "FLOAT4", "bool" => "BOOLEAN", - _ => { - return Err(Error::Data(format!( - "unhandled dataset feature while reading dataset: {}", - type_ - ))) - } + _ => bail!("unhandled dataset feature while reading dataset: {type_}"), }; Ok(format!("{name} {type_}")) }) @@ -299,13 +290,11 @@ pub fn load_dataset( value.as_bool().unwrap().into_datum(), )), type_ => { - return Err(Error::Data(format!( - "unhandled dataset value type while reading dataset: {value:?} {type_:?}", - ))) + bail!("unhandled dataset value type while reading dataset: {value:?} {type_:?}") } } } - Spi::run_with_args(&insert, Some(row)).unwrap(); + Spi::run_with_args(&insert, Some(row))? } Ok(num_rows) From 7bee8be7f407b1a5674e0ed186b10791a8db1d4b Mon Sep 17 00:00:00 2001 From: Kevin Zimmerman <4733573+kczimm@users.noreply.github.com> Date: Fri, 4 Aug 2023 10:08:56 -0500 Subject: [PATCH 3/4] remove unwraps from transformers module --- .../src/bindings/transformers/mod.rs | 134 ++++++++++++------ 1 file changed, 90 insertions(+), 44 deletions(-) diff --git a/pgml-extension/src/bindings/transformers/mod.rs b/pgml-extension/src/bindings/transformers/mod.rs index d8d510b48..85f71d3c8 100644 --- a/pgml-extension/src/bindings/transformers/mod.rs +++ b/pgml-extension/src/bindings/transformers/mod.rs @@ -3,7 +3,7 @@ use std::path::PathBuf; use std::str::FromStr; use std::{collections::HashMap, path::Path}; -use anyhow::{bail, Result}; +use anyhow::{anyhow, bail, Context, Result}; use once_cell::sync::Lazy; use pgrx::*; use pyo3::prelude::*; @@ -93,12 +93,13 @@ pub fn tune( Python::with_gil(|py| -> Result> { let tune = PY_MODULE.getattr(py, "tune")?; + let path = path.to_string_lossy(); let output = tune.call1( py, ( &task, &hyperparams, - path.to_str().unwrap(), + path.as_ref(), dataset.x_train, dataset.x_test, dataset.y_train, @@ -122,12 +123,12 @@ pub fn generate(model_id: i64, inputs: Vec<&str>, config: JsonB) -> Result { - if e.get_type(py).name().unwrap() == "MissingModelError" { + if e.get_type(py).name()? == "MissingModelError" { info!("Loading model into cache for connection reuse"); let mut dir = std::path::PathBuf::from("/tmp/postgresml/models"); dir.push(model_id.to_string()); if !dir.exists() { - dump_model(model_id, dir.clone()); + dump_model(model_id, dir.clone())?; } let task = Spi::get_one_with_args::( "SELECT task::TEXT @@ -136,15 +137,15 @@ pub fn generate(model_id: i64, inputs: Vec<&str>, config: JsonB) -> Result, config: JsonB) -> Result Result<()> { if dir.exists() { - std::fs::remove_dir_all(&dir).unwrap(); + std::fs::remove_dir_all(&dir).context("failed to remove directory while dumping model")?; } - std::fs::create_dir_all(&dir).unwrap(); - Spi::connect(|client| { + std::fs::create_dir_all(&dir).context("failed to create directory while dumping model")?; + Spi::connect(|client| -> Result<()> { let result = client.select("SELECT path, part, data FROM pgml.files WHERE model_id = $1 ORDER BY path ASC, part ASC", None, Some(vec![ (PgBuiltInOids::INT8OID.oid(), model_id.into_datum()), ]) - ).unwrap(); + )?; for row in result { let mut path = dir.clone(); - path.push(row.get::(1).unwrap().unwrap()); - let data: Vec = row.get(3).unwrap().unwrap(); + path.push( + row.get::(1)? + .ok_or(anyhow!("row get ordinal 1 returned None"))?, + ); + let data: Vec = row + .get(3)? + .ok_or(anyhow!("row get ordinal 3 returned None"))?; let mut file = std::fs::OpenOptions::new() .create(true) .append(true) - .open(path) - .unwrap(); - let _num_bytes = file.write(&data).unwrap(); - file.flush().unwrap(); + .open(path)?; + + let _num_bytes = file.write(&data)?; + file.flush()?; } - }); + Ok(()) + }) } pub fn load_dataset( @@ -214,9 +221,19 @@ pub fn load_dataset( // Columns are a (name: String, values: Vec) pair let json: serde_json::Value = serde_json::from_str(&dataset)?; - let json = json.as_object().unwrap(); - let types = json.get("types").unwrap().as_object().unwrap(); - let data = json.get("data").unwrap().as_object().unwrap(); + let json = json + .as_object() + .ok_or(anyhow!("dataset json is not object"))?; + let types = json + .get("types") + .ok_or(anyhow!("dataset json missing `types` key"))? + .as_object() + .ok_or(anyhow!("dataset `types` key is not an object"))?; + let data = json + .get("data") + .ok_or(anyhow!("dataset json missing `data` key"))? + .as_object() + .ok_or(anyhow!("dataset `data` key is not an object"))?; let column_names = types .iter() .map(|(name, _type)| name.clone()) @@ -225,7 +242,10 @@ pub fn load_dataset( let column_types = types .iter() .map(|(name, type_)| -> Result { - let type_ = match type_.as_str().unwrap() { + let type_ = type_ + .as_str() + .ok_or(anyhow!("expected {type_} to be a json string"))?; + let type_ = match type_ { "string" => "TEXT", "dict" | "list" => "JSONB", "int64" => "INT8", @@ -251,27 +271,45 @@ pub fn load_dataset( .collect::>() .join(", "); let num_cols = types.len(); - let num_rows = data.values().next().unwrap().as_array().unwrap().len(); + let num_rows = data + .values() + .next() + .ok_or(anyhow!("dataset json has no fields"))? + .as_array() + .ok_or(anyhow!("dataset json field is not an array"))? + .len(); // Avoid the existence warning by checking the schema for the table first let table_count = Spi::get_one_with_args::("SELECT COUNT(*) FROM information_schema.tables WHERE table_name = $1 AND table_schema = 'pgml'", vec![ (PgBuiltInOids::TEXTOID.oid(), table_name.clone().into_datum()) - ]).unwrap().unwrap(); + ])?.ok_or(anyhow!("table count query returned None"))?; if table_count == 1 { - Spi::run(&format!(r#"DROP TABLE IF EXISTS {table_name}"#)).unwrap() + Spi::run(&format!(r#"DROP TABLE IF EXISTS {table_name}"#))?; } - Spi::run(&format!(r#"CREATE TABLE {table_name} ({column_types})"#)).unwrap(); + Spi::run(&format!(r#"CREATE TABLE {table_name} ({column_types})"#))?; let insert = format!(r#"INSERT INTO {table_name} ({column_names}) VALUES ({column_placeholders})"#); for i in 0..num_rows { let mut row = Vec::with_capacity(num_cols); for (name, values) in data { - let value = values.as_array().unwrap().get(i).unwrap(); - match types.get(name).unwrap().as_str().unwrap() { + let value = values + .as_array() + .ok_or_else(|| anyhow!("expected {values} to be an array"))? + .get(i) + .ok_or_else(|| anyhow!("invalid index {i} for {values}"))?; + match types + .get(name) + .ok_or_else(|| anyhow!("{types:?} expected to have key {name}"))? + .as_str() + .ok_or_else(|| anyhow!("json field {name} expected to be string"))? + { "string" => row.push(( PgBuiltInOids::TEXTOID.oid(), - value.as_str().unwrap().into_datum(), + value + .as_str() + .ok_or_else(|| anyhow!("expected {value} to be string"))? + .into_datum(), )), "dict" | "list" => row.push(( PgBuiltInOids::JSONBOID.oid(), @@ -279,15 +317,24 @@ pub fn load_dataset( )), "int64" | "int32" | "int16" => row.push(( PgBuiltInOids::INT8OID.oid(), - value.as_i64().unwrap().into_datum(), + value + .as_i64() + .ok_or_else(|| anyhow!("expected {value} to be i64"))? + .into_datum(), )), "float64" | "float32" | "float16" => row.push(( PgBuiltInOids::FLOAT8OID.oid(), - value.as_f64().unwrap().into_datum(), + value + .as_f64() + .ok_or_else(|| anyhow!("expected {value} to be f64"))? + .into_datum(), )), "bool" => row.push(( PgBuiltInOids::BOOLOID.oid(), - value.as_bool().unwrap().into_datum(), + value + .as_bool() + .ok_or_else(|| anyhow!("expected {value} to be bool"))? + .into_datum(), )), type_ => { bail!("unhandled dataset value type while reading dataset: {value:?} {type_:?}") @@ -300,13 +347,12 @@ pub fn load_dataset( Ok(num_rows) } -pub fn clear_gpu_cache(memory_usage: Option) -> bool { - Python::with_gil(|py| -> bool { - let clear_gpu_cache: Py = PY_MODULE.getattr(py, "clear_gpu_cache").unwrap(); - clear_gpu_cache - .call1(py, PyTuple::new(py, &[memory_usage.into_py(py)])) - .unwrap() - .extract(py) - .unwrap() +pub fn clear_gpu_cache(memory_usage: Option) -> Result { + Python::with_gil(|py| -> Result { + let clear_gpu_cache: Py = PY_MODULE.getattr(py, "clear_gpu_cache")?; + let success = clear_gpu_cache + .call1(py, PyTuple::new(py, &[memory_usage.into_py(py)]))? + .extract(py)?; + Ok(success) }) } From a7e4076a402fea370e834a222d45ce26edfa23e3 Mon Sep 17 00:00:00 2001 From: Kevin Zimmerman <4733573+kczimm@users.noreply.github.com> Date: Fri, 4 Aug 2023 10:12:44 -0500 Subject: [PATCH 4/4] fix compile error in api --- pgml-extension/src/api.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs index 175c2db60..0a6418e26 100644 --- a/pgml-extension/src/api.rs +++ b/pgml-extension/src/api.rs @@ -602,7 +602,10 @@ pub fn embed_batch( /// ``` #[pg_extern(immutable, parallel_safe, name = "clear_gpu_cache")] pub fn clear_gpu_cache(memory_usage: default!(Option, "NULL")) -> bool { - crate::bindings::transformers::clear_gpu_cache(memory_usage) + match crate::bindings::transformers::clear_gpu_cache(memory_usage) { + Ok(success) => success, + Err(e) => error!("{e}"), + } } #[pg_extern(immutable, parallel_safe)]