Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
fix and test preprocessing examples
  • Loading branch information
montanalow committed Jun 11, 2024
commit fb773bb307720ccfc7f7f566a448efb0b7c673bf
33 changes: 33 additions & 0 deletions pgml-extension/examples/preprocessing.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
-- load the diamonds dataset, that contains text categorical variables
SELECT pgml.load_dataset('jdxcosta/diamonds');

-- view the data
SELECT * FROM pgml."jdxcosta/diamonds" LIMIT 10;

-- drop the Unamed column, since it's not useful for training (you could create a view instead)
ALTER TABLE pgml."jdxcosta/diamonds" DROP COLUMN "Unnamed: 0";

-- train a model using preprocessors to scale the numeric variables, and target encode the categoricals
SELECT pgml.train(
project_name => 'Diamond prices',
task => 'regression',
relation_name => 'pgml.jdxcosta/diamonds',
y_column_name => 'price',
algorithm => 'lightgbm',
preprocess => '{
"carat": {"scale": "standard"},
"depth": {"scale": "standard"},
"table": {"scale": "standard"},
"cut": {"encode": "target", "scale": "standard"},
"color": {"encode": "target", "scale": "standard"},
"clarity": {"encode": "target", "scale": "standard"}
}'
);

-- run some predictions, notice we're passing a heterogeneous row (tuple) as input, rather than a homogenous ARRAY[].
SELECT price, pgml.predict('Diamond prices', (carat, cut, color, clarity, depth, "table", x, y, z)) AS prediction
FROM pgml."jdxcosta/diamonds"
LIMIT 10;

