Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions pgml-extension/src/bindings/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use std::fmt::Debug;

use anyhow::{anyhow, Result};
#[allow(unused_imports)] // used for test macros
use pgrx::*;
use pyo3::{PyResult, Python};

use crate::orm::*;

Expand Down Expand Up @@ -40,6 +42,19 @@ pub trait Bindings: Send + Sync + Debug {
Self: Sized;
}

trait TracebackError<T> {
fn format_traceback(self, py: Python<'_>) -> Result<T>;
}

impl<T> TracebackError<T> for PyResult<T> {
fn format_traceback(self, py: Python<'_>) -> Result<T> {
self.map_err(|e| {
let traceback = e.traceback(py).unwrap().format().unwrap();
anyhow!("{traceback} {e}")
})
}
}

#[cfg(any(test, feature = "pg_test"))]
#[pg_schema]
mod tests {
Expand Down
122 changes: 76 additions & 46 deletions pgml-extension/src/bindings/transformers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@ use once_cell::sync::Lazy;
use pgrx::*;
use pyo3::prelude::*;
use pyo3::types::PyTuple;
use serde_json::Value;

use crate::orm::{Task, TextDataset};

use super::TracebackError;

pub mod whitelist;

static PY_MODULE: Lazy<Py<PyModule>> = Lazy::new(|| {
Expand Down Expand Up @@ -38,22 +41,36 @@ pub fn transform(
let inputs = serde_json::to_string(&inputs)?;

let results = Python::with_gil(|py| -> Result<String> {
let transform: Py<PyAny> = PY_MODULE.getattr(py, "transform")?;
let transform: Py<PyAny> = PY_MODULE.getattr(py, "transform").format_traceback(py)?;

let output = transform.call1(
py,
PyTuple::new(
let output = transform
.call1(
py,
&[task.into_py(py), args.into_py(py), inputs.into_py(py)],
),
)?;
PyTuple::new(
py,
&[task.into_py(py), args.into_py(py), inputs.into_py(py)],
),
)
.format_traceback(py)?;

Ok(output.extract(py)?)
Ok(output.extract(py).format_traceback(py)?)
})?;

Ok(serde_json::from_str(&results)?)
}

pub fn get_model_from(task: &Value) -> Result<String> {
Ok(Python::with_gil(|py| -> Result<String> {
let get_model_from = PY_MODULE
.getattr(py, "get_model_from")
.format_traceback(py)?;
let model = get_model_from
.call1(py, PyTuple::new(py, &[task.to_string().into_py(py)]))
.format_traceback(py)?;
Ok(model.extract(py).format_traceback(py)?)
})?)
}

pub fn embed(
transformer: &str,
inputs: Vec<&str>,
Expand All @@ -63,20 +80,22 @@ pub fn embed(

let kwargs = serde_json::to_string(kwargs)?;
Python::with_gil(|py| -> Result<Vec<Vec<f32>>> {
let embed: Py<PyAny> = PY_MODULE.getattr(py, "embed")?;
let output = embed.call1(
py,
PyTuple::new(
let embed: Py<PyAny> = PY_MODULE.getattr(py, "embed").format_traceback(py)?;
let output = embed
.call1(
py,
&[
transformer.to_string().into_py(py),
inputs.into_py(py),
kwargs.into_py(py),
],
),
)?;

Ok(output.extract(py)?)
PyTuple::new(
py,
&[
transformer.to_string().into_py(py),
inputs.into_py(py),
kwargs.into_py(py),
],
),
)
.format_traceback(py)?;

Ok(output.extract(py).format_traceback(py)?)
})
}

Expand All @@ -92,30 +111,32 @@ pub fn tune(
let hyperparams = serde_json::to_string(&hyperparams.0)?;

Python::with_gil(|py| -> Result<HashMap<String, f64>> {
let tune = PY_MODULE.getattr(py, "tune")?;
let tune = PY_MODULE.getattr(py, "tune").format_traceback(py)?;
let path = path.to_string_lossy();
let output = tune.call1(
py,
(
&task,
&hyperparams,
path.as_ref(),
dataset.x_train,
dataset.x_test,
dataset.y_train,
dataset.y_test,
),
)?;

Ok(output.extract(py)?)
let output = tune
.call1(
py,
(
&task,
&hyperparams,
path.as_ref(),
dataset.x_train,
dataset.x_test,
dataset.y_train,
dataset.y_test,
),
)
.format_traceback(py)?;

Ok(output.extract(py).format_traceback(py)?)
})
}

pub fn generate(model_id: i64, inputs: Vec<&str>, config: JsonB) -> Result<Vec<String>> {
crate::bindings::venv::activate();

Python::with_gil(|py| -> Result<Vec<String>> {
let generate = PY_MODULE.getattr(py, "generate")?;
let generate = PY_MODULE.getattr(py, "generate").format_traceback(py)?;
let config = serde_json::to_string(&config.0)?;
// cloning inputs in case we have to re-call on error is rather unfortunate here
// similarly, using a json string to pass kwargs is also unfortunate extra parsing
Expand Down Expand Up @@ -143,16 +164,19 @@ pub fn generate(model_id: i64, inputs: Vec<&str>, config: JsonB) -> Result<Vec<S
let load = PY_MODULE.getattr(py, "load_model")?;
let task = Task::from_str(&task)
.map_err(|_| anyhow!("could not make a Task from {task}"))?;
load.call1(py, (model_id, task.to_string(), dir))?;
load.call1(py, (model_id, task.to_string(), dir))
.format_traceback(py)?;

generate.call1(py, (model_id, inputs, config))?
generate
.call1(py, (model_id, inputs, config))
.format_traceback(py)?
} else {
return Err(e.into());
}
}
Ok(o) => o,
};
Ok(result.extract(py)?)
Ok(result.extract(py).format_traceback(py)?)
})
}

Expand Down Expand Up @@ -200,7 +224,7 @@ pub fn load_dataset(
let kwargs = serde_json::to_string(kwargs)?;

let dataset = Python::with_gil(|py| -> Result<String> {
let load_dataset: Py<PyAny> = PY_MODULE.getattr(py, "load_dataset")?;
let load_dataset: Py<PyAny> = PY_MODULE.getattr(py, "load_dataset").format_traceback(py)?;
Ok(load_dataset
.call1(
py,
Expand All @@ -213,8 +237,10 @@ pub fn load_dataset(
kwargs.into_py(py),
],
),
)?
.extract(py)?)
)
.format_traceback(py)?
.extract(py)
.format_traceback(py)?)
})?;

