Skip to content

Commit 88ca844

Browse files
authored
add Result to lazy pymodules then trickled up... (#925)
1 parent 8fcb39a commit 88ca844

File tree

11 files changed

+444
-414
lines changed

11 files changed

+444
-414
lines changed

pgml-extension/src/api.rs

Lines changed: 51 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,25 @@ use serde_json::json;
1313
use crate::bindings::sklearn::package_version;
1414
use crate::orm::*;
1515

16+
macro_rules! unwrap_or_error {
17+
($i:expr) => {
18+
match $i {
19+
Ok(v) => v,
20+
Err(e) => error!("{e}"),
21+
}
22+
};
23+
}
24+
1625
#[cfg(feature = "python")]
1726
#[pg_extern]
1827
pub fn activate_venv(venv: &str) -> bool {
19-
crate::bindings::venv::activate_venv(venv)
28+
unwrap_or_error!(crate::bindings::venv::activate_venv(venv))
2029
}
2130

2231
#[cfg(feature = "python")]
2332
#[pg_extern(immutable, parallel_safe)]
2433
pub fn validate_python_dependencies() -> bool {
25-
crate::bindings::venv::activate();
34+
unwrap_or_error!(crate::bindings::venv::activate());
2635

2736
Python::with_gil(|py| {
2837
let sys = PyModule::import(py, "sys").unwrap();
@@ -40,13 +49,12 @@ pub fn validate_python_dependencies() -> bool {
4049
}
4150
});
4251

43-
info!(
44-
"Scikit-learn {}, XGBoost {}, LightGBM {}, NumPy {}",
45-
package_version("sklearn"),
46-
package_version("xgboost"),
47-
package_version("lightgbm"),
48-
package_version("numpy"),
49-
);
52+
let sklearn = unwrap_or_error!(package_version("sklearn"));
53+
let xgboost = unwrap_or_error!(package_version("xgboost"));
54+
let lightgbm = unwrap_or_error!(package_version("lightgbm"));
55+
let numpy = unwrap_or_error!(package_version("numpy"));
56+
57+
info!("Scikit-learn {sklearn}, XGBoost {xgboost}, LightGBM {lightgbm}, NumPy {numpy}",);
5058

5159
true
5260
}
@@ -58,8 +66,8 @@ pub fn validate_python_dependencies() {}
5866
#[cfg(feature = "python")]
5967
#[pg_extern]
6068
pub fn python_package_version(name: &str) -> String {
61-
crate::bindings::venv::activate();
62-
package_version(name)
69+
unwrap_or_error!(crate::bindings::venv::activate());
70+
unwrap_or_error!(package_version(name))
6371
}
6472

6573
#[cfg(not(feature = "python"))]
@@ -71,9 +79,9 @@ pub fn python_package_version(name: &str) {
7179
#[cfg(feature = "python")]
7280
#[pg_extern]
7381
pub fn python_pip_freeze() -> TableIterator<'static, (name!(package, String),)> {
74-
crate::bindings::venv::activate();
82+
unwrap_or_error!(crate::bindings::venv::activate());
7583

76-
let packages = crate::bindings::venv::freeze()
84+
let packages = unwrap_or_error!(crate::bindings::venv::freeze())
7785
.into_iter()
7886
.map(|package| (package,));
7987

@@ -99,7 +107,7 @@ pub fn validate_shared_library() {
99107
#[cfg(feature = "python")]
100108
#[pg_extern]
101109
fn python_version() -> String {
102-
crate::bindings::venv::activate();
110+
unwrap_or_error!(crate::bindings::venv::activate());
103111
let mut version = String::new();
104112

105113
Python::with_gil(|py| {
@@ -479,27 +487,31 @@ fn predict_row(project_name: &str, row: pgrx::datum::AnyElement) -> f32 {
479487

480488
#[pg_extern(immutable, parallel_safe, strict, name = "predict")]
481489
fn predict_model(model_id: i64, features: Vec<f32>) -> f32 {
482-
Model::find_cached(model_id).predict(&features)
490+
let model = unwrap_or_error!(Model::find_cached(model_id));
491+
unwrap_or_error!(model.predict(&features))
483492
}
484493

485494
#[pg_extern(immutable, parallel_safe, strict, name = "predict_proba")]
486495
fn predict_model_proba(model_id: i64, features: Vec<f32>) -> Vec<f32> {
487-
Model::find_cached(model_id).predict_proba(&features)
496+
let model = unwrap_or_error!(Model::find_cached(model_id));
497+
unwrap_or_error!(model.predict_proba(&features))
488498
}
489499

490500
#[pg_extern(immutable, parallel_safe, strict, name = "predict_joint")]
491501
fn predict_model_joint(model_id: i64, features: Vec<f32>) -> Vec<f32> {
492-
Model::find_cached(model_id).predict_joint(&features)
502+
let model = unwrap_or_error!(Model::find_cached(model_id));
503+
unwrap_or_error!(model.predict_joint(&features))
493504
}
494505

495506
#[pg_extern(immutable, parallel_safe, strict, name = "predict_batch")]
496507
fn predict_model_batch(model_id: i64, features: Vec<f32>) -> Vec<f32> {
497-
Model::find_cached(model_id).predict_batch(&features)
508+
let model = unwrap_or_error!(Model::find_cached(model_id));
509+
unwrap_or_error!(model.predict_batch(&features))
498510
}
499511

500512
#[pg_extern(immutable, parallel_safe, strict, name = "predict")]
501513
fn predict_model_row(model_id: i64, row: pgrx::datum::AnyElement) -> f32 {
502-
let model = Model::find_cached(model_id);
514+
let model = unwrap_or_error!(Model::find_cached(model_id));
503515
let snapshot = &model.snapshot;
504516
let numeric_encoded_features = model.numeric_encode_features(&[row]);
505517
let features_width = snapshot.features_width();
@@ -514,7 +526,7 @@ fn predict_model_row(model_id: i64, row: pgrx::datum::AnyElement) -> f32 {
514526
let column = &snapshot.columns[position.column_position - 1];
515527
column.preprocess(&data, &mut processed, features_width, position.row_position);
516528
});
517-
model.predict(&processed)
529+
unwrap_or_error!(model.predict(&processed))
518530
}
519531

520532
#[pg_extern]
@@ -617,7 +629,11 @@ pub fn chunk(
617629
text: &str,
618630
kwargs: default!(JsonB, "'{}'"),
619631
) -> TableIterator<'static, (name!(chunk_index, i64), name!(chunk, String))> {
620-
let chunks = crate::bindings::langchain::chunk(splitter, text, &kwargs.0);
632+
let chunks = match crate::bindings::langchain::chunk(splitter, text, &kwargs.0) {
633+
Ok(chunks) => chunks,
634+
Err(e) => error!("{e}"),
635+
};
636+
621637
let chunks = chunks
622638
.into_iter()
623639
.enumerate()
@@ -838,28 +854,23 @@ fn tune(
838854
#[cfg(feature = "python")]
839855
#[pg_extern(name = "sklearn_f1_score")]
840856
pub fn sklearn_f1_score(ground_truth: Vec<f32>, y_hat: Vec<f32>) -> f32 {
841-
crate::bindings::sklearn::f1(&ground_truth, &y_hat)
857+
unwrap_or_error!(crate::bindings::sklearn::f1(&ground_truth, &y_hat))
842858
}
843859

844860
#[cfg(feature = "python")]
845861
#[pg_extern(name = "sklearn_r2_score")]
846862
pub fn sklearn_r2_score(ground_truth: Vec<f32>, y_hat: Vec<f32>) -> f32 {
847-
crate::bindings::sklearn::r2(&ground_truth, &y_hat)
863+
unwrap_or_error!(crate::bindings::sklearn::r2(&ground_truth, &y_hat))
848864
}
849865

850866
#[cfg(feature = "python")]
851867
#[pg_extern(name = "sklearn_regression_metrics")]
852868
pub fn sklearn_regression_metrics(ground_truth: Vec<f32>, y_hat: Vec<f32>) -> JsonB {
853-
JsonB(
854-
serde_json::from_str(
855-
&serde_json::to_string(&crate::bindings::sklearn::regression_metrics(
856-
&ground_truth,
857-
&y_hat,
858-
))
859-
.unwrap(),
860-
)
861-
.unwrap(),
862-
)
869+
let metrics = unwrap_or_error!(crate::bindings::sklearn::regression_metrics(
870+
&ground_truth,
871+
&y_hat,
872+
));
873+
JsonB(json!(metrics))
863874
}
864875

865876
#[cfg(feature = "python")]
@@ -869,17 +880,13 @@ pub fn sklearn_classification_metrics(
869880
y_hat: Vec<f32>,
870881
num_classes: i64,
871882
) -> JsonB {
872-
JsonB(
873-
serde_json::from_str(
874-
&serde_json::to_string(&crate::bindings::sklearn::classification_metrics(
875-
&ground_truth,
876-
&y_hat,
877-
num_classes as usize,
878-
))
879-
.unwrap(),
880-
)
881-
.unwrap(),
882-
)
883+
let metrics = unwrap_or_error!(crate::bindings::sklearn::classification_metrics(
884+
&ground_truth,
885+
&y_hat,
886+
num_classes as _
887+
));
888+
889+
JsonB(json!(metrics))
883890
}
884891

885892
#[pg_extern]
Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,29 @@
1+
use anyhow::Result;
12
use once_cell::sync::Lazy;
23
use pgrx::*;
34
use pyo3::prelude::*;
45
use pyo3::types::PyTuple;
56

6-
static PY_MODULE: Lazy<Py<PyModule>> = Lazy::new(|| {
7-
Python::with_gil(|py| -> Py<PyModule> {
8-
let src = include_str!(concat!(
9-
env!("CARGO_MANIFEST_DIR"),
10-
"/src/bindings/langchain.py"
11-
));
7+
use crate::{bindings::TracebackError, create_pymodule};
128

13-
PyModule::from_code(py, src, "", "").unwrap().into()
14-
})
15-
});
9+
create_pymodule!("/src/bindings/langchain.py");
1610

17-
pub fn chunk(splitter: &str, text: &str, kwargs: &serde_json::Value) -> Vec<String> {
18-
crate::bindings::venv::activate();
11+
pub fn chunk(splitter: &str, text: &str, kwargs: &serde_json::Value) -> Result<Vec<String>> {
12+
crate::bindings::venv::activate()?;
1913

2014
let kwargs = serde_json::to_string(kwargs).unwrap();
2115

22-
Python::with_gil(|py| -> Vec<String> {
23-
let chunk: Py<PyAny> = PY_MODULE.getattr(py, "chunk").unwrap();
16+
Python::with_gil(|py| -> Result<Vec<String>> {
17+
let chunk: Py<PyAny> = get_module!(PY_MODULE).getattr(py, "chunk")?;
2418

25-
chunk
19+
Ok(chunk
2620
.call1(
2721
py,
2822
PyTuple::new(
2923
py,
3024
&[splitter.into_py(py), text.into_py(py), kwargs.into_py(py)],
3125
),
32-
)
33-
.unwrap()
34-
.extract(py)
35-
.unwrap()
26+
)?
27+
.extract(py)?)
3628
})
3729
}

pgml-extension/src/bindings/lightgbm.rs

Lines changed: 34 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ use crate::bindings::Bindings;
22
use crate::orm::dataset::Dataset;
33
use crate::orm::task::Task;
44
use crate::orm::Hyperparams;
5+
6+
use anyhow::Result;
57
use lightgbm;
68
use pgrx::*;
79
use serde_json::json;
@@ -22,15 +24,18 @@ impl std::fmt::Debug for Estimator {
2224
}
2325
}
2426

25-
pub fn fit_regression(dataset: &Dataset, hyperparams: &Hyperparams) -> Box<dyn Bindings> {
27+
pub fn fit_regression(dataset: &Dataset, hyperparams: &Hyperparams) -> Result<Box<dyn Bindings>> {
2628
fit(dataset, hyperparams, Task::regression)
2729
}
2830

29-
pub fn fit_classification(dataset: &Dataset, hyperparams: &Hyperparams) -> Box<dyn Bindings> {
31+
pub fn fit_classification(
32+
dataset: &Dataset,
33+
hyperparams: &Hyperparams,
34+
) -> Result<Box<dyn Bindings>> {
3035
fit(dataset, hyperparams, Task::classification)
3136
}
3237

33-
fn fit(dataset: &Dataset, hyperparams: &Hyperparams, task: Task) -> Box<dyn Bindings> {
38+
fn fit(dataset: &Dataset, hyperparams: &Hyperparams, task: Task) -> Result<Box<dyn Bindings>> {
3439
let mut hyperparams = hyperparams.clone();
3540
match task {
3641
Task::regression => {
@@ -65,14 +70,19 @@ fn fit(dataset: &Dataset, hyperparams: &Hyperparams, task: Task) -> Box<dyn Bind
6570

6671
let estimator = lightgbm::Booster::train(data, &json! {hyperparams}).unwrap();
6772

68-
Box::new(Estimator { estimator })
73+
Ok(Box::new(Estimator { estimator }))
6974
}
7075

7176
impl Bindings for Estimator {
7277
/// Predict a set of datapoints.
73-
fn predict(&self, features: &[f32], num_features: usize, num_classes: usize) -> Vec<f32> {
74-
let results = self.predict_proba(features, num_features);
75-
match num_classes {
78+
fn predict(
79+
&self,
80+
features: &[f32],
81+
num_features: usize,
82+
num_classes: usize,
83+
) -> Result<Vec<f32>> {
84+
let results = self.predict_proba(features, num_features)?;
85+
Ok(match num_classes {
7686
// TODO make lightgbm predict both classes like scikit and xgboost
7787
0 => results,
7888
2 => results.iter().map(|i| i.round()).collect(),
@@ -87,47 +97,46 @@ impl Bindings for Estimator {
8797
.unwrap() as f32
8898
})
8999
.collect(),
90-
}
100+
})
91101
}
92102

93103
// Predict the raw probability of classes for a classifier.
94-
fn predict_proba(&self, features: &[f32], num_features: usize) -> Vec<f32> {
95-
self.estimator
96-
.predict(features, num_features as i32)
97-
.unwrap()
104+
fn predict_proba(&self, features: &[f32], num_features: usize) -> Result<Vec<f32>> {
105+
Ok(self
106+
.estimator
107+
.predict(features, num_features as i32)?
98108
.into_iter()
99109
.map(|i| i as f32)
100-
.collect()
110+
.collect())
101111
}
102112

103113
/// Serialize self to bytes
104-
fn to_bytes(&self) -> Vec<u8> {
114+
fn to_bytes(&self) -> Result<Vec<u8>> {
105115
let r: u64 = rand::random();
106116
let path = format!("/tmp/pgml_{}.bin", r);
107-
self.estimator.save_file(&path).unwrap();
108-
let bytes = std::fs::read(&path).unwrap();
109-
std::fs::remove_file(&path).unwrap();
117+
self.estimator.save_file(&path)?;
118+
let bytes = std::fs::read(&path)?;
119+
std::fs::remove_file(&path)?;
110120

111-
bytes
121+
Ok(bytes)
112122
}
113123

114124
/// Deserialize self from bytes, with additional context
115-
fn from_bytes(bytes: &[u8]) -> Box<dyn Bindings>
125+
fn from_bytes(bytes: &[u8]) -> Result<Box<dyn Bindings>>
116126
where
117127
Self: Sized,
118128
{
119129
let r: u64 = rand::random();
120130
let path = format!("/tmp/pgml_{}.bin", r);
121-
std::fs::write(&path, bytes).unwrap();
131+
std::fs::write(&path, bytes)?;
122132
let mut estimator = lightgbm::Booster::from_file(&path);
123133
if estimator.is_err() {
124134
// backward compatibility w/ 2.0.0
125-
std::fs::write(&path, &bytes[16..]).unwrap();
135+
std::fs::write(&path, &bytes[16..])?;
126136
estimator = lightgbm::Booster::from_file(&path);
127137
}
128-
std::fs::remove_file(&path).unwrap();
129-
Box::new(Estimator {
130-
estimator: estimator.unwrap(),
131-
})
138+
std::fs::remove_file(&path)?;
139+
let estimator = estimator?;
140+
Ok(Box::new(Estimator { estimator }))
132141
}
133142
}

0 commit comments

Comments
 (0)