@@ -13,16 +13,25 @@ use serde_json::json;
13
13
use crate :: bindings:: sklearn:: package_version;
14
14
use crate :: orm:: * ;
15
15
16
+ macro_rules! unwrap_or_error {
17
+ ( $i: expr) => {
18
+ match $i {
19
+ Ok ( v) => v,
20
+ Err ( e) => error!( "{e}" ) ,
21
+ }
22
+ } ;
23
+ }
24
+
16
25
#[ cfg( feature = "python" ) ]
17
26
#[ pg_extern]
18
27
pub fn activate_venv ( venv : & str ) -> bool {
19
- crate :: bindings:: venv:: activate_venv ( venv)
28
+ unwrap_or_error ! ( crate :: bindings:: venv:: activate_venv( venv) )
20
29
}
21
30
22
31
#[ cfg( feature = "python" ) ]
23
32
#[ pg_extern( immutable, parallel_safe) ]
24
33
pub fn validate_python_dependencies ( ) -> bool {
25
- crate :: bindings:: venv:: activate ( ) ;
34
+ unwrap_or_error ! ( crate :: bindings:: venv:: activate( ) ) ;
26
35
27
36
Python :: with_gil ( |py| {
28
37
let sys = PyModule :: import ( py, "sys" ) . unwrap ( ) ;
@@ -40,13 +49,12 @@ pub fn validate_python_dependencies() -> bool {
40
49
}
41
50
} ) ;
42
51
43
- info ! (
44
- "Scikit-learn {}, XGBoost {}, LightGBM {}, NumPy {}" ,
45
- package_version( "sklearn" ) ,
46
- package_version( "xgboost" ) ,
47
- package_version( "lightgbm" ) ,
48
- package_version( "numpy" ) ,
49
- ) ;
52
+ let sklearn = unwrap_or_error ! ( package_version( "sklearn" ) ) ;
53
+ let xgboost = unwrap_or_error ! ( package_version( "xgboost" ) ) ;
54
+ let lightgbm = unwrap_or_error ! ( package_version( "lightgbm" ) ) ;
55
+ let numpy = unwrap_or_error ! ( package_version( "numpy" ) ) ;
56
+
57
+ info ! ( "Scikit-learn {sklearn}, XGBoost {xgboost}, LightGBM {lightgbm}, NumPy {numpy}" , ) ;
50
58
51
59
true
52
60
}
@@ -58,8 +66,8 @@ pub fn validate_python_dependencies() {}
58
66
#[ cfg( feature = "python" ) ]
59
67
#[ pg_extern]
60
68
pub fn python_package_version ( name : & str ) -> String {
61
- crate :: bindings:: venv:: activate ( ) ;
62
- package_version ( name)
69
+ unwrap_or_error ! ( crate :: bindings:: venv:: activate( ) ) ;
70
+ unwrap_or_error ! ( package_version( name) )
63
71
}
64
72
65
73
#[ cfg( not( feature = "python" ) ) ]
@@ -71,9 +79,9 @@ pub fn python_package_version(name: &str) {
71
79
#[ cfg( feature = "python" ) ]
72
80
#[ pg_extern]
73
81
pub fn python_pip_freeze ( ) -> TableIterator < ' static , ( name ! ( package, String ) , ) > {
74
- crate :: bindings:: venv:: activate ( ) ;
82
+ unwrap_or_error ! ( crate :: bindings:: venv:: activate( ) ) ;
75
83
76
- let packages = crate :: bindings:: venv:: freeze ( )
84
+ let packages = unwrap_or_error ! ( crate :: bindings:: venv:: freeze( ) )
77
85
. into_iter ( )
78
86
. map ( |package| ( package, ) ) ;
79
87
@@ -99,7 +107,7 @@ pub fn validate_shared_library() {
99
107
#[ cfg( feature = "python" ) ]
100
108
#[ pg_extern]
101
109
fn python_version ( ) -> String {
102
- crate :: bindings:: venv:: activate ( ) ;
110
+ unwrap_or_error ! ( crate :: bindings:: venv:: activate( ) ) ;
103
111
let mut version = String :: new ( ) ;
104
112
105
113
Python :: with_gil ( |py| {
@@ -479,27 +487,31 @@ fn predict_row(project_name: &str, row: pgrx::datum::AnyElement) -> f32 {
479
487
480
488
#[ pg_extern( immutable, parallel_safe, strict, name = "predict" ) ]
481
489
fn predict_model ( model_id : i64 , features : Vec < f32 > ) -> f32 {
482
- Model :: find_cached ( model_id) . predict ( & features)
490
+ let model = unwrap_or_error ! ( Model :: find_cached( model_id) ) ;
491
+ unwrap_or_error ! ( model. predict( & features) )
483
492
}
484
493
485
494
#[ pg_extern( immutable, parallel_safe, strict, name = "predict_proba" ) ]
486
495
fn predict_model_proba ( model_id : i64 , features : Vec < f32 > ) -> Vec < f32 > {
487
- Model :: find_cached ( model_id) . predict_proba ( & features)
496
+ let model = unwrap_or_error ! ( Model :: find_cached( model_id) ) ;
497
+ unwrap_or_error ! ( model. predict_proba( & features) )
488
498
}
489
499
490
500
#[ pg_extern( immutable, parallel_safe, strict, name = "predict_joint" ) ]
491
501
fn predict_model_joint ( model_id : i64 , features : Vec < f32 > ) -> Vec < f32 > {
492
- Model :: find_cached ( model_id) . predict_joint ( & features)
502
+ let model = unwrap_or_error ! ( Model :: find_cached( model_id) ) ;
503
+ unwrap_or_error ! ( model. predict_joint( & features) )
493
504
}
494
505
495
506
#[ pg_extern( immutable, parallel_safe, strict, name = "predict_batch" ) ]
496
507
fn predict_model_batch ( model_id : i64 , features : Vec < f32 > ) -> Vec < f32 > {
497
- Model :: find_cached ( model_id) . predict_batch ( & features)
508
+ let model = unwrap_or_error ! ( Model :: find_cached( model_id) ) ;
509
+ unwrap_or_error ! ( model. predict_batch( & features) )
498
510
}
499
511
500
512
#[ pg_extern( immutable, parallel_safe, strict, name = "predict" ) ]
501
513
fn predict_model_row ( model_id : i64 , row : pgrx:: datum:: AnyElement ) -> f32 {
502
- let model = Model :: find_cached ( model_id) ;
514
+ let model = unwrap_or_error ! ( Model :: find_cached( model_id) ) ;
503
515
let snapshot = & model. snapshot ;
504
516
let numeric_encoded_features = model. numeric_encode_features ( & [ row] ) ;
505
517
let features_width = snapshot. features_width ( ) ;
@@ -514,7 +526,7 @@ fn predict_model_row(model_id: i64, row: pgrx::datum::AnyElement) -> f32 {
514
526
let column = & snapshot. columns [ position. column_position - 1 ] ;
515
527
column. preprocess ( & data, & mut processed, features_width, position. row_position ) ;
516
528
} ) ;
517
- model. predict ( & processed)
529
+ unwrap_or_error ! ( model. predict( & processed) )
518
530
}
519
531
520
532
#[ pg_extern]
@@ -617,7 +629,11 @@ pub fn chunk(
617
629
text : & str ,
618
630
kwargs : default ! ( JsonB , "'{}'" ) ,
619
631
) -> TableIterator < ' static , ( name ! ( chunk_index, i64 ) , name ! ( chunk, String ) ) > {
620
- let chunks = crate :: bindings:: langchain:: chunk ( splitter, text, & kwargs. 0 ) ;
632
+ let chunks = match crate :: bindings:: langchain:: chunk ( splitter, text, & kwargs. 0 ) {
633
+ Ok ( chunks) => chunks,
634
+ Err ( e) => error ! ( "{e}" ) ,
635
+ } ;
636
+
621
637
let chunks = chunks
622
638
. into_iter ( )
623
639
. enumerate ( )
@@ -838,28 +854,23 @@ fn tune(
838
854
#[ cfg( feature = "python" ) ]
839
855
#[ pg_extern( name = "sklearn_f1_score" ) ]
840
856
pub fn sklearn_f1_score ( ground_truth : Vec < f32 > , y_hat : Vec < f32 > ) -> f32 {
841
- crate :: bindings:: sklearn:: f1 ( & ground_truth, & y_hat)
857
+ unwrap_or_error ! ( crate :: bindings:: sklearn:: f1( & ground_truth, & y_hat) )
842
858
}
843
859
844
860
#[ cfg( feature = "python" ) ]
845
861
#[ pg_extern( name = "sklearn_r2_score" ) ]
846
862
pub fn sklearn_r2_score ( ground_truth : Vec < f32 > , y_hat : Vec < f32 > ) -> f32 {
847
- crate :: bindings:: sklearn:: r2 ( & ground_truth, & y_hat)
863
+ unwrap_or_error ! ( crate :: bindings:: sklearn:: r2( & ground_truth, & y_hat) )
848
864
}
849
865
850
866
#[ cfg( feature = "python" ) ]
851
867
#[ pg_extern( name = "sklearn_regression_metrics" ) ]
852
868
pub fn sklearn_regression_metrics ( ground_truth : Vec < f32 > , y_hat : Vec < f32 > ) -> JsonB {
853
- JsonB (
854
- serde_json:: from_str (
855
- & serde_json:: to_string ( & crate :: bindings:: sklearn:: regression_metrics (
856
- & ground_truth,
857
- & y_hat,
858
- ) )
859
- . unwrap ( ) ,
860
- )
861
- . unwrap ( ) ,
862
- )
869
+ let metrics = unwrap_or_error ! ( crate :: bindings:: sklearn:: regression_metrics(
870
+ & ground_truth,
871
+ & y_hat,
872
+ ) ) ;
873
+ JsonB ( json ! ( metrics) )
863
874
}
864
875
865
876
#[ cfg( feature = "python" ) ]
@@ -869,17 +880,13 @@ pub fn sklearn_classification_metrics(
869
880
y_hat : Vec < f32 > ,
870
881
num_classes : i64 ,
871
882
) -> JsonB {
872
- JsonB (
873
- serde_json:: from_str (
874
- & serde_json:: to_string ( & crate :: bindings:: sklearn:: classification_metrics (
875
- & ground_truth,
876
- & y_hat,
877
- num_classes as usize ,
878
- ) )
879
- . unwrap ( ) ,
880
- )
881
- . unwrap ( ) ,
882
- )
883
+ let metrics = unwrap_or_error ! ( crate :: bindings:: sklearn:: classification_metrics(
884
+ & ground_truth,
885
+ & y_hat,
886
+ num_classes as _
887
+ ) ) ;
888
+
889
+ JsonB ( json ! ( metrics) )
883
890
}
884
891
885
892
#[ pg_extern]
0 commit comments