Skip to content

Commit 6d67b81

Browse files
authored
format Python tracebacks like they used to be (#916)
1 parent 0eea430 commit 6d67b81

File tree

3 files changed

+102
-46
lines changed

3 files changed

+102
-46
lines changed

pgml-extension/src/bindings/mod.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
use std::fmt::Debug;
22

3+
use anyhow::{anyhow, Result};
34
#[allow(unused_imports)] // used for test macros
45
use pgrx::*;
6+
use pyo3::{PyResult, Python};
57

68
use crate::orm::*;
79

@@ -40,6 +42,19 @@ pub trait Bindings: Send + Sync + Debug {
4042
Self: Sized;
4143
}
4244

45+
trait TracebackError<T> {
46+
fn format_traceback(self, py: Python<'_>) -> Result<T>;
47+
}
48+
49+
impl<T> TracebackError<T> for PyResult<T> {
50+
fn format_traceback(self, py: Python<'_>) -> Result<T> {
51+
self.map_err(|e| {
52+
let traceback = e.traceback(py).unwrap().format().unwrap();
53+
anyhow!("{traceback} {e}")
54+
})
55+
}
56+
}
57+
4358
#[cfg(any(test, feature = "pg_test"))]
4459
#[pg_schema]
4560
mod tests {

pgml-extension/src/bindings/transformers/mod.rs

Lines changed: 76 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,12 @@ use once_cell::sync::Lazy;
88
use pgrx::*;
99
use pyo3::prelude::*;
1010
use pyo3::types::PyTuple;
11+
use serde_json::Value;
1112

1213
use crate::orm::{Task, TextDataset};
1314

15+
use super::TracebackError;
16+
1417
pub mod whitelist;
1518

1619
static PY_MODULE: Lazy<Py<PyModule>> = Lazy::new(|| {
@@ -38,22 +41,36 @@ pub fn transform(
3841
let inputs = serde_json::to_string(&inputs)?;
3942

4043
let results = Python::with_gil(|py| -> Result<String> {
41-
let transform: Py<PyAny> = PY_MODULE.getattr(py, "transform")?;
44+
let transform: Py<PyAny> = PY_MODULE.getattr(py, "transform").format_traceback(py)?;
4245

43-
let output = transform.call1(
44-
py,
45-
PyTuple::new(
46+
let output = transform
47+
.call1(
4648
py,
47-
&[task.into_py(py), args.into_py(py), inputs.into_py(py)],
48-
),
49-
)?;
49+
PyTuple::new(
50+
py,
51+
&[task.into_py(py), args.into_py(py), inputs.into_py(py)],
52+
),
53+
)
54+
.format_traceback(py)?;
5055

51-
Ok(output.extract(py)?)
56+
Ok(output.extract(py).format_traceback(py)?)
5257
})?;
5358

5459
Ok(serde_json::from_str(&results)?)
5560
}
5661

62+
pub fn get_model_from(task: &Value) -> Result<String> {
63+
Ok(Python::with_gil(|py| -> Result<String> {
64+
let get_model_from = PY_MODULE
65+
.getattr(py, "get_model_from")
66+
.format_traceback(py)?;
67+
let model = get_model_from
68+
.call1(py, PyTuple::new(py, &[task.to_string().into_py(py)]))
69+
.format_traceback(py)?;
70+
Ok(model.extract(py).format_traceback(py)?)
71+
})?)
72+
}
73+
5774
pub fn embed(
5875
transformer: &str,
5976
inputs: Vec<&str>,
@@ -63,20 +80,22 @@ pub fn embed(
6380

6481
let kwargs = serde_json::to_string(kwargs)?;
6582
Python::with_gil(|py| -> Result<Vec<Vec<f32>>> {
66-
let embed: Py<PyAny> = PY_MODULE.getattr(py, "embed")?;
67-
let output = embed.call1(
68-
py,
69-
PyTuple::new(
83+
let embed: Py<PyAny> = PY_MODULE.getattr(py, "embed").format_traceback(py)?;
84+
let output = embed
85+
.call1(
7086
py,
71-
&[
72-
transformer.to_string().into_py(py),
73-
inputs.into_py(py),
74-
kwargs.into_py(py),
75-
],
76-
),
77-
)?;
78-
79-
Ok(output.extract(py)?)
87+
PyTuple::new(
88+
py,
89+
&[
90+
transformer.to_string().into_py(py),
91+
inputs.into_py(py),
92+
kwargs.into_py(py),
93+
],
94+
),
95+
)
96+
.format_traceback(py)?;
97+
98+
Ok(output.extract(py).format_traceback(py)?)
8099
})
81100
}
82101

@@ -92,30 +111,32 @@ pub fn tune(
92111
let hyperparams = serde_json::to_string(&hyperparams.0)?;
93112

94113
Python::with_gil(|py| -> Result<HashMap<String, f64>> {
95-
let tune = PY_MODULE.getattr(py, "tune")?;
114+
let tune = PY_MODULE.getattr(py, "tune").format_traceback(py)?;
96115
let path = path.to_string_lossy();
97-
let output = tune.call1(
98-
py,
99-
(
100-
&task,
101-
&hyperparams,
102-
path.as_ref(),
103-
dataset.x_train,
104-
dataset.x_test,
105-
dataset.y_train,
106-
dataset.y_test,
107-
),
108-
)?;
109-
110-
Ok(output.extract(py)?)
116+
let output = tune
117+
.call1(
118+
py,
119+
(
120+
&task,
121+
&hyperparams,
122+
path.as_ref(),
123+
dataset.x_train,
124+
dataset.x_test,
125+
dataset.y_train,
126+
dataset.y_test,
127+
),
128+
)
129+
.format_traceback(py)?;
130+
131+
Ok(output.extract(py).format_traceback(py)?)
111132
})
112133
}
113134

114135
pub fn generate(model_id: i64, inputs: Vec<&str>, config: JsonB) -> Result<Vec<String>> {
115136
crate::bindings::venv::activate();
116137

117138
Python::with_gil(|py| -> Result<Vec<String>> {
118-
let generate = PY_MODULE.getattr(py, "generate")?;
139+
let generate = PY_MODULE.getattr(py, "generate").format_traceback(py)?;
119140
let config = serde_json::to_string(&config.0)?;
120141
// cloning inputs in case we have to re-call on error is rather unfortunate here
121142
// similarly, using a json string to pass kwargs is also unfortunate extra parsing
@@ -143,16 +164,19 @@ pub fn generate(model_id: i64, inputs: Vec<&str>, config: JsonB) -> Result<Vec<S
143164
let load = PY_MODULE.getattr(py, "load_model")?;
144165
let task = Task::from_str(&task)
145166
.map_err(|_| anyhow!("could not make a Task from {task}"))?;
146-
load.call1(py, (model_id, task.to_string(), dir))?;
167+
load.call1(py, (model_id, task.to_string(), dir))
168+
.format_traceback(py)?;
147169

148-
generate.call1(py, (model_id, inputs, config))?
170+
generate
171+
.call1(py, (model_id, inputs, config))
172+
.format_traceback(py)?
149173
} else {
150174
return Err(e.into());
151175
}
152176
}
153177
Ok(o) => o,
154178
};
155-
Ok(result.extract(py)?)
179+
Ok(result.extract(py).format_traceback(py)?)
156180
})
157181
}
158182

@@ -200,7 +224,7 @@ pub fn load_dataset(
200224
let kwargs = serde_json::to_string(kwargs)?;
201225

202226
let dataset = Python::with_gil(|py| -> Result<String> {
203-
let load_dataset: Py<PyAny> = PY_MODULE.getattr(py, "load_dataset")?;
227+
let load_dataset: Py<PyAny> = PY_MODULE.getattr(py, "load_dataset").format_traceback(py)?;
204228
Ok(load_dataset
205229
.call1(
206230
py,
@@ -213,8 +237,10 @@ pub fn load_dataset(
213237
kwargs.into_py(py),
214238
],
215239
),
216-
)?
217-
.extract(py)?)
240+
)
241+
.format_traceback(py)?
242+
.extract(py)
243+
.format_traceback(py)?)
218244
})?;
219245

220246
let table_name = format!("pgml.\"{}\"", name);
@@ -351,10 +377,14 @@ pub fn clear_gpu_cache(memory_usage: Option<f32>) -> Result<bool> {
351377
crate::bindings::venv::activate();
352378

353379
Python::with_gil(|py| -> Result<bool> {
354-
let clear_gpu_cache: Py<PyAny> = PY_MODULE.getattr(py, "clear_gpu_cache")?;
380+
let clear_gpu_cache: Py<PyAny> = PY_MODULE
381+
.getattr(py, "clear_gpu_cache")
382+
.format_traceback(py)?;
355383
let success = clear_gpu_cache
356-
.call1(py, PyTuple::new(py, &[memory_usage.into_py(py)]))?
357-
.extract(py)?;
384+
.call1(py, PyTuple::new(py, &[memory_usage.into_py(py)]))
385+
.format_traceback(py)?
386+
.extract(py)
387+
.format_traceback(py)?;
358388
Ok(success)
359389
})
360390
}

pgml-extension/src/bindings/transformers/transformers.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,17 @@ def __call__(self, inputs, **kwargs):
173173
return self.pipe(inputs, **kwargs)
174174

175175

176+
def get_model_from(task):
177+
task = orjson.loads(task)
178+
if "model" in task:
179+
return task["model"]
180+
181+
if "task" in task:
182+
model = transformers.pipelines.SUPPORTED_TASKS[task["task"]]["default"]["model"]
183+
ty = "tf" if "tf" in model else "pt"
184+
return model[ty][0]
185+
186+
176187
def transform(task, args, inputs):
177188
task = orjson.loads(task)
178189
args = orjson.loads(args)

0 commit comments

Comments
 (0)