From c77770aa13c09117b5e5acee2cdcee78377c68b0 Mon Sep 17 00:00:00 2001 From: Kevin Zimmerman <4733573+kczimm@users.noreply.github.com> Date: Mon, 14 Aug 2023 12:56:24 -0500 Subject: [PATCH] format Python tracebacks like they used to be --- pgml-extension/src/bindings/mod.rs | 15 +++ .../src/bindings/transformers/mod.rs | 122 +++++++++++------- .../src/bindings/transformers/transformers.py | 11 ++ 3 files changed, 102 insertions(+), 46 deletions(-) diff --git a/pgml-extension/src/bindings/mod.rs b/pgml-extension/src/bindings/mod.rs index 3af906c28..5c32608f1 100644 --- a/pgml-extension/src/bindings/mod.rs +++ b/pgml-extension/src/bindings/mod.rs @@ -1,7 +1,9 @@ use std::fmt::Debug; +use anyhow::{anyhow, Result}; #[allow(unused_imports)] // used for test macros use pgrx::*; +use pyo3::{PyResult, Python}; use crate::orm::*; @@ -40,6 +42,19 @@ pub trait Bindings: Send + Sync + Debug { Self: Sized; } +trait TracebackError { + fn format_traceback(self, py: Python<'_>) -> Result; +} + +impl TracebackError for PyResult { + fn format_traceback(self, py: Python<'_>) -> Result { + self.map_err(|e| { + let traceback = e.traceback(py).unwrap().format().unwrap(); + anyhow!("{traceback} {e}") + }) + } +} + #[cfg(any(test, feature = "pg_test"))] #[pg_schema] mod tests { diff --git a/pgml-extension/src/bindings/transformers/mod.rs b/pgml-extension/src/bindings/transformers/mod.rs index d94e87de7..c5fffa1c6 100644 --- a/pgml-extension/src/bindings/transformers/mod.rs +++ b/pgml-extension/src/bindings/transformers/mod.rs @@ -8,9 +8,12 @@ use once_cell::sync::Lazy; use pgrx::*; use pyo3::prelude::*; use pyo3::types::PyTuple; +use serde_json::Value; use crate::orm::{Task, TextDataset}; +use super::TracebackError; + pub mod whitelist; static PY_MODULE: Lazy> = Lazy::new(|| { @@ -38,22 +41,36 @@ pub fn transform( let inputs = serde_json::to_string(&inputs)?; let results = Python::with_gil(|py| -> Result { - let transform: Py = PY_MODULE.getattr(py, "transform")?; + let transform: Py = PY_MODULE.getattr(py, "transform").format_traceback(py)?; - let output = transform.call1( - py, - PyTuple::new( + let output = transform + .call1( py, - &[task.into_py(py), args.into_py(py), inputs.into_py(py)], - ), - )?; + PyTuple::new( + py, + &[task.into_py(py), args.into_py(py), inputs.into_py(py)], + ), + ) + .format_traceback(py)?; - Ok(output.extract(py)?) + Ok(output.extract(py).format_traceback(py)?) })?; Ok(serde_json::from_str(&results)?) } +pub fn get_model_from(task: &Value) -> Result { + Ok(Python::with_gil(|py| -> Result { + let get_model_from = PY_MODULE + .getattr(py, "get_model_from") + .format_traceback(py)?; + let model = get_model_from + .call1(py, PyTuple::new(py, &[task.to_string().into_py(py)])) + .format_traceback(py)?; + Ok(model.extract(py).format_traceback(py)?) + })?) +} + pub fn embed( transformer: &str, inputs: Vec<&str>, @@ -63,20 +80,22 @@ pub fn embed( let kwargs = serde_json::to_string(kwargs)?; Python::with_gil(|py| -> Result>> { - let embed: Py = PY_MODULE.getattr(py, "embed")?; - let output = embed.call1( - py, - PyTuple::new( + let embed: Py = PY_MODULE.getattr(py, "embed").format_traceback(py)?; + let output = embed + .call1( py, - &[ - transformer.to_string().into_py(py), - inputs.into_py(py), - kwargs.into_py(py), - ], - ), - )?; - - Ok(output.extract(py)?) + PyTuple::new( + py, + &[ + transformer.to_string().into_py(py), + inputs.into_py(py), + kwargs.into_py(py), + ], + ), + ) + .format_traceback(py)?; + + Ok(output.extract(py).format_traceback(py)?) }) } @@ -92,22 +111,24 @@ pub fn tune( let hyperparams = serde_json::to_string(&hyperparams.0)?; Python::with_gil(|py| -> Result> { - let tune = PY_MODULE.getattr(py, "tune")?; + let tune = PY_MODULE.getattr(py, "tune").format_traceback(py)?; let path = path.to_string_lossy(); - let output = tune.call1( - py, - ( - &task, - &hyperparams, - path.as_ref(), - dataset.x_train, - dataset.x_test, - dataset.y_train, - dataset.y_test, - ), - )?; - - Ok(output.extract(py)?) + let output = tune + .call1( + py, + ( + &task, + &hyperparams, + path.as_ref(), + dataset.x_train, + dataset.x_test, + dataset.y_train, + dataset.y_test, + ), + ) + .format_traceback(py)?; + + Ok(output.extract(py).format_traceback(py)?) }) } @@ -115,7 +136,7 @@ pub fn generate(model_id: i64, inputs: Vec<&str>, config: JsonB) -> Result Result> { - let generate = PY_MODULE.getattr(py, "generate")?; + let generate = PY_MODULE.getattr(py, "generate").format_traceback(py)?; let config = serde_json::to_string(&config.0)?; // cloning inputs in case we have to re-call on error is rather unfortunate here // similarly, using a json string to pass kwargs is also unfortunate extra parsing @@ -143,16 +164,19 @@ pub fn generate(model_id: i64, inputs: Vec<&str>, config: JsonB) -> Result o, }; - Ok(result.extract(py)?) + Ok(result.extract(py).format_traceback(py)?) }) } @@ -200,7 +224,7 @@ pub fn load_dataset( let kwargs = serde_json::to_string(kwargs)?; let dataset = Python::with_gil(|py| -> Result { - let load_dataset: Py = PY_MODULE.getattr(py, "load_dataset")?; + let load_dataset: Py = PY_MODULE.getattr(py, "load_dataset").format_traceback(py)?; Ok(load_dataset .call1( py, @@ -213,8 +237,10 @@ pub fn load_dataset( kwargs.into_py(py), ], ), - )? - .extract(py)?) + ) + .format_traceback(py)? + .extract(py) + .format_traceback(py)?) })?; let table_name = format!("pgml.\"{}\"", name); @@ -351,10 +377,14 @@ pub fn clear_gpu_cache(memory_usage: Option) -> Result { crate::bindings::venv::activate(); Python::with_gil(|py| -> Result { - let clear_gpu_cache: Py = PY_MODULE.getattr(py, "clear_gpu_cache")?; + let clear_gpu_cache: Py = PY_MODULE + .getattr(py, "clear_gpu_cache") + .format_traceback(py)?; let success = clear_gpu_cache - .call1(py, PyTuple::new(py, &[memory_usage.into_py(py)]))? - .extract(py)?; + .call1(py, PyTuple::new(py, &[memory_usage.into_py(py)])) + .format_traceback(py)? + .extract(py) + .format_traceback(py)?; Ok(success) }) } diff --git a/pgml-extension/src/bindings/transformers/transformers.py b/pgml-extension/src/bindings/transformers/transformers.py index 4cdab2d44..fe2f6b3e7 100644 --- a/pgml-extension/src/bindings/transformers/transformers.py +++ b/pgml-extension/src/bindings/transformers/transformers.py @@ -173,6 +173,17 @@ def __call__(self, inputs, **kwargs): return self.pipe(inputs, **kwargs) +def get_model_from(task): + task = orjson.loads(task) + if "model" in task: + return task["model"] + + if "task" in task: + model = transformers.pipelines.SUPPORTED_TASKS[task["task"]]["default"]["model"] + ty = "tf" if "tf" in model else "pt" + return model[ty][0] + + def transform(task, args, inputs): task = orjson.loads(task) args = orjson.loads(args)