@@ -8,9 +8,12 @@ use once_cell::sync::Lazy;
8
8
use pgrx:: * ;
9
9
use pyo3:: prelude:: * ;
10
10
use pyo3:: types:: PyTuple ;
11
+ use serde_json:: Value ;
11
12
12
13
use crate :: orm:: { Task , TextDataset } ;
13
14
15
+ use super :: TracebackError ;
16
+
14
17
pub mod whitelist;
15
18
16
19
static PY_MODULE : Lazy < Py < PyModule > > = Lazy :: new ( || {
@@ -38,22 +41,36 @@ pub fn transform(
38
41
let inputs = serde_json:: to_string ( & inputs) ?;
39
42
40
43
let results = Python :: with_gil ( |py| -> Result < String > {
41
- let transform: Py < PyAny > = PY_MODULE . getattr ( py, "transform" ) ?;
44
+ let transform: Py < PyAny > = PY_MODULE . getattr ( py, "transform" ) . format_traceback ( py ) ?;
42
45
43
- let output = transform. call1 (
44
- py,
45
- PyTuple :: new (
46
+ let output = transform
47
+ . call1 (
46
48
py,
47
- & [ task. into_py ( py) , args. into_py ( py) , inputs. into_py ( py) ] ,
48
- ) ,
49
- ) ?;
49
+ PyTuple :: new (
50
+ py,
51
+ & [ task. into_py ( py) , args. into_py ( py) , inputs. into_py ( py) ] ,
52
+ ) ,
53
+ )
54
+ . format_traceback ( py) ?;
50
55
51
- Ok ( output. extract ( py) ?)
56
+ Ok ( output. extract ( py) . format_traceback ( py ) ?)
52
57
} ) ?;
53
58
54
59
Ok ( serde_json:: from_str ( & results) ?)
55
60
}
56
61
62
+ pub fn get_model_from ( task : & Value ) -> Result < String > {
63
+ Ok ( Python :: with_gil ( |py| -> Result < String > {
64
+ let get_model_from = PY_MODULE
65
+ . getattr ( py, "get_model_from" )
66
+ . format_traceback ( py) ?;
67
+ let model = get_model_from
68
+ . call1 ( py, PyTuple :: new ( py, & [ task. to_string ( ) . into_py ( py) ] ) )
69
+ . format_traceback ( py) ?;
70
+ Ok ( model. extract ( py) . format_traceback ( py) ?)
71
+ } ) ?)
72
+ }
73
+
57
74
pub fn embed (
58
75
transformer : & str ,
59
76
inputs : Vec < & str > ,
@@ -63,20 +80,22 @@ pub fn embed(
63
80
64
81
let kwargs = serde_json:: to_string ( kwargs) ?;
65
82
Python :: with_gil ( |py| -> Result < Vec < Vec < f32 > > > {
66
- let embed: Py < PyAny > = PY_MODULE . getattr ( py, "embed" ) ?;
67
- let output = embed. call1 (
68
- py,
69
- PyTuple :: new (
83
+ let embed: Py < PyAny > = PY_MODULE . getattr ( py, "embed" ) . format_traceback ( py) ?;
84
+ let output = embed
85
+ . call1 (
70
86
py,
71
- & [
72
- transformer. to_string ( ) . into_py ( py) ,
73
- inputs. into_py ( py) ,
74
- kwargs. into_py ( py) ,
75
- ] ,
76
- ) ,
77
- ) ?;
78
-
79
- Ok ( output. extract ( py) ?)
87
+ PyTuple :: new (
88
+ py,
89
+ & [
90
+ transformer. to_string ( ) . into_py ( py) ,
91
+ inputs. into_py ( py) ,
92
+ kwargs. into_py ( py) ,
93
+ ] ,
94
+ ) ,
95
+ )
96
+ . format_traceback ( py) ?;
97
+
98
+ Ok ( output. extract ( py) . format_traceback ( py) ?)
80
99
} )
81
100
}
82
101
@@ -92,30 +111,32 @@ pub fn tune(
92
111
let hyperparams = serde_json:: to_string ( & hyperparams. 0 ) ?;
93
112
94
113
Python :: with_gil ( |py| -> Result < HashMap < String , f64 > > {
95
- let tune = PY_MODULE . getattr ( py, "tune" ) ?;
114
+ let tune = PY_MODULE . getattr ( py, "tune" ) . format_traceback ( py ) ?;
96
115
let path = path. to_string_lossy ( ) ;
97
- let output = tune. call1 (
98
- py,
99
- (
100
- & task,
101
- & hyperparams,
102
- path. as_ref ( ) ,
103
- dataset. x_train ,
104
- dataset. x_test ,
105
- dataset. y_train ,
106
- dataset. y_test ,
107
- ) ,
108
- ) ?;
109
-
110
- Ok ( output. extract ( py) ?)
116
+ let output = tune
117
+ . call1 (
118
+ py,
119
+ (
120
+ & task,
121
+ & hyperparams,
122
+ path. as_ref ( ) ,
123
+ dataset. x_train ,
124
+ dataset. x_test ,
125
+ dataset. y_train ,
126
+ dataset. y_test ,
127
+ ) ,
128
+ )
129
+ . format_traceback ( py) ?;
130
+
131
+ Ok ( output. extract ( py) . format_traceback ( py) ?)
111
132
} )
112
133
}
113
134
114
135
pub fn generate ( model_id : i64 , inputs : Vec < & str > , config : JsonB ) -> Result < Vec < String > > {
115
136
crate :: bindings:: venv:: activate ( ) ;
116
137
117
138
Python :: with_gil ( |py| -> Result < Vec < String > > {
118
- let generate = PY_MODULE . getattr ( py, "generate" ) ?;
139
+ let generate = PY_MODULE . getattr ( py, "generate" ) . format_traceback ( py ) ?;
119
140
let config = serde_json:: to_string ( & config. 0 ) ?;
120
141
// cloning inputs in case we have to re-call on error is rather unfortunate here
121
142
// similarly, using a json string to pass kwargs is also unfortunate extra parsing
@@ -143,16 +164,19 @@ pub fn generate(model_id: i64, inputs: Vec<&str>, config: JsonB) -> Result<Vec<S
143
164
let load = PY_MODULE . getattr ( py, "load_model" ) ?;
144
165
let task = Task :: from_str ( & task)
145
166
. map_err ( |_| anyhow ! ( "could not make a Task from {task}" ) ) ?;
146
- load. call1 ( py, ( model_id, task. to_string ( ) , dir) ) ?;
167
+ load. call1 ( py, ( model_id, task. to_string ( ) , dir) )
168
+ . format_traceback ( py) ?;
147
169
148
- generate. call1 ( py, ( model_id, inputs, config) ) ?
170
+ generate
171
+ . call1 ( py, ( model_id, inputs, config) )
172
+ . format_traceback ( py) ?
149
173
} else {
150
174
return Err ( e. into ( ) ) ;
151
175
}
152
176
}
153
177
Ok ( o) => o,
154
178
} ;
155
- Ok ( result. extract ( py) ?)
179
+ Ok ( result. extract ( py) . format_traceback ( py ) ?)
156
180
} )
157
181
}
158
182
@@ -200,7 +224,7 @@ pub fn load_dataset(
200
224
let kwargs = serde_json:: to_string ( kwargs) ?;
201
225
202
226
let dataset = Python :: with_gil ( |py| -> Result < String > {
203
- let load_dataset: Py < PyAny > = PY_MODULE . getattr ( py, "load_dataset" ) ?;
227
+ let load_dataset: Py < PyAny > = PY_MODULE . getattr ( py, "load_dataset" ) . format_traceback ( py ) ?;
204
228
Ok ( load_dataset
205
229
. call1 (
206
230
py,
@@ -213,8 +237,10 @@ pub fn load_dataset(
213
237
kwargs. into_py ( py) ,
214
238
] ,
215
239
) ,
216
- ) ?
217
- . extract ( py) ?)
240
+ )
241
+ . format_traceback ( py) ?
242
+ . extract ( py)
243
+ . format_traceback ( py) ?)
218
244
} ) ?;
219
245
220
246
let table_name = format ! ( "pgml.\" {}\" " , name) ;
@@ -351,10 +377,14 @@ pub fn clear_gpu_cache(memory_usage: Option<f32>) -> Result<bool> {
351
377
crate :: bindings:: venv:: activate ( ) ;
352
378
353
379
Python :: with_gil ( |py| -> Result < bool > {
354
- let clear_gpu_cache: Py < PyAny > = PY_MODULE . getattr ( py, "clear_gpu_cache" ) ?;
380
+ let clear_gpu_cache: Py < PyAny > = PY_MODULE
381
+ . getattr ( py, "clear_gpu_cache" )
382
+ . format_traceback ( py) ?;
355
383
let success = clear_gpu_cache
356
- . call1 ( py, PyTuple :: new ( py, & [ memory_usage. into_py ( py) ] ) ) ?
357
- . extract ( py) ?;
384
+ . call1 ( py, PyTuple :: new ( py, & [ memory_usage. into_py ( py) ] ) )
385
+ . format_traceback ( py) ?
386
+ . extract ( py)
387
+ . format_traceback ( py) ?;
358
388
Ok ( success)
359
389
} )
360
390
}
0 commit comments