Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 51 additions & 44 deletions pgml-extension/src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,25 @@ use serde_json::json;
use crate::bindings::sklearn::package_version;
use crate::orm::*;

macro_rules! unwrap_or_error {
($i:expr) => {
match $i {
Ok(v) => v,
Err(e) => error!("{e}"),
}
};
}

#[cfg(feature = "python")]
#[pg_extern]
pub fn activate_venv(venv: &str) -> bool {
crate::bindings::venv::activate_venv(venv)
unwrap_or_error!(crate::bindings::venv::activate_venv(venv))
}

#[cfg(feature = "python")]
#[pg_extern(immutable, parallel_safe)]
pub fn validate_python_dependencies() -> bool {
crate::bindings::venv::activate();
unwrap_or_error!(crate::bindings::venv::activate());

Python::with_gil(|py| {
let sys = PyModule::import(py, "sys").unwrap();
Expand All @@ -40,13 +49,12 @@ pub fn validate_python_dependencies() -> bool {
}
});

info!(
"Scikit-learn {}, XGBoost {}, LightGBM {}, NumPy {}",
package_version("sklearn"),
package_version("xgboost"),
package_version("lightgbm"),
package_version("numpy"),
);
let sklearn = unwrap_or_error!(package_version("sklearn"));
let xgboost = unwrap_or_error!(package_version("xgboost"));
let lightgbm = unwrap_or_error!(package_version("lightgbm"));
let numpy = unwrap_or_error!(package_version("numpy"));

info!("Scikit-learn {sklearn}, XGBoost {xgboost}, LightGBM {lightgbm}, NumPy {numpy}",);

true
}
Expand All @@ -58,8 +66,8 @@ pub fn validate_python_dependencies() {}
#[cfg(feature = "python")]
#[pg_extern]
pub fn python_package_version(name: &str) -> String {
crate::bindings::venv::activate();
package_version(name)
unwrap_or_error!(crate::bindings::venv::activate());
unwrap_or_error!(package_version(name))
}