let table_name = format!("pgml.\"{}\"", name);
Expand Down Expand Up @@ -351,10 +377,14 @@ pub fn clear_gpu_cache(memory_usage: Option<f32>) -> Result<bool> {
crate::bindings::venv::activate();

Python::with_gil(|py| -> Result<bool> {
let clear_gpu_cache: Py<PyAny> = PY_MODULE.getattr(py, "clear_gpu_cache")?;
let clear_gpu_cache: Py<PyAny> = PY_MODULE
.getattr(py, "clear_gpu_cache")
.format_traceback(py)?;
let success = clear_gpu_cache
.call1(py, PyTuple::new(py, &[memory_usage.into_py(py)]))?
.extract(py)?;
.call1(py, PyTuple::new(py, &[memory_usage.into_py(py)]))
.format_traceback(py)?
.extract(py)
.format_traceback(py)?;
Ok(success)
})
}
11 changes: 11 additions & 0 deletions pgml-extension/src/bindings/transformers/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,17 @@ def __call__(self, inputs, **kwargs):
return self.pipe(inputs, **kwargs)


def get_model_from(task):
task = orjson.loads(task)
if "model" in task:
return task["model"]

if "task" in task:
model = transformers.pipelines.SUPPORTED_TASKS[task["task"]]["default"]["model"]
ty = "tf" if "tf" in model else "pt"
return model[ty][0]


def transform(task, args, inputs):
task = orjson.loads(task)
args = orjson.loads(args)
Expand Down