-- This is a difficult dataset for more algorithms, which makes it a good challenge for preprocessing, and additional
-- feature engineering. What's next?
7 changes: 4 additions & 3 deletions pgml-extension/src/bindings/transformers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ pub fn load_dataset(
.ok_or(anyhow!("dataset `data` key is not an object"))?;
let column_names = types
.iter()
.map(|(name, _type)| name.clone())
.map(|(name, _type)| format!("\"{}\"", name) )
.collect::<Vec<String>>()
.join(", ");
let column_types = types
Expand All @@ -393,13 +393,14 @@ pub fn load_dataset(
"int64" => "INT8",
"int32" => "INT4",
"int16" => "INT2",
"int8" => "INT2",
"float64" => "FLOAT8",
"float32" => "FLOAT4",
"float16" => "FLOAT4",
"bool" => "BOOLEAN",
_ => bail!("unhandled dataset feature while reading dataset: {type_}"),
};
Ok(format!("{name} {type_}"))
Ok(format!("\"{name}\" {type_}"))
})
.collect::<Result<Vec<String>>>()?
.join(", ");
Expand Down Expand Up @@ -455,7 +456,7 @@ pub fn load_dataset(
.into_datum(),
)),
"dict" | "list" => row.push((PgBuiltInOids::JSONBOID.oid(), JsonB(value.clone()).into_datum())),
"int64" | "int32" | "int16" => row.push((
"int64" | "int32" | "int16" | "int8" => row.push((
PgBuiltInOids::INT8OID.oid(),
value
.as_i64()
Expand Down
13 changes: 6 additions & 7 deletions pgml-extension/src/orm/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -344,13 +344,12 @@ impl Model {
).unwrap().first();

if !result.is_empty() {
let project_id = result.get(2).unwrap().unwrap();
let project = Project::find(project_id).unwrap();
let snapshot_id = result.get(3).unwrap().unwrap();
let snapshot = Snapshot::find(snapshot_id).unwrap();
let algorithm = Algorithm::from_str(result.get(4).unwrap().unwrap()).unwrap();
let runtime = Runtime::from_str(result.get(5).unwrap().unwrap()).unwrap();

let project_id = result.get(2).unwrap().expect("project_id is i64");
let project = Project::find(project_id).expect("project doesn't exist");
let snapshot_id = result.get(3).unwrap().expect("snapshot_id is i64");
let snapshot = Snapshot::find(snapshot_id).expect("snapshot doesn't exist");
let algorithm = Algorithm::from_str(result.get(4).unwrap().unwrap()).expect("algorithm is malformed");
let runtime = Runtime::from_str(result.get(5).unwrap().unwrap()).expect("runtime is malformed");
let data = Spi::get_one_with_args::<Vec<u8>>(
"
SELECT data
Expand Down
2 changes: 1 addition & 1 deletion pgml-extension/src/orm/sampling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ impl Sampling {
Sampling::stratified => {
format!(
"
SELECT *
SELECT {col_string}
FROM (
SELECT
*,
Expand Down
23 changes: 16 additions & 7 deletions pgml-extension/src/orm/snapshot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,16 +230,24 @@ impl Column {
if self.preprocessor.encode == Encode::target {
let categories = self.statistics.categories.as_mut().unwrap();
let mut sums = vec![0_f32; categories.len() + 1];
let mut total = 0.;
Zip::from(array).and(target).for_each(|&value, &target| {
total += target;
sums[value as usize] += target;
});
let avg_target = total / categories.len() as f32;
for category in categories.values_mut() {
let sum = sums[category.value as usize];
category.value = sum / category.members as f32;
if category.members > 0 {
let sum = sums[category.value as usize];
category.value = sum / category.members as f32;
} else {
// use avg target for categories w/ no members, e.g. __NULL__ category in a complete dataset
category.value = avg_target;
}
}
}

// Data is filtered for NaN because it is not well defined statistically, and they are counted as separate stat
// Data is filtered for NaN because it is not well-defined statistically, and they are counted as separate stat
let mut data = array
.iter()
.filter_map(|n| if n.is_nan() { None } else { Some(*n) })
Expand Down Expand Up @@ -404,7 +412,7 @@ impl Snapshot {
.first();
if !result.is_empty() {
let jsonb: JsonB = result.get(7).unwrap().unwrap();
let columns: Vec<Column> = serde_json::from_value(jsonb.0).unwrap();
let columns: Vec<Column> = serde_json::from_value(jsonb.0).expect("invalid json description of columns");
// let jsonb: JsonB = result.get(8).unwrap();
// let analysis: Option<IndexMap<String, f32>> = Some(serde_json::from_value(jsonb.0).unwrap());
let mut s = Snapshot {
Expand Down Expand Up @@ -500,9 +508,10 @@ impl Snapshot {

let preprocessors: HashMap<String, Preprocessor> = serde_json::from_value(preprocess.0).expect("is valid");

let mut position = 0; // Postgres column positions are not updated when other columns are dropped, but we expect consecutive positions when we read the table.
Spi::connect(|mut client| {
let mut columns: Vec<Column> = Vec::new();
client.select("SELECT column_name::TEXT, udt_name::TEXT, is_nullable::BOOLEAN, ordinal_position::INTEGER FROM information_schema.columns WHERE table_schema = $1 AND table_name = $2 ORDER BY ordinal_position ASC",
client.select("SELECT column_name::TEXT, udt_name::TEXT, is_nullable::BOOLEAN FROM information_schema.columns WHERE table_schema = $1 AND table_name = $2 ORDER BY ordinal_position ASC",
None,
Some(vec![
(PgBuiltInOids::TEXTOID.oid(), schema_name.into_datum()),
Expand All @@ -520,7 +529,7 @@ impl Snapshot {
pg_type = pg_type[1..].to_string() + "[]";
}
let nullable = row[3].value::<bool>().unwrap().unwrap();
let position = row[4].value::<i32>().unwrap().unwrap() as usize;
position += 1;
let label = match y_column_name {
Some(ref y_column_name) => y_column_name.contains(&name),
None => false,
Expand Down Expand Up @@ -1158,7 +1167,7 @@ impl Snapshot {
pub fn numeric_encoded_dataset(&mut self) -> Dataset {
let mut data = None;
Spi::connect(|client| {
// Postgres Arrays arrays are 1 indexed and so are SPI tuples...
// Postgres arrays are 1 indexed and so are SPI tuples...
let result = client.select(&self.select_sql(), None, None).unwrap();
let num_rows = result.len();
let (num_train_rows, num_test_rows) = self.train_test_split(num_rows);
Expand Down
1 change: 1 addition & 0 deletions pgml-extension/tests/test.sql
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,6 @@ SELECT pgml.load_dataset('wine');
\i examples/regression.sql
\i examples/vectors.sql
\i examples/chunking.sql
\i examples/preprocessing.sql
-- transformers are generally too slow to run in the test suite
--\i examples/transformers.sql