#[cfg(not(feature = "python"))]
Expand All @@ -71,9 +79,9 @@ pub fn python_package_version(name: &str) {
#[cfg(feature = "python")]
#[pg_extern]
pub fn python_pip_freeze() -> TableIterator<'static, (name!(package, String),)> {
crate::bindings::venv::activate();
unwrap_or_error!(crate::bindings::venv::activate());

let packages = crate::bindings::venv::freeze()
let packages = unwrap_or_error!(crate::bindings::venv::freeze())
.into_iter()
.map(|package| (package,));

Expand All @@ -99,7 +107,7 @@ pub fn validate_shared_library() {
#[cfg(feature = "python")]
#[pg_extern]
fn python_version() -> String {
crate::bindings::venv::activate();
unwrap_or_error!(crate::bindings::venv::activate());
let mut version = String::new();

Python::with_gil(|py| {
Expand Down Expand Up @@ -479,27 +487,31 @@ fn predict_row(project_name: &str, row: pgrx::datum::AnyElement) -> f32 {

#[pg_extern(immutable, parallel_safe, strict, name = "predict")]
fn predict_model(model_id: i64, features: Vec<f32>) -> f32 {
Model::find_cached(model_id).predict(&features)
let model = unwrap_or_error!(Model::find_cached(model_id));
unwrap_or_error!(model.predict(&features))
}

#[pg_extern(immutable, parallel_safe, strict, name = "predict_proba")]
fn predict_model_proba(model_id: i64, features: Vec<f32>) -> Vec<f32> {
Model::find_cached(model_id).predict_proba(&features)
let model = unwrap_or_error!(Model::find_cached(model_id));
unwrap_or_error!(model.predict_proba(&features))
}

#[pg_extern(immutable, parallel_safe, strict, name = "predict_joint")]
fn predict_model_joint(model_id: i64, features: Vec<f32>) -> Vec<f32> {
Model::find_cached(model_id).predict_joint(&features)
let model = unwrap_or_error!(Model::find_cached(model_id));
unwrap_or_error!(model.predict_joint(&features))
}

#[pg_extern(immutable, parallel_safe, strict, name = "predict_batch")]
fn predict_model_batch(model_id: i64, features: Vec<f32>) -> Vec<f32> {
Model::find_cached(model_id).predict_batch(&features)
let model = unwrap_or_error!(Model::find_cached(model_id));
unwrap_or_error!(model.predict_batch(&features))
}

#[pg_extern(immutable, parallel_safe, strict, name = "predict")]
fn predict_model_row(model_id: i64, row: pgrx::datum::AnyElement) -> f32 {
let model = Model::find_cached(model_id);
let model = unwrap_or_error!(Model::find_cached(model_id));
let snapshot = &model.snapshot;
let numeric_encoded_features = model.numeric_encode_features(&[row]);
let features_width = snapshot.features_width();
Expand All @@ -514,7 +526,7 @@ fn predict_model_row(model_id: i64, row: pgrx::datum::AnyElement) -> f32 {
let column = &snapshot.columns[position.column_position - 1];
column.preprocess(&data, &mut processed, features_width, position.row_position);
});
model.predict(&processed)
unwrap_or_error!(model.predict(&processed))
}

#[pg_extern]
Expand Down Expand Up @@ -617,7 +629,11 @@ pub fn chunk(
text: &str,
kwargs: default!(JsonB, "'{}'"),
) -> TableIterator<'static, (name!(chunk_index, i64), name!(chunk, String))> {
let chunks = crate::bindings::langchain::chunk(splitter, text, &kwargs.0);
let chunks = match crate::bindings::langchain::chunk(splitter, text, &kwargs.0) {
Ok(chunks) => chunks,
Err(e) => error!("{e}"),
};

let chunks = chunks
.into_iter()
.enumerate()
Expand Down Expand Up @@ -838,28 +854,23 @@ fn tune(
#[cfg(feature = "python")]
#[pg_extern(name = "sklearn_f1_score")]
pub fn sklearn_f1_score(ground_truth: Vec<f32>, y_hat: Vec<f32>) -> f32 {
crate::bindings::sklearn::f1(&ground_truth, &y_hat)
unwrap_or_error!(crate::bindings::sklearn::f1(&ground_truth, &y_hat))
}

#[cfg(feature = "python")]
#[pg_extern(name = "sklearn_r2_score")]
pub fn sklearn_r2_score(ground_truth: Vec<f32>, y_hat: Vec<f32>) -> f32 {
crate::bindings::sklearn::r2(&ground_truth, &y_hat)
unwrap_or_error!(crate::bindings::sklearn::r2(&ground_truth, &y_hat))
}

#[cfg(feature = "python")]
#[pg_extern(name = "sklearn_regression_metrics")]
pub fn sklearn_regression_metrics(ground_truth: Vec<f32>, y_hat: Vec<f32>) -> JsonB {
JsonB(
serde_json::from_str(
&serde_json::to_string(&crate::bindings::sklearn::regression_metrics(
&ground_truth,
&y_hat,
))
.unwrap(),
)
.unwrap(),
)
let metrics = unwrap_or_error!(crate::bindings::sklearn::regression_metrics(
&ground_truth,
&y_hat,
));
JsonB(json!(metrics))
}

#[cfg(feature = "python")]
Expand All @@ -869,17 +880,13 @@ pub fn sklearn_classification_metrics(
y_hat: Vec<f32>,
num_classes: i64,
) -> JsonB {
JsonB(
serde_json::from_str(
&serde_json::to_string(&crate::bindings::sklearn::classification_metrics(
&ground_truth,
&y_hat,
num_classes as usize,
))
.unwrap(),
)
.unwrap(),
)
let metrics = unwrap_or_error!(crate::bindings::sklearn::classification_metrics(
&ground_truth,
&y_hat,
num_classes as _
));

JsonB(json!(metrics))
}

#[pg_extern]
Expand Down
28 changes: 10 additions & 18 deletions pgml-extension/src/bindings/langchain.rs
Original file line number Diff line number Diff line change
@@ -1,37 +1,29 @@
use anyhow::Result;
use once_cell::sync::Lazy;
use pgrx::*;
use pyo3::prelude::*;
use pyo3::types::PyTuple;

static PY_MODULE: Lazy<Py<PyModule>> = Lazy::new(|| {
Python::with_gil(|py| -> Py<PyModule> {
let src = include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/src/bindings/langchain.py"
));
use crate::{bindings::TracebackError, create_pymodule};

PyModule::from_code(py, src, "", "").unwrap().into()
})
});
create_pymodule!("/src/bindings/langchain.py");

pub fn chunk(splitter: &str, text: &str, kwargs: &serde_json::Value) -> Vec<String> {
crate::bindings::venv::activate();
pub fn chunk(splitter: &str, text: &str, kwargs: &serde_json::Value) -> Result<Vec<String>> {
crate::bindings::venv::activate()?;

let kwargs = serde_json::to_string(kwargs).unwrap();

Python::with_gil(|py| -> Vec<String> {
let chunk: Py<PyAny> = PY_MODULE.getattr(py, "chunk").unwrap();
Python::with_gil(|py| -> Result<Vec<String>> {
let chunk: Py<PyAny> = get_module!(PY_MODULE).getattr(py, "chunk")?;

chunk
Ok(chunk
.call1(
py,
PyTuple::new(
py,
&[splitter.into_py(py), text.into_py(py), kwargs.into_py(py)],
),
)
.unwrap()
.extract(py)
.unwrap()
)?
.extract(py)?)
})
}
59 changes: 34 additions & 25 deletions pgml-extension/src/bindings/lightgbm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ use crate::bindings::Bindings;
use crate::orm::dataset::Dataset;
use crate::orm::task::Task;
use crate::orm::Hyperparams;

use anyhow::Result;
use lightgbm;
use pgrx::*;
use serde_json::json;
Expand All @@ -22,15 +24,18 @@ impl std::fmt::Debug for Estimator {
}
}

pub fn fit_regression(dataset: &Dataset, hyperparams: &Hyperparams) -> Box<dyn Bindings> {
pub fn fit_regression(dataset: &Dataset, hyperparams: &Hyperparams) -> Result<Box<dyn Bindings>> {
fit(dataset, hyperparams, Task::regression)
}

pub fn fit_classification(dataset: &Dataset, hyperparams: &Hyperparams) -> Box<dyn Bindings> {
pub fn fit_classification(
dataset: &Dataset,
hyperparams: &Hyperparams,
) -> Result<Box<dyn Bindings>> {
fit(dataset, hyperparams, Task::classification)
}

fn fit(dataset: &Dataset, hyperparams: &Hyperparams, task: Task) -> Box<dyn Bindings> {
fn fit(dataset: &Dataset, hyperparams: &Hyperparams, task: Task) -> Result<Box<dyn Bindings>> {
let mut hyperparams = hyperparams.clone();
match task {
Task::regression => {
Expand Down Expand Up @@ -65,14 +70,19 @@ fn fit(dataset: &Dataset, hyperparams: &Hyperparams, task: Task) -> Box<dyn Bind

let estimator = lightgbm::Booster::train(data, &json! {hyperparams}).unwrap();

Box::new(Estimator { estimator })
Ok(Box::new(Estimator { estimator }))
}

impl Bindings for Estimator {
/// Predict a set of datapoints.
fn predict(&self, features: &[f32], num_features: usize, num_classes: usize) -> Vec<f32> {
let results = self.predict_proba(features, num_features);
match num_classes {
fn predict(
&self,
features: &[f32],
num_features: usize,
num_classes: usize,
) -> Result<Vec<f32>> {
let results = self.predict_proba(features, num_features)?;
Ok(match num_classes {
// TODO make lightgbm predict both classes like scikit and xgboost
0 => results,
2 => results.iter().map(|i| i.round()).collect(),
Expand All @@ -87,47 +97,46 @@ impl Bindings for Estimator {
.unwrap() as f32
})
.collect(),
}
})
}

// Predict the raw probability of classes for a classifier.
fn predict_proba(&self, features: &[f32], num_features: usize) -> Vec<f32> {
self.estimator
.predict(features, num_features as i32)
.unwrap()
fn predict_proba(&self, features: &[f32], num_features: usize) -> Result<Vec<f32>> {
Ok(self
.estimator
.predict(features, num_features as i32)?
.into_iter()
.map(|i| i as f32)
.collect()
.collect())
}

/// Serialize self to bytes
fn to_bytes(&self) -> Vec<u8> {
fn to_bytes(&self) -> Result<Vec<u8>> {
let r: u64 = rand::random();
let path = format!("/tmp/pgml_{}.bin", r);
self.estimator.save_file(&path).unwrap();
let bytes = std::fs::read(&path).unwrap();
std::fs::remove_file(&path).unwrap();
self.estimator.save_file(&path)?;
let bytes = std::fs::read(&path)?;
std::fs::remove_file(&path)?;

bytes
Ok(bytes)
}

/// Deserialize self from bytes, with additional context
fn from_bytes(bytes: &[u8]) -> Box<dyn Bindings>
fn from_bytes(bytes: &[u8]) -> Result<Box<dyn Bindings>>
where
Self: Sized,
{
let r: u64 = rand::random();
let path = format!("/tmp/pgml_{}.bin", r);
std::fs::write(&path, bytes).unwrap();
std::fs::write(&path, bytes)?;
let mut estimator = lightgbm::Booster::from_file(&path);
if estimator.is_err() {
// backward compatibility w/ 2.0.0
std::fs::write(&path, &bytes[16..]).unwrap();
std::fs::write(&path, &bytes[16..])?;
estimator = lightgbm::Booster::from_file(&path);
}
std::fs::remove_file(&path).unwrap();
Box::new(Estimator {
estimator: estimator.unwrap(),
})
std::fs::remove_file(&path)?;
let estimator = estimator?;
Ok(Box::new(Estimator { estimator }))
}
}
Loading