@@ -3,21 +3,16 @@ use std::path::PathBuf;
3
3
use std:: str:: FromStr ;
4
4
use std:: { collections:: HashMap , path:: Path } ;
5
5
6
+ use anyhow:: { anyhow, bail, Context , Result } ;
6
7
use once_cell:: sync:: Lazy ;
7
8
use pgrx:: * ;
8
9
use pyo3:: prelude:: * ;
9
10
use pyo3:: types:: PyTuple ;
10
11
11
12
use crate :: orm:: { Task , TextDataset } ;
12
13
13
- use self :: error:: Error ;
14
- use self :: whitelist:: verify_task_against_whitelist;
15
-
16
- pub mod error;
17
14
pub mod whitelist;
18
15
19
- pub type Result < T > = std:: result:: Result < T , error:: Error > ;
20
-
21
16
static PY_MODULE : Lazy < Py < PyModule > > = Lazy :: new ( || {
22
17
Python :: with_gil ( |py| -> Py < PyModule > {
23
18
let src = include_str ! ( concat!(
@@ -36,7 +31,7 @@ pub fn transform(
36
31
) -> Result < serde_json:: Value > {
37
32
crate :: bindings:: venv:: activate ( ) ;
38
33
39
- verify_task_against_whitelist ( task) ?;
34
+ whitelist :: verify_task ( task) ?;
40
35
41
36
let task = serde_json:: to_string ( task) ?;
42
37
let args = serde_json:: to_string ( args) ?;
@@ -98,12 +93,13 @@ pub fn tune(
98
93
99
94
Python :: with_gil ( |py| -> Result < HashMap < String , f64 > > {
100
95
let tune = PY_MODULE . getattr ( py, "tune" ) ?;
96
+ let path = path. to_string_lossy ( ) ;
101
97
let output = tune. call1 (
102
98
py,
103
99
(
104
100
& task,
105
101
& hyperparams,
106
- path. to_str ( ) . unwrap ( ) ,
102
+ path. as_ref ( ) ,
107
103
dataset. x_train ,
108
104
dataset. x_test ,
109
105
dataset. y_train ,
@@ -127,12 +123,12 @@ pub fn generate(model_id: i64, inputs: Vec<&str>, config: JsonB) -> Result<Vec<S
127
123
let result = generate. call1 ( py, ( model_id, inputs. clone ( ) , & config) ) ;
128
124
let result = match result {
129
125
Err ( e) => {
130
- if e. get_type ( py) . name ( ) . unwrap ( ) == "MissingModelError" {
126
+ if e. get_type ( py) . name ( ) ? == "MissingModelError" {
131
127
info ! ( "Loading model into cache for connection reuse" ) ;
132
128
let mut dir = std:: path:: PathBuf :: from ( "/tmp/postgresml/models" ) ;
133
129
dir. push ( model_id. to_string ( ) ) ;
134
130
if !dir. exists ( ) {
135
- dump_model ( model_id, dir. clone ( ) ) ;
131
+ dump_model ( model_id, dir. clone ( ) ) ? ;
136
132
}
137
133
let task = Spi :: get_one_with_args :: < String > (
138
134
"SELECT task::TEXT
@@ -141,15 +137,15 @@ pub fn generate(model_id: i64, inputs: Vec<&str>, config: JsonB) -> Result<Vec<S
141
137
ON models.project_id = projects.id
142
138
WHERE models.id = $1" ,
143
139
vec ! [ ( PgBuiltInOids :: INT8OID . oid( ) , model_id. into_datum( ) ) ] ,
144
- )
145
- . unwrap ( )
146
- . unwrap ( ) ;
140
+ ) ?
141
+ . ok_or ( anyhow ! ( "task query returned None" ) ) ?;
147
142
148
143
let load = PY_MODULE . getattr ( py, "load_model" ) ?;
149
- let task = Task :: from_str ( & task) . unwrap ( ) ;
150
- load. call1 ( py, ( model_id, task. to_string ( ) , dir) ) . unwrap ( ) ;
144
+ let task = Task :: from_str ( & task)
145
+ . map_err ( |_| anyhow ! ( "could not make a Task from {task}" ) ) ?;
146
+ load. call1 ( py, ( model_id, task. to_string ( ) , dir) ) ?;
151
147
152
- generate. call1 ( py, ( model_id, inputs, config) ) . unwrap ( )
148
+ generate. call1 ( py, ( model_id, inputs, config) ) ?
153
149
} else {
154
150
return Err ( e. into ( ) ) ;
155
151
}
@@ -160,31 +156,37 @@ pub fn generate(model_id: i64, inputs: Vec<&str>, config: JsonB) -> Result<Vec<S
160
156
} )
161
157
}
162
158
163
- fn dump_model ( model_id : i64 , dir : PathBuf ) {
159
+ fn dump_model ( model_id : i64 , dir : PathBuf ) -> Result < ( ) > {
164
160
if dir. exists ( ) {
165
- std:: fs:: remove_dir_all ( & dir) . unwrap ( ) ;
161
+ std:: fs:: remove_dir_all ( & dir) . context ( "failed to remove directory while dumping model" ) ? ;
166
162
}
167
- std:: fs:: create_dir_all ( & dir) . unwrap ( ) ;
168
- Spi :: connect ( |client| {
163
+ std:: fs:: create_dir_all ( & dir) . context ( "failed to create directory while dumping model" ) ? ;
164
+ Spi :: connect ( |client| -> Result < ( ) > {
169
165
let result = client. select ( "SELECT path, part, data FROM pgml.files WHERE model_id = $1 ORDER BY path ASC, part ASC" ,
170
166
None ,
171
167
Some ( vec ! [
172
168
( PgBuiltInOids :: INT8OID . oid( ) , model_id. into_datum( ) ) ,
173
169
] )
174
- ) . unwrap ( ) ;
170
+ ) ? ;
175
171
for row in result {
176
172
let mut path = dir. clone ( ) ;
177
- path. push ( row. get :: < String > ( 1 ) . unwrap ( ) . unwrap ( ) ) ;
178
- let data: Vec < u8 > = row. get ( 3 ) . unwrap ( ) . unwrap ( ) ;
173
+ path. push (
174
+ row. get :: < String > ( 1 ) ?
175
+ . ok_or ( anyhow ! ( "row get ordinal 1 returned None" ) ) ?,
176
+ ) ;
177
+ let data: Vec < u8 > = row
178
+ . get ( 3 ) ?
179
+ . ok_or ( anyhow ! ( "row get ordinal 3 returned None" ) ) ?;
179
180
let mut file = std:: fs:: OpenOptions :: new ( )
180
181
. create ( true )
181
182
. append ( true )
182
- . open ( path)
183
- . unwrap ( ) ;
184
- let _num_bytes = file. write ( & data) . unwrap ( ) ;
185
- file. flush ( ) . unwrap ( ) ;
183
+ . open ( path) ? ;
184
+
185
+ let _num_bytes = file. write ( & data) ? ;
186
+ file. flush ( ) ? ;
186
187
}
187
- } ) ;
188
+ Ok ( ( ) )
189
+ } )
188
190
}
189
191
190
192
pub fn load_dataset (
@@ -219,9 +221,19 @@ pub fn load_dataset(
219
221
220
222
// Columns are a (name: String, values: Vec<Value>) pair
221
223
let json: serde_json:: Value = serde_json:: from_str ( & dataset) ?;
222
- let json = json. as_object ( ) . unwrap ( ) ;
223
- let types = json. get ( "types" ) . unwrap ( ) . as_object ( ) . unwrap ( ) ;
224
- let data = json. get ( "data" ) . unwrap ( ) . as_object ( ) . unwrap ( ) ;
224
+ let json = json
225
+ . as_object ( )
226
+ . ok_or ( anyhow ! ( "dataset json is not object" ) ) ?;
227
+ let types = json
228
+ . get ( "types" )
229
+ . ok_or ( anyhow ! ( "dataset json missing `types` key" ) ) ?
230
+ . as_object ( )
231
+ . ok_or ( anyhow ! ( "dataset `types` key is not an object" ) ) ?;
232
+ let data = json
233
+ . get ( "data" )
234
+ . ok_or ( anyhow ! ( "dataset json missing `data` key" ) ) ?
235
+ . as_object ( )
236
+ . ok_or ( anyhow ! ( "dataset `data` key is not an object" ) ) ?;
225
237
let column_names = types
226
238
. iter ( )
227
239
. map ( |( name, _type) | name. clone ( ) )
@@ -230,7 +242,10 @@ pub fn load_dataset(
230
242
let column_types = types
231
243
. iter ( )
232
244
. map ( |( name, type_) | -> Result < String > {
233
- let type_ = match type_. as_str ( ) . unwrap ( ) {
245
+ let type_ = type_
246
+ . as_str ( )
247
+ . ok_or ( anyhow ! ( "expected {type_} to be a json string" ) ) ?;
248
+ let type_ = match type_ {
234
249
"string" => "TEXT" ,
235
250
"dict" | "list" => "JSONB" ,
236
251
"int64" => "INT8" ,
@@ -240,12 +255,7 @@ pub fn load_dataset(
240
255
"float32" => "FLOAT4" ,
241
256
"float16" => "FLOAT4" ,
242
257
"bool" => "BOOLEAN" ,
243
- _ => {
244
- return Err ( Error :: Data ( format ! (
245
- "unhandled dataset feature while reading dataset: {}" ,
246
- type_
247
- ) ) )
248
- }
258
+ _ => bail ! ( "unhandled dataset feature while reading dataset: {type_}" ) ,
249
259
} ;
250
260
Ok ( format ! ( "{name} {type_}" ) )
251
261
} )
@@ -261,64 +271,88 @@ pub fn load_dataset(
261
271
. collect :: < Vec < String > > ( )
262
272
. join ( ", " ) ;
263
273
let num_cols = types. len ( ) ;
264
- let num_rows = data. values ( ) . next ( ) . unwrap ( ) . as_array ( ) . unwrap ( ) . len ( ) ;
274
+ let num_rows = data
275
+ . values ( )
276
+ . next ( )
277
+ . ok_or ( anyhow ! ( "dataset json has no fields" ) ) ?
278
+ . as_array ( )
279
+ . ok_or ( anyhow ! ( "dataset json field is not an array" ) ) ?
280
+ . len ( ) ;
265
281
266
282
// Avoid the existence warning by checking the schema for the table first
267
283
let table_count = Spi :: get_one_with_args :: < i64 > ( "SELECT COUNT(*) FROM information_schema.tables WHERE table_name = $1 AND table_schema = 'pgml'" , vec ! [
268
284
( PgBuiltInOids :: TEXTOID . oid( ) , table_name. clone( ) . into_datum( ) )
269
- ] ) . unwrap ( ) . unwrap ( ) ;
285
+ ] ) ? . ok_or ( anyhow ! ( "table count query returned None" ) ) ? ;
270
286
if table_count == 1 {
271
- Spi :: run ( & format ! ( r#"DROP TABLE IF EXISTS {table_name}"# ) ) . unwrap ( )
287
+ Spi :: run ( & format ! ( r#"DROP TABLE IF EXISTS {table_name}"# ) ) ? ;
272
288
}
273
289
274
- Spi :: run ( & format ! ( r#"CREATE TABLE {table_name} ({column_types})"# ) ) . unwrap ( ) ;
290
+ Spi :: run ( & format ! ( r#"CREATE TABLE {table_name} ({column_types})"# ) ) ? ;
275
291
let insert =
276
292
format ! ( r#"INSERT INTO {table_name} ({column_names}) VALUES ({column_placeholders})"# ) ;
277
293
for i in 0 ..num_rows {
278
294
let mut row = Vec :: with_capacity ( num_cols) ;
279
295
for ( name, values) in data {
280
- let value = values. as_array ( ) . unwrap ( ) . get ( i) . unwrap ( ) ;
281
- match types. get ( name) . unwrap ( ) . as_str ( ) . unwrap ( ) {
296
+ let value = values
297
+ . as_array ( )
298
+ . ok_or_else ( || anyhow ! ( "expected {values} to be an array" ) ) ?
299
+ . get ( i)
300
+ . ok_or_else ( || anyhow ! ( "invalid index {i} for {values}" ) ) ?;
301
+ match types
302
+ . get ( name)
303
+ . ok_or_else ( || anyhow ! ( "{types:?} expected to have key {name}" ) ) ?
304
+ . as_str ( )
305
+ . ok_or_else ( || anyhow ! ( "json field {name} expected to be string" ) ) ?
306
+ {
282
307
"string" => row. push ( (
283
308
PgBuiltInOids :: TEXTOID . oid ( ) ,
284
- value. as_str ( ) . unwrap ( ) . into_datum ( ) ,
309
+ value
310
+ . as_str ( )
311
+ . ok_or_else ( || anyhow ! ( "expected {value} to be string" ) ) ?
312
+ . into_datum ( ) ,
285
313
) ) ,
286
314
"dict" | "list" => row. push ( (
287
315
PgBuiltInOids :: JSONBOID . oid ( ) ,
288
316
JsonB ( value. clone ( ) ) . into_datum ( ) ,
289
317
) ) ,
290
318
"int64" | "int32" | "int16" => row. push ( (
291
319
PgBuiltInOids :: INT8OID . oid ( ) ,
292
- value. as_i64 ( ) . unwrap ( ) . into_datum ( ) ,
320
+ value
321
+ . as_i64 ( )
322
+ . ok_or_else ( || anyhow ! ( "expected {value} to be i64" ) ) ?
323
+ . into_datum ( ) ,
293
324
) ) ,
294
325
"float64" | "float32" | "float16" => row. push ( (
295
326
PgBuiltInOids :: FLOAT8OID . oid ( ) ,
296
- value. as_f64 ( ) . unwrap ( ) . into_datum ( ) ,
327
+ value
328
+ . as_f64 ( )
329
+ . ok_or_else ( || anyhow ! ( "expected {value} to be f64" ) ) ?
330
+ . into_datum ( ) ,
297
331
) ) ,
298
332
"bool" => row. push ( (
299
333
PgBuiltInOids :: BOOLOID . oid ( ) ,
300
- value. as_bool ( ) . unwrap ( ) . into_datum ( ) ,
334
+ value
335
+ . as_bool ( )
336
+ . ok_or_else ( || anyhow ! ( "expected {value} to be bool" ) ) ?
337
+ . into_datum ( ) ,
301
338
) ) ,
302
339
type_ => {
303
- return Err ( Error :: Data ( format ! (
304
- "unhandled dataset value type while reading dataset: {value:?} {type_:?}" ,
305
- ) ) )
340
+ bail ! ( "unhandled dataset value type while reading dataset: {value:?} {type_:?}" )
306
341
}
307
342
}
308
343
}
309
- Spi :: run_with_args ( & insert, Some ( row) ) . unwrap ( ) ;
344
+ Spi :: run_with_args ( & insert, Some ( row) ) ?
310
345
}
311
346
312
347
Ok ( num_rows)
313
348
}
314
349
315
- pub fn clear_gpu_cache ( memory_usage : Option < f32 > ) -> bool {
316
- Python :: with_gil ( |py| -> bool {
317
- let clear_gpu_cache: Py < PyAny > = PY_MODULE . getattr ( py, "clear_gpu_cache" ) . unwrap ( ) ;
318
- clear_gpu_cache
319
- . call1 ( py, PyTuple :: new ( py, & [ memory_usage. into_py ( py) ] ) )
320
- . unwrap ( )
321
- . extract ( py)
322
- . unwrap ( )
350
+ pub fn clear_gpu_cache ( memory_usage : Option < f32 > ) -> Result < bool > {
351
+ Python :: with_gil ( |py| -> Result < bool > {
352
+ let clear_gpu_cache: Py < PyAny > = PY_MODULE . getattr ( py, "clear_gpu_cache" ) ?;
353
+ let success = clear_gpu_cache
354
+ . call1 ( py, PyTuple :: new ( py, & [ memory_usage. into_py ( py) ] ) ) ?
355
+ . extract ( py) ?;
356
+ Ok ( success)
323
357
} )
324
358
}
0 commit comments