Skip to content

Commit c30fd38

Browse files
authored
remove unwraps from transformers module (#894)
1 parent 21ad471 commit c30fd38

File tree

4 files changed

+118
-137
lines changed

4 files changed

+118
-137
lines changed

pgml-extension/src/api.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -602,7 +602,10 @@ pub fn embed_batch(
602602
/// ```
603603
#[pg_extern(immutable, parallel_safe, name = "clear_gpu_cache")]
604604
pub fn clear_gpu_cache(memory_usage: default!(Option<f32>, "NULL")) -> bool {
605-
crate::bindings::transformers::clear_gpu_cache(memory_usage)
605+
match crate::bindings::transformers::clear_gpu_cache(memory_usage) {
606+
Ok(success) => success,
607+
Err(e) => error!("{e}"),
608+
}
606609
}
607610

608611
#[pg_extern(immutable, parallel_safe)]

pgml-extension/src/bindings/transformers/error.rs

Lines changed: 0 additions & 44 deletions
This file was deleted.

pgml-extension/src/bindings/transformers/mod.rs

Lines changed: 94 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,16 @@ use std::path::PathBuf;
33
use std::str::FromStr;
44
use std::{collections::HashMap, path::Path};
55

6+
use anyhow::{anyhow, bail, Context, Result};
67
use once_cell::sync::Lazy;
78
use pgrx::*;
89
use pyo3::prelude::*;
910
use pyo3::types::PyTuple;
1011

1112
use crate::orm::{Task, TextDataset};
1213

13-
use self::error::Error;
14-
use self::whitelist::verify_task_against_whitelist;
15-
16-
pub mod error;
1714
pub mod whitelist;
1815

19-
pub type Result<T> = std::result::Result<T, error::Error>;
20-
2116
static PY_MODULE: Lazy<Py<PyModule>> = Lazy::new(|| {
2217
Python::with_gil(|py| -> Py<PyModule> {
2318
let src = include_str!(concat!(
@@ -36,7 +31,7 @@ pub fn transform(
3631
) -> Result<serde_json::Value> {
3732
crate::bindings::venv::activate();
3833

39-
verify_task_against_whitelist(task)?;
34+
whitelist::verify_task(task)?;
4035

4136
let task = serde_json::to_string(task)?;
4237
let args = serde_json::to_string(args)?;
@@ -98,12 +93,13 @@ pub fn tune(
9893

9994
Python::with_gil(|py| -> Result<HashMap<String, f64>> {
10095
let tune = PY_MODULE.getattr(py, "tune")?;
96+
let path = path.to_string_lossy();
10197
let output = tune.call1(
10298
py,
10399
(
104100
&task,
105101
&hyperparams,
106-
path.to_str().unwrap(),
102+
path.as_ref(),
107103
dataset.x_train,
108104
dataset.x_test,
109105
dataset.y_train,
@@ -127,12 +123,12 @@ pub fn generate(model_id: i64, inputs: Vec<&str>, config: JsonB) -> Result<Vec<S
127123
let result = generate.call1(py, (model_id, inputs.clone(), &config));
128124
let result = match result {
129125
Err(e) => {
130-
if e.get_type(py).name().unwrap() == "MissingModelError" {
126+
if e.get_type(py).name()? == "MissingModelError" {
131127
info!("Loading model into cache for connection reuse");
132128
let mut dir = std::path::PathBuf::from("/tmp/postgresml/models");
133129
dir.push(model_id.to_string());
134130
if !dir.exists() {
135-
dump_model(model_id, dir.clone());
131+
dump_model(model_id, dir.clone())?;
136132
}
137133
let task = Spi::get_one_with_args::<String>(
138134
"SELECT task::TEXT
@@ -141,15 +137,15 @@ pub fn generate(model_id: i64, inputs: Vec<&str>, config: JsonB) -> Result<Vec<S
141137
ON models.project_id = projects.id
142138
WHERE models.id = $1",
143139
vec![(PgBuiltInOids::INT8OID.oid(), model_id.into_datum())],
144-
)
145-
.unwrap()
146-
.unwrap();
140+
)?
141+
.ok_or(anyhow!("task query returned None"))?;
147142

148143
let load = PY_MODULE.getattr(py, "load_model")?;
149-
let task = Task::from_str(&task).unwrap();
150-
load.call1(py, (model_id, task.to_string(), dir)).unwrap();
144+
let task = Task::from_str(&task)
145+
.map_err(|_| anyhow!("could not make a Task from {task}"))?;
146+
load.call1(py, (model_id, task.to_string(), dir))?;
151147

152-
generate.call1(py, (model_id, inputs, config)).unwrap()
148+
generate.call1(py, (model_id, inputs, config))?
153149
} else {
154150
return Err(e.into());
155151
}
@@ -160,31 +156,37 @@ pub fn generate(model_id: i64, inputs: Vec<&str>, config: JsonB) -> Result<Vec<S
160156
})
161157
}
162158

163-
fn dump_model(model_id: i64, dir: PathBuf) {
159+
fn dump_model(model_id: i64, dir: PathBuf) -> Result<()> {
164160
if dir.exists() {
165-
std::fs::remove_dir_all(&dir).unwrap();
161+
std::fs::remove_dir_all(&dir).context("failed to remove directory while dumping model")?;
166162
}
167-
std::fs::create_dir_all(&dir).unwrap();
168-
Spi::connect(|client| {
163+
std::fs::create_dir_all(&dir).context("failed to create directory while dumping model")?;
164+
Spi::connect(|client| -> Result<()> {
169165
let result = client.select("SELECT path, part, data FROM pgml.files WHERE model_id = $1 ORDER BY path ASC, part ASC",
170166
None,
171167
Some(vec![
172168
(PgBuiltInOids::INT8OID.oid(), model_id.into_datum()),
173169
])
174-
).unwrap();
170+
)?;
175171
for row in result {
176172
let mut path = dir.clone();
177-
path.push(row.get::<String>(1).unwrap().unwrap());
178-
let data: Vec<u8> = row.get(3).unwrap().unwrap();
173+
path.push(
174+
row.get::<String>(1)?
175+
.ok_or(anyhow!("row get ordinal 1 returned None"))?,
176+
);
177+
let data: Vec<u8> = row
178+
.get(3)?
179+
.ok_or(anyhow!("row get ordinal 3 returned None"))?;
179180
let mut file = std::fs::OpenOptions::new()
180181
.create(true)
181182
.append(true)
182-
.open(path)
183-
.unwrap();
184-
let _num_bytes = file.write(&data).unwrap();
185-
file.flush().unwrap();
183+
.open(path)?;
184+
185+
let _num_bytes = file.write(&data)?;
186+
file.flush()?;
186187
}
187-
});
188+
Ok(())
189+
})
188190
}
189191

190192
pub fn load_dataset(
@@ -219,9 +221,19 @@ pub fn load_dataset(
219221

220222
// Columns are a (name: String, values: Vec<Value>) pair
221223
let json: serde_json::Value = serde_json::from_str(&dataset)?;
222-
let json = json.as_object().unwrap();
223-
let types = json.get("types").unwrap().as_object().unwrap();
224-
let data = json.get("data").unwrap().as_object().unwrap();
224+
let json = json
225+
.as_object()
226+
.ok_or(anyhow!("dataset json is not object"))?;
227+
let types = json
228+
.get("types")
229+
.ok_or(anyhow!("dataset json missing `types` key"))?
230+
.as_object()
231+
.ok_or(anyhow!("dataset `types` key is not an object"))?;
232+
let data = json
233+
.get("data")
234+
.ok_or(anyhow!("dataset json missing `data` key"))?
235+
.as_object()
236+
.ok_or(anyhow!("dataset `data` key is not an object"))?;
225237
let column_names = types
226238
.iter()
227239
.map(|(name, _type)| name.clone())
@@ -230,7 +242,10 @@ pub fn load_dataset(
230242
let column_types = types
231243
.iter()
232244
.map(|(name, type_)| -> Result<String> {
233-
let type_ = match type_.as_str().unwrap() {
245+
let type_ = type_
246+
.as_str()
247+
.ok_or(anyhow!("expected {type_} to be a json string"))?;
248+
let type_ = match type_ {
234249
"string" => "TEXT",
235250
"dict" | "list" => "JSONB",
236251
"int64" => "INT8",
@@ -240,12 +255,7 @@ pub fn load_dataset(
240255
"float32" => "FLOAT4",
241256
"float16" => "FLOAT4",
242257
"bool" => "BOOLEAN",
243-
_ => {
244-
return Err(Error::Data(format!(
245-
"unhandled dataset feature while reading dataset: {}",
246-
type_
247-
)))
248-
}
258+
_ => bail!("unhandled dataset feature while reading dataset: {type_}"),
249259
};
250260
Ok(format!("{name} {type_}"))
251261
})
@@ -261,64 +271,88 @@ pub fn load_dataset(
261271
.collect::<Vec<String>>()
262272
.join(", ");
263273
let num_cols = types.len();
264-
let num_rows = data.values().next().unwrap().as_array().unwrap().len();
274+
let num_rows = data
275+
.values()
276+
.next()
277+
.ok_or(anyhow!("dataset json has no fields"))?
278+
.as_array()
279+
.ok_or(anyhow!("dataset json field is not an array"))?
280+
.len();
265281

266282
// Avoid the existence warning by checking the schema for the table first
267283
let table_count = Spi::get_one_with_args::<i64>("SELECT COUNT(*) FROM information_schema.tables WHERE table_name = $1 AND table_schema = 'pgml'", vec![
268284
(PgBuiltInOids::TEXTOID.oid(), table_name.clone().into_datum())
269-
]).unwrap().unwrap();
285+
])?.ok_or(anyhow!("table count query returned None"))?;
270286
if table_count == 1 {
271-
Spi::run(&format!(r#"DROP TABLE IF EXISTS {table_name}"#)).unwrap()
287+
Spi::run(&format!(r#"DROP TABLE IF EXISTS {table_name}"#))?;
272288
}
273289

274-
Spi::run(&format!(r#"CREATE TABLE {table_name} ({column_types})"#)).unwrap();
290+
Spi::run(&format!(r#"CREATE TABLE {table_name} ({column_types})"#))?;
275291
let insert =
276292
format!(r#"INSERT INTO {table_name} ({column_names}) VALUES ({column_placeholders})"#);
277293
for i in 0..num_rows {
278294
let mut row = Vec::with_capacity(num_cols);
279295
for (name, values) in data {
280-
let value = values.as_array().unwrap().get(i).unwrap();
281-
match types.get(name).unwrap().as_str().unwrap() {
296+
let value = values
297+
.as_array()
298+
.ok_or_else(|| anyhow!("expected {values} to be an array"))?
299+
.get(i)
300+
.ok_or_else(|| anyhow!("invalid index {i} for {values}"))?;
301+
match types
302+
.get(name)
303+
.ok_or_else(|| anyhow!("{types:?} expected to have key {name}"))?
304+
.as_str()
305+
.ok_or_else(|| anyhow!("json field {name} expected to be string"))?
306+
{
282307
"string" => row.push((
283308
PgBuiltInOids::TEXTOID.oid(),
284-
value.as_str().unwrap().into_datum(),
309+
value
310+
.as_str()
311+
.ok_or_else(|| anyhow!("expected {value} to be string"))?
312+
.into_datum(),
285313
)),
286314
"dict" | "list" => row.push((
287315
PgBuiltInOids::JSONBOID.oid(),
288316
JsonB(value.clone()).into_datum(),
289317
)),
290318
"int64" | "int32" | "int16" => row.push((
291319
PgBuiltInOids::INT8OID.oid(),
292-
value.as_i64().unwrap().into_datum(),
320+
value
321+
.as_i64()
322+
.ok_or_else(|| anyhow!("expected {value} to be i64"))?
323+
.into_datum(),
293324
)),
294325
"float64" | "float32" | "float16" => row.push((
295326
PgBuiltInOids::FLOAT8OID.oid(),
296-
value.as_f64().unwrap().into_datum(),
327+
value
328+
.as_f64()
329+
.ok_or_else(|| anyhow!("expected {value} to be f64"))?
330+
.into_datum(),
297331
)),
298332
"bool" => row.push((
299333
PgBuiltInOids::BOOLOID.oid(),
300-
value.as_bool().unwrap().into_datum(),
334+
value
335+
.as_bool()
336+
.ok_or_else(|| anyhow!("expected {value} to be bool"))?
337+
.into_datum(),
301338
)),
302339
type_ => {
303-
return Err(Error::Data(format!(
304-
"unhandled dataset value type while reading dataset: {value:?} {type_:?}",
305-
)))
340+
bail!("unhandled dataset value type while reading dataset: {value:?} {type_:?}")
306341
}
307342
}
308343
}
309-
Spi::run_with_args(&insert, Some(row)).unwrap();
344+
Spi::run_with_args(&insert, Some(row))?
310345
}
311346

312347
Ok(num_rows)
313348
}
314349

315-
pub fn clear_gpu_cache(memory_usage: Option<f32>) -> bool {
316-
Python::with_gil(|py| -> bool {
317-
let clear_gpu_cache: Py<PyAny> = PY_MODULE.getattr(py, "clear_gpu_cache").unwrap();
318-
clear_gpu_cache
319-
.call1(py, PyTuple::new(py, &[memory_usage.into_py(py)]))
320-
.unwrap()
321-
.extract(py)
322-
.unwrap()
350+
pub fn clear_gpu_cache(memory_usage: Option<f32>) -> Result<bool> {
351+
Python::with_gil(|py| -> Result<bool> {
352+
let clear_gpu_cache: Py<PyAny> = PY_MODULE.getattr(py, "clear_gpu_cache")?;
353+
let success = clear_gpu_cache
354+
.call1(py, PyTuple::new(py, &[memory_usage.into_py(py)]))?
355+
.extract(py)?;
356+
Ok(success)
323357
})
324358
}

0 commit comments

Comments
 (0)