Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion pgml-extension/src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,10 @@ pub fn embed_batch(
/// ```
#[pg_extern(immutable, parallel_safe, name = "clear_gpu_cache")]
pub fn clear_gpu_cache(memory_usage: default!(Option<f32>, "NULL")) -> bool {
crate::bindings::transformers::clear_gpu_cache(memory_usage)
match crate::bindings::transformers::clear_gpu_cache(memory_usage) {
Ok(success) => success,
Err(e) => error!("{e}"),
}
}

#[pg_extern(immutable, parallel_safe)]
Expand Down
44 changes: 0 additions & 44 deletions pgml-extension/src/bindings/transformers/error.rs

This file was deleted.

154 changes: 94 additions & 60 deletions pgml-extension/src/bindings/transformers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,16 @@ use std::path::PathBuf;
use std::str::FromStr;
use std::{collections::HashMap, path::Path};

use anyhow::{anyhow, bail, Context, Result};
use once_cell::sync::Lazy;
use pgrx::*;
use pyo3::prelude::*;
use pyo3::types::PyTuple;

use crate::orm::{Task, TextDataset};

use self::error::Error;
use self::whitelist::verify_task_against_whitelist;

pub mod error;
pub mod whitelist;

pub type Result<T> = std::result::Result<T, error::Error>;

static PY_MODULE: Lazy<Py<PyModule>> = Lazy::new(|| {
Python::with_gil(|py| -> Py<PyModule> {
let src = include_str!(concat!(
Expand All @@ -36,7 +31,7 @@ pub fn transform(
) -> Result<serde_json::Value> {
crate::bindings::venv::activate();

verify_task_against_whitelist(task)?;
whitelist::verify_task(task)?;

let task = serde_json::to_string(task)?;
let args = serde_json::to_string(args)?;
Expand Down Expand Up @@ -98,12 +93,13 @@ pub fn tune(

Python::with_gil(|py| -> Result<HashMap<String, f64>> {
let tune = PY_MODULE.getattr(py, "tune")?;
let path = path.to_string_lossy();
let output = tune.call1(
py,
(
&task,
&hyperparams,
path.to_str().unwrap(),
path.as_ref(),
dataset.x_train,
dataset.x_test,
dataset.y_train,
Expand All @@ -127,12 +123,12 @@ pub fn generate(model_id: i64, inputs: Vec<&str>, config: JsonB) -> Result<Vec<S
let result = generate.call1(py, (model_id, inputs.clone(), &config));
let result = match result {
Err(e) => {
if e.get_type(py).name().unwrap() == "MissingModelError" {
if e.get_type(py).name()? == "MissingModelError" {
info!("Loading model into cache for connection reuse");
let mut dir = std::path::PathBuf::from("/tmp/postgresml/models");
dir.push(model_id.to_string());
if !dir.exists() {
dump_model(model_id, dir.clone());
dump_model(model_id, dir.clone())?;
}
let task = Spi::get_one_with_args::<String>(
"SELECT task::TEXT
Expand All @@ -141,15 +137,15 @@ pub fn generate(model_id: i64, inputs: Vec<&str>, config: JsonB) -> Result<Vec<S
ON models.project_id = projects.id
WHERE models.id = $1",
vec![(PgBuiltInOids::INT8OID.oid(), model_id.into_datum())],
)
.unwrap()
.unwrap();
)?
.ok_or(anyhow!("task query returned None"))?;

let load = PY_MODULE.getattr(py, "load_model")?;
let task = Task::from_str(&task).unwrap();
load.call1(py, (model_id, task.to_string(), dir)).unwrap();
let task = Task::from_str(&task)
.map_err(|_| anyhow!("could not make a Task from {task}"))?;
load.call1(py, (model_id, task.to_string(), dir))?;

generate.call1(py, (model_id, inputs, config)).unwrap()
generate.call1(py, (model_id, inputs, config))?
} else {
return Err(e.into());
}
Expand All @@ -160,31 +156,37 @@ pub fn generate(model_id: i64, inputs: Vec<&str>, config: JsonB) -> Result<Vec<S
})
}

fn dump_model(model_id: i64, dir: PathBuf) {
fn dump_model(model_id: i64, dir: PathBuf) -> Result<()> {
if dir.exists() {
std::fs::remove_dir_all(&dir).unwrap();
std::fs::remove_dir_all(&dir).context("failed to remove directory while dumping model")?;
}
std::fs::create_dir_all(&dir).unwrap();
Spi::connect(|client| {
std::fs::create_dir_all(&dir).context("failed to create directory while dumping model")?;
Spi::connect(|client| -> Result<()> {
let result = client.select("SELECT path, part, data FROM pgml.files WHERE model_id = $1 ORDER BY path ASC, part ASC",
None,
Some(vec![
(PgBuiltInOids::INT8OID.oid(), model_id.into_datum()),
])
).unwrap();
)?;
for row in result {
let mut path = dir.clone();
path.push(row.get::<String>(1).unwrap().unwrap());
let data: Vec<u8> = row.get(3).unwrap().unwrap();
path.push(
row.get::<String>(1)?
.ok_or(anyhow!("row get ordinal 1 returned None"))?,
);
let data: Vec<u8> = row
.get(3)?
.ok_or(anyhow!("row get ordinal 3 returned None"))?;
let mut file = std::fs::OpenOptions::new()
.create(true)
.append(true)
.open(path)
.unwrap();
let _num_bytes = file.write(&data).unwrap();
file.flush().unwrap();
.open(path)?;

let _num_bytes = file.write(&data)?;
file.flush()?;
}
});
Ok(())
})
}

pub fn load_dataset(
Expand Down Expand Up @@ -219,9 +221,19 @@ pub fn load_dataset(

// Columns are a (name: String, values: Vec<Value>) pair
let json: serde_json::Value = serde_json::from_str(&dataset)?;
let json = json.as_object().unwrap();
let types = json.get("types").unwrap().as_object().unwrap();
let data = json.get("data").unwrap().as_object().unwrap();
let json = json
.as_object()
.ok_or(anyhow!("dataset json is not object"))?;
let types = json
.get("types")
.ok_or(anyhow!("dataset json missing `types` key"))?
.as_object()
.ok_or(anyhow!("dataset `types` key is not an object"))?;
let data = json
.get("data")
.ok_or(anyhow!("dataset json missing `data` key"))?
.as_object()
.ok_or(anyhow!("dataset `data` key is not an object"))?;
let column_names = types
.iter()
.map(|(name, _type)| name.clone())
Expand All @@ -230,7 +242,10 @@ pub fn load_dataset(
let column_types = types
.iter()
.map(|(name, type_)| -> Result<String> {
let type_ = match type_.as_str().unwrap() {
let type_ = type_
.as_str()
.ok_or(anyhow!("expected {type_} to be a json string"))?;
let type_ = match type_ {
"string" => "TEXT",
"dict" | "list" => "JSONB",
"int64" => "INT8",
Expand All @@ -240,12 +255,7 @@ pub fn load_dataset(
"float32" => "FLOAT4",
"float16" => "FLOAT4",
"bool" => "BOOLEAN",
_ => {
return Err(Error::Data(format!(
"unhandled dataset feature while reading dataset: {}",
type_
)))
}
_ => bail!("unhandled dataset feature while reading dataset: {type_}"),
};
Ok(format!("{name} {type_}"))
})
Expand All @@ -261,64 +271,88 @@ pub fn load_dataset(
.collect::<Vec<String>>()
.join(", ");
let num_cols = types.len();
let num_rows = data.values().next().unwrap().as_array().unwrap().len();
let num_rows = data
.values()
.next()
.ok_or(anyhow!("dataset json has no fields"))?
.as_array()
.ok_or(anyhow!("dataset json field is not an array"))?
.len();

// Avoid the existence warning by checking the schema for the table first
let table_count = Spi::get_one_with_args::<i64>("SELECT COUNT(*) FROM information_schema.tables WHERE table_name = $1 AND table_schema = 'pgml'", vec![
(PgBuiltInOids::TEXTOID.oid(), table_name.clone().into_datum())
]).unwrap().unwrap();
])?.ok_or(anyhow!("table count query returned None"))?;
if table_count == 1 {
Spi::run(&format!(r#"DROP TABLE IF EXISTS {table_name}"#)).unwrap()
Spi::run(&format!(r#"DROP TABLE IF EXISTS {table_name}"#))?;
}

Spi::run(&format!(r#"CREATE TABLE {table_name} ({column_types})"#)).unwrap();
Spi::run(&format!(r#"CREATE TABLE {table_name} ({column_types})"#))?;
let insert =
format!(r#"INSERT INTO {table_name} ({column_names}) VALUES ({column_placeholders})"#);
for i in 0..num_rows {
let mut row = Vec::with_capacity(num_cols);
for (name, values) in data {
let value = values.as_array().unwrap().get(i).unwrap();
match types.get(name).unwrap().as_str().unwrap() {
let value = values
.as_array()
.ok_or_else(|| anyhow!("expected {values} to be an array"))?
.get(i)
.ok_or_else(|| anyhow!("invalid index {i} for {values}"))?;
match types
.get(name)
.ok_or_else(|| anyhow!("{types:?} expected to have key {name}"))?
.as_str()
.ok_or_else(|| anyhow!("json field {name} expected to be string"))?
{
"string" => row.push((
PgBuiltInOids::TEXTOID.oid(),
value.as_str().unwrap().into_datum(),
value
.as_str()
.ok_or_else(|| anyhow!("expected {value} to be string"))?
.into_datum(),
)),
"dict" | "list" => row.push((
PgBuiltInOids::JSONBOID.oid(),
JsonB(value.clone()).into_datum(),
)),
"int64" | "int32" | "int16" => row.push((
PgBuiltInOids::INT8OID.oid(),
value.as_i64().unwrap().into_datum(),
value
.as_i64()
.ok_or_else(|| anyhow!("expected {value} to be i64"))?
.into_datum(),
)),
"float64" | "float32" | "float16" => row.push((
PgBuiltInOids::FLOAT8OID.oid(),
value.as_f64().unwrap().into_datum(),
value
.as_f64()
.ok_or_else(|| anyhow!("expected {value} to be f64"))?
.into_datum(),
)),
"bool" => row.push((
PgBuiltInOids::BOOLOID.oid(),
value.as_bool().unwrap().into_datum(),
value
.as_bool()
.ok_or_else(|| anyhow!("expected {value} to be bool"))?
.into_datum(),
)),
type_ => {
return Err(Error::Data(format!(
"unhandled dataset value type while reading dataset: {value:?} {type_:?}",
)))
bail!("unhandled dataset value type while reading dataset: {value:?} {type_:?}")
}
}
}
Spi::run_with_args(&insert, Some(row)).unwrap();
Spi::run_with_args(&insert, Some(row))?
}

Ok(num_rows)
}

pub fn clear_gpu_cache(memory_usage: Option<f32>) -> bool {
Python::with_gil(|py| -> bool {
let clear_gpu_cache: Py<PyAny> = PY_MODULE.getattr(py, "clear_gpu_cache").unwrap();
clear_gpu_cache
.call1(py, PyTuple::new(py, &[memory_usage.into_py(py)]))
.unwrap()
.extract(py)
.unwrap()
pub fn clear_gpu_cache(memory_usage: Option<f32>) -> Result<bool> {
Python::with_gil(|py| -> Result<bool> {
let clear_gpu_cache: Py<PyAny> = PY_MODULE.getattr(py, "clear_gpu_cache")?;
let success = clear_gpu_cache
.call1(py, PyTuple::new(py, &[memory_usage.into_py(py)]))?
.extract(py)?;
Ok(success)
})
}
Loading