diff --git a/pgml-cms/docs/api/sql-extension/pgml.decompose.md b/pgml-cms/docs/api/sql-extension/pgml.decompose.md
new file mode 100644
index 000000000..a322b4c99
--- /dev/null
+++ b/pgml-cms/docs/api/sql-extension/pgml.decompose.md
@@ -0,0 +1,50 @@
+---
+description: Decompose an input vector into it's principal components
+---
+
+# pgml.decompose()
+
+
+Chunks are pieces of documents split using some specified splitter. This is typically done before embedding.
+
+## API
+
+```sql
+pgml.decompose(
+ project_name TEXT, -- project name
+ vector REAL[] -- features to decompose
+)
+```
+
+### Parameters
+
+| Parameter | Example | Description |
+|----------------|---------------------------------|----------------------------------------------------------|
+| `project_name` | `'My First PostgresML Project'` | The project name used to train models in `pgml.train()`. |
+| `vector` | `ARRAY[0.1, 0.45, 1.0]` | The feature vector that needs decomposition. |
+
+## Example
+
+```sql
+SELECT pgml.decompose('My PCA', ARRAY[0.1, 2.0, 5.0]);
+```
+
+!!! example
+
+```sql
+SELECT *,
+ pgml.decompose(
+ 'Buy it Again',
+ ARRAY[
+ user.location_id,
+ NOW() - user.created_at,
+ user.total_purchases_in_dollars
+ ]
+ ) AS buying_score
+FROM users
+WHERE tenant_id = 5
+ORDER BY buying_score
+LIMIT 25;
+```
+
+!!!
\ No newline at end of file
diff --git a/pgml-cms/docs/api/sql-extension/pgml.train/clustering.md b/pgml-cms/docs/api/sql-extension/pgml.train/clustering.md
index 16554f54a..5ecf0b552 100644
--- a/pgml-cms/docs/api/sql-extension/pgml.train/clustering.md
+++ b/pgml-cms/docs/api/sql-extension/pgml.train/clustering.md
@@ -16,8 +16,8 @@ SELECT image FROM pgml.digits;
-- view the dataset
SELECT left(image::text, 40) || ',...}' FROM pgml.digit_vectors LIMIT 10;
--- train a simple model to classify the data
-SELECT * FROM pgml.train('Handwritten Digit Clusters', 'cluster', 'pgml.digit_vectors', hyperparams => '{"n_clusters": 10}');
+-- train a simple model to cluster the data
+SELECT * FROM pgml.train('Handwritten Digit Clusters', 'clustering', 'pgml.digit_vectors', hyperparams => '{"n_clusters": 10}');
-- check out the predictions
SELECT target, pgml.predict('Handwritten Digit Clusters', image) AS prediction
@@ -27,7 +27,7 @@ LIMIT 10;
## Algorithms
-All clustering algorithms implemented by PostgresML are online versions. You may use the [pgml.predict](../../../api/sql-extension/pgml.predict/ "mention")function to cluster novel datapoints after the clustering model has been trained.
+All clustering algorithms implemented by PostgresML are online versions. You may use the [pgml.predict](../../../api/sql-extension/pgml.predict/ "mention")function to cluster novel data points after the clustering model has been trained.
| Algorithm | Reference |
| ---------------------- | ----------------------------------------------------------------------------------------------------------------- |
diff --git a/pgml-cms/docs/api/sql-extension/pgml.train/decomposition.md b/pgml-cms/docs/api/sql-extension/pgml.train/decomposition.md
new file mode 100644
index 000000000..be8420df2
--- /dev/null
+++ b/pgml-cms/docs/api/sql-extension/pgml.train/decomposition.md
@@ -0,0 +1,42 @@
+# Decomposition
+
+Models can be trained using `pgml.train` on unlabeled data to identify important features within the data. To decompose a dataset into it's principal components, we can use the table or a view. Since decomposition is an unsupervised algorithm, we don't need a column that represents a label as one of the inputs to `pgml.train`.
+
+## Example
+
+This example trains models on the sklearn digits dataset -- which is a copy of the test set of the [UCI ML hand-written digits datasets](https://archive.ics.uci.edu/ml/datasets/Optical+Recognition+of+Handwritten+Digits). This demonstrates using a table with a single array feature column for principal component analysis. You could do something similar with a vector column.
+
+```sql
+SELECT pgml.load_dataset('digits');
+
+-- create an unlabeled table of the images for unsupervised learning
+CREATE VIEW pgml.digit_vectors AS
+SELECT image FROM pgml.digits;
+
+-- view the dataset
+SELECT left(image::text, 40) || ',...}' FROM pgml.digit_vectors LIMIT 10;
+
+-- train a simple model to cluster the data
+SELECT * FROM pgml.train('Handwritten Digit Components', 'decomposition', 'pgml.digit_vectors', hyperparams => '{"n_components": 3}');
+
+-- check out the compenents
+SELECT target, pgml.decompose('Handwritten Digit Components', image) AS pca
+FROM pgml.digits
+LIMIT 10;
+```
+
+Note that the input vectors have been reduced from 64 dimensions to 3, which explain nearly half of the variance across all samples.
+
+## Algorithms
+
+All decomposition algorithms implemented by PostgresML are online versions. You may use the [pgml.decompose](../../../api/sql-extension/pgml.decompose "mention") function to decompose novel data points after the model has been trained.
+
+| Algorithm | Reference |
+|---------------------------|---------------------------------------------------------------------------------------------------------------------|
+| `pca` | [PCA](https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.PCA.html) |
+
+### Examples
+
+```sql
+SELECT * FROM pgml.train('Handwritten Digit Clusters', algorithm => 'pca', hyperparams => '{"n_components": 10}');
+```
diff --git a/pgml-dashboard/src/models.rs b/pgml-dashboard/src/models.rs
index c26ca363f..c2168eb0e 100644
--- a/pgml-dashboard/src/models.rs
+++ b/pgml-dashboard/src/models.rs
@@ -55,10 +55,11 @@ impl Project {
match self.task.as_ref().unwrap().as_str() {
"classification" | "text_classification" | "question_answering" => Ok("f1"),
"regression" => Ok("r2"),
+ "clustering" => Ok("silhouette"),
+ "decomposition" => Ok("cumulative_explained_variance"),
"summarization" => Ok("rouge_ngram_f1"),
"translation" => Ok("bleu"),
"text_generation" | "text2text" => Ok("perplexity"),
- "cluster" => Ok("silhouette"),
task => Err(anyhow::anyhow!("Unhandled task: {}", task)),
}
}
@@ -67,10 +68,11 @@ impl Project {
match self.task.as_ref().unwrap().as_str() {
"classification" | "text_classification" | "question_answering" => Ok("F1"),
"regression" => Ok("R2"),
+ "clustering" => Ok("silhouette"),
+ "decomposition" => Ok("Cumulative Explained Variance"),
"summarization" => Ok("Rouge Ngram F1"),
"translation" => Ok("Bleu"),
"text_generation" | "text2text" => Ok("Perplexity"),
- "cluster" => Ok("silhouette"),
task => Err(anyhow::anyhow!("Unhandled task: {}", task)),
}
}
diff --git a/pgml-extension/.cargo/config b/pgml-extension/.cargo/config.toml
similarity index 100%
rename from pgml-extension/.cargo/config
rename to pgml-extension/.cargo/config.toml
diff --git a/pgml-extension/Cargo.lock b/pgml-extension/Cargo.lock
index ad7dd7b0f..8dbfba0f1 100644
--- a/pgml-extension/Cargo.lock
+++ b/pgml-extension/Cargo.lock
@@ -1746,7 +1746,7 @@ dependencies = [
[[package]]
name = "pgml"
-version = "2.8.3"
+version = "2.8.4"
dependencies = [
"anyhow",
"blas",
diff --git a/pgml-extension/Cargo.toml b/pgml-extension/Cargo.toml
index 7aea7ba7c..86d94c124 100644
--- a/pgml-extension/Cargo.toml
+++ b/pgml-extension/Cargo.toml
@@ -1,6 +1,6 @@
[package]
name = "pgml"
-version = "2.8.3"
+version = "2.8.4"
edition = "2021"
[lib]
diff --git a/pgml-extension/examples/cluster.sql b/pgml-extension/examples/clustering.sql
similarity index 94%
rename from pgml-extension/examples/cluster.sql
rename to pgml-extension/examples/clustering.sql
index f12609a1e..cb60d4af6 100644
--- a/pgml-extension/examples/cluster.sql
+++ b/pgml-extension/examples/clustering.sql
@@ -20,7 +20,7 @@ SELECT image FROM pgml.digits;
SELECT left(image::text, 40) || ',...}' FROM pgml.digit_vectors LIMIT 10;
-- train a simple model to classify the data
-SELECT * FROM pgml.train('Handwritten Digit Clusters', 'cluster', 'pgml.digit_vectors', hyperparams => '{"n_clusters": 10}');
+SELECT * FROM pgml.train('Handwritten Digit Clusters', 'clustering', 'pgml.digit_vectors', hyperparams => '{"n_clusters": 10}');
-- check out the predictions
SELECT target, pgml.predict('Handwritten Digit Clusters', image) AS prediction
diff --git a/pgml-extension/examples/decomposition.sql b/pgml-extension/examples/decomposition.sql
new file mode 100644
index 000000000..d9e387d90
--- /dev/null
+++ b/pgml-extension/examples/decomposition.sql
@@ -0,0 +1,60 @@
+-- This example reduces the dimensionality of images in the sklean digits dataset
+-- which is a copy of the test set of the UCI ML hand-written digits datasets
+-- https://archive.ics.uci.edu/ml/datasets/Optical+Recognition+of+Handwritten+Digits
+--
+-- This demonstrates using a table with a single array feature column
+-- for decomposition to reduce dimensionality.
+--
+-- Exit on error (psql)
+-- \set ON_ERROR_STOP true
+\timing on
+
+SELECT pgml.load_dataset('digits');
+
+-- view the dataset
+SELECT left(image::text, 40) || ',...}', target FROM pgml.digits LIMIT 10;
+
+-- create a view of just the vectors for decomposition, without any labels
+CREATE VIEW digit_vectors AS
+SELECT image FROM pgml.digits;
+
+SELECT * FROM pgml.train('Handwritten Digits Reduction', 'decomposition', 'digit_vectors');
+
+-- check out the decomposed vectors
+SELECT target, pgml.decompose('Handwritten Digits Reduction', image) AS pca
+FROM pgml.digits
+LIMIT 10;
+
+--
+-- After a project has been trained, omitted parameters will be reused from previous training runs
+-- In these examples we'll reuse the training data snapshots from the initial call.
+--
+
+-- We can reduce the image vectors from 64 dimensions to 3 components
+SELECT * FROM pgml.train('Handwritten Digits Reduction', hyperparams => '{"n_components": 3}');
+
+-- check out the reduced vectors
+SELECT target, pgml.decompose('Handwritten Digits Reduction', image) AS pca
+FROM pgml.digits
+LIMIT 10;
+
+-- check out all that hard work
+SELECT trained_models.* FROM pgml.trained_models
+ JOIN pgml.models on models.id = trained_models.id
+ORDER BY models.metrics->>'cumulative_explained_variance' DESC LIMIT 5;
+
+-- deploy the PCA model for prediction use
+SELECT * FROM pgml.deploy('Handwritten Digits Reduction', 'most_recent', 'pca');
+-- check out that throughput
+SELECT * FROM pgml.deployed_models ORDER BY deployed_at DESC LIMIT 5;
+
+-- deploy the "best" model for prediction use
+SELECT * FROM pgml.deploy('Handwritten Digits Reduction', 'best_score');
+SELECT * FROM pgml.deploy('Handwritten Digits Reduction', 'most_recent');
+SELECT * FROM pgml.deploy('Handwritten Digits Reduction', 'rollback');
+SELECT * FROM pgml.deploy('Handwritten Digits Reduction', 'best_score', 'pca');
+
+-- check out the improved predictions
+SELECT target, pgml.predict('Handwritten Digits Reduction', image) AS prediction
+FROM pgml.digits
+LIMIT 10;
diff --git a/pgml-extension/examples/image_classification.sql b/pgml-extension/examples/image_classification.sql
index 0dea5749a..f9a7888a6 100644
--- a/pgml-extension/examples/image_classification.sql
+++ b/pgml-extension/examples/image_classification.sql
@@ -5,9 +5,8 @@
-- This demonstrates using a table with a single array feature column
-- for classification.
--
--- The final result after a few seconds of training is not terrible. Maybe not perfect
--- enough for mission critical applications, but it's telling how quickly "off the shelf"
--- solutions can solve problems these days.
+-- Some algorithms converge on this trivial dataset in under a second, demonstrating the
+-- speed with which modern machines can "learn" from example data.
-- Exit on error (psql)
-- \set ON_ERROR_STOP true
diff --git a/pgml-extension/examples/regression.sql b/pgml-extension/examples/regression.sql
index 2970e7e59..e355b6393 100644
--- a/pgml-extension/examples/regression.sql
+++ b/pgml-extension/examples/regression.sql
@@ -1,4 +1,4 @@
--- This example trains models on the sklean diabetes dataset
+-- This example trains models on the sklearn diabetes dataset
-- Source URL: https://www4.stat.ncsu.edu/~boos/var.select/diabetes.html
-- For more information see:
-- Bradley Efron, Trevor Hastie, Iain Johnstone and Robert Tibshirani (2004)
diff --git a/pgml-extension/sql/pgml--2.8.3--2.8.4.sql b/pgml-extension/sql/pgml--2.8.3--2.8.4.sql
new file mode 100644
index 000000000..bcaa0e7b9
--- /dev/null
+++ b/pgml-extension/sql/pgml--2.8.3--2.8.4.sql
@@ -0,0 +1,13 @@
+ALTER TYPE pgml.task RENAME VALUE 'cluster' TO 'clustering';
+ALTER TYPE pgml.task ADD VALUE IF NOT EXISTS 'decomposition';
+
+ALTER TYPE pgml.algorithm ADD VALUE IF NOT EXISTS 'pca';
+
+-- pgml::api::decompose
+CREATE FUNCTION pgml."decompose"(
+ "project_name" TEXT, /* alloc::string::String */
+ "vector" FLOAT4[] /* Vec */
+) RETURNS FLOAT4[] /* Vec */
+ IMMUTABLE STRICT PARALLEL SAFE
+LANGUAGE c /* Rust */
+AS 'MODULE_PATHNAME', 'decompose_wrapper';
diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs
index 54bb17799..697d6390b 100644
--- a/pgml-extension/src/api.rs
+++ b/pgml-extension/src/api.rs
@@ -225,8 +225,10 @@ fn train_joint(
};
// fix up default algorithm for clustering
- let algorithm = if algorithm == Algorithm::linear && project.task == Task::cluster {
+ let algorithm = if algorithm == Algorithm::linear && project.task == Task::clustering {
Algorithm::kmeans
+ } else if algorithm == Algorithm::linear && project.task == Task::decomposition {
+ Algorithm::pca
} else {
algorithm
};
@@ -482,6 +484,13 @@ fn predict_batch(project_name: &str, features: Vec) -> SetOfIterator<'stati
))
}
+#[pg_extern(immutable, parallel_safe, strict, name = "decompose")]
+fn decompose(project_name: &str, vector: Vec) -> Vec {
+ let model_id = Project::get_deployed_model_id(project_name);
+ let model = unwrap_or_error!(Model::find_cached(model_id));
+ unwrap_or_error!(model.decompose(&vector))
+}
+
#[pg_extern(immutable, parallel_safe, strict, name = "predict")]
fn predict_row(project_name: &str, row: pgrx::datum::AnyElement) -> f32 {
predict_model_row(Project::get_deployed_model_id(project_name), row)
diff --git a/pgml-extension/src/bindings/mod.rs b/pgml-extension/src/bindings/mod.rs
index 294e0fe3a..52592fe94 100644
--- a/pgml-extension/src/bindings/mod.rs
+++ b/pgml-extension/src/bindings/mod.rs
@@ -78,12 +78,24 @@ pub mod xgboost;
pub type Fit = fn(dataset: &Dataset, hyperparams: &Hyperparams) -> Result>;
+use std::any::Any;
+
+pub trait AToAny: 'static {
+ fn as_any(&self) -> &dyn Any;
+}
+
+impl AToAny for T {
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+}
+
/// The Bindings trait that has to be implemented by all algorithm
/// providers we use in PostgresML. We don't rely on Serde serialization,
/// since scikit-learn estimators were originally serialized in pure Python as
-/// pickled objects, and neither xgboost or linfa estimators completely
+/// pickled objects, and neither xgboost nor linfa estimators completely
/// implement serde.
-pub trait Bindings: Send + Sync + Debug {
+pub trait Bindings: Send + Sync + Debug + AToAny {
/// Predict a set of datapoints.
fn predict(&self, features: &[f32], num_features: usize, num_classes: usize) -> Result>;
diff --git a/pgml-extension/src/bindings/sklearn/mod.rs b/pgml-extension/src/bindings/sklearn/mod.rs
index bee066b87..ccd49a50f 100644
--- a/pgml-extension/src/bindings/sklearn/mod.rs
+++ b/pgml-extension/src/bindings/sklearn/mod.rs
@@ -14,7 +14,11 @@ use anyhow::Result;
use pyo3::prelude::*;
use pyo3::types::PyTuple;
-use crate::{bindings::Bindings, create_pymodule, orm::*};
+use crate::{
+ bindings::{Bindings, TracebackError},
+ create_pymodule,
+ orm::*,
+};
create_pymodule!("/src/bindings/sklearn/sklearn.py");
@@ -35,8 +39,8 @@ wrap_fit!(random_forest_regression, "random_forest_regression");
wrap_fit!(xgboost_regression, "xgboost_regression");
wrap_fit!(xgboost_random_forest_regression, "xgboost_random_forest_regression");
wrap_fit!(
- orthogonal_matching_persuit_regression,
- "orthogonal_matching_persuit_regression"
+ orthogonal_matching_pursuit_regression,
+ "orthogonal_matching_pursuit_regression"
);
wrap_fit!(bayesian_ridge_regression, "bayesian_ridge_regression");
wrap_fit!(
@@ -109,6 +113,8 @@ wrap_fit!(spectral, "spectral_clustering");
wrap_fit!(spectral_bi, "spectral_biclustering");
wrap_fit!(spectral_co, "spectral_coclustering");
+wrap_fit!(pca, "pca_decomposition");
+
fn fit(dataset: &Dataset, hyperparams: &Hyperparams, algorithm_task: &'static str) -> Result> {
let hyperparams = serde_json::to_string(hyperparams).unwrap();
@@ -293,9 +299,9 @@ pub fn classification_metrics(ground_truth: &[f32], y_hat: &[f32], num_classes:
Ok(scores)
}
-pub fn cluster_metrics(num_features: usize, inputs: &[f32], labels: &[f32]) -> Result> {
+pub fn clustering_metrics(num_features: usize, inputs: &[f32], labels: &[f32]) -> Result> {
Python::with_gil(|py| {
- let calculate_metric = get_module!(PY_MODULE).getattr(py, "cluster_metrics")?;
+ let calculate_metric = get_module!(PY_MODULE).getattr(py, "clustering_metrics")?;
let scores: HashMap = calculate_metric
.call1(py, (num_features, PyTuple::new(py, [inputs, labels])))?
@@ -304,3 +310,15 @@ pub fn cluster_metrics(num_features: usize, inputs: &[f32], labels: &[f32]) -> R
Ok(scores)
})
}
+
+pub fn decomposition_metrics(bindings: &Box) -> Result> {
+ Python::with_gil(|py| match bindings.as_any().downcast_ref::() {
+ Some(estimator) => {
+ let calculate_metric = get_module!(PY_MODULE).getattr(py, "decomposition_metrics")?;
+ let metrics = calculate_metric.call1(py, PyTuple::new(py, [&estimator.estimator]));
+ let metrics = metrics.format_traceback(py)?.extract(py).format_traceback(py)?;
+ Ok(metrics)
+ }
+ None => error!("Can't compute decomposition metrics for bindings other than sklearn"),
+ })
+}
diff --git a/pgml-extension/src/bindings/sklearn/sklearn.py b/pgml-extension/src/bindings/sklearn/sklearn.py
index b27638a55..eab8faf57 100644
--- a/pgml-extension/src/bindings/sklearn/sklearn.py
+++ b/pgml-extension/src/bindings/sklearn/sklearn.py
@@ -43,7 +43,7 @@
"elastic_net_regression": sklearn.linear_model.ElasticNet,
"least_angle_regression": sklearn.linear_model.Lars,
"lasso_least_angle_regression": sklearn.linear_model.LassoLars,
- "orthogonal_matching_persuit_regression": sklearn.linear_model.OrthogonalMatchingPursuit,
+ "orthogonal_matching_pursuit_regression": sklearn.linear_model.OrthogonalMatchingPursuit,
"bayesian_ridge_regression": sklearn.linear_model.BayesianRidge,
"automatic_relevance_determination_regression": sklearn.linear_model.ARDRegression,
"stochastic_gradient_descent_regression": sklearn.linear_model.SGDRegressor,
@@ -95,6 +95,7 @@
"spectral_clustering": sklearn.cluster.SpectralClustering,
"spectral_biclustering": sklearn.cluster.SpectralBiclustering,
"spectral_coclustering": sklearn.cluster.SpectralCoclustering,
+ "pca_decomposition": sklearn.decomposition.PCA,
}
@@ -182,7 +183,10 @@ def predictor_joint(estimator, num_targets):
def predict(X):
X = np.asarray(X).reshape((-1, estimator.n_features_in_))
- y_hat = estimator.predict(X)
+ if hasattr(estimator.__class__, 'predict'):
+ y_hat = estimator.predict(X)
+ else:
+ y_hat = estimator.transform(X)
# Only support single value models for just now.
if num_targets == 1:
@@ -238,6 +242,8 @@ def calculate_metric(metric_name):
func = mean_absolute_error
elif metric_name == "confusion_matrix":
func = confusion_matrix
+ elif metric_name == "variance":
+ func = variance
else:
raise Exception(f"Unknown metric requested: {metric_name}")
@@ -300,10 +306,15 @@ def classification_metrics(y_true, y_hat):
}
-def cluster_metrics(num_features, inputs_labels):
+def clustering_metrics(num_features, inputs_labels):
inputs = np.asarray(inputs_labels[0]).reshape((-1, num_features))
labels = np.asarray(inputs_labels[1]).reshape((-1, 1))
return {
"silhouette": silhouette_score(inputs, labels),
}
+
+def decomposition_metrics(pca):
+ return {
+ "cumulative_explained_variance": sum(pca.explained_variance_ratio_)
+ }
diff --git a/pgml-extension/src/orm/algorithm.rs b/pgml-extension/src/orm/algorithm.rs
index 21a87e3bf..64a754d9c 100644
--- a/pgml-extension/src/orm/algorithm.rs
+++ b/pgml-extension/src/orm/algorithm.rs
@@ -48,6 +48,7 @@ pub enum Algorithm {
spectral_bi,
spectral_co,
catboost,
+ pca,
}
impl std::str::FromStr for Algorithm {
@@ -99,6 +100,7 @@ impl std::str::FromStr for Algorithm {
"spectral_bi" => Ok(Algorithm::spectral_bi),
"spectral_co" => Ok(Algorithm::spectral_co),
"catboost" => Ok(Algorithm::catboost),
+ "pca" => Ok(Algorithm::pca),
_ => Err(()),
}
}
@@ -151,6 +153,7 @@ impl std::string::ToString for Algorithm {
Algorithm::spectral_bi => "spectral_bi".to_string(),
Algorithm::spectral_co => "spectral_co".to_string(),
Algorithm::catboost => "catboost".to_string(),
+ Algorithm::pca => "pca".to_string(),
}
}
}
diff --git a/pgml-extension/src/orm/model.rs b/pgml-extension/src/orm/model.rs
index a45cbd970..fb9eaae47 100644
--- a/pgml-extension/src/orm/model.rs
+++ b/pgml-extension/src/orm/model.rs
@@ -370,27 +370,27 @@ impl Model {
Runtime::rust => {
match algorithm {
Algorithm::xgboost => {
- crate::bindings::xgboost::Estimator::from_bytes(&data)?
+ xgboost::Estimator::from_bytes(&data)?
}
Algorithm::lightgbm => {
- crate::bindings::lightgbm::Estimator::from_bytes(&data)?
+ lightgbm::Estimator::from_bytes(&data)?
}
Algorithm::linear => match project.task {
Task::regression => {
- crate::bindings::linfa::LinearRegression::from_bytes(&data)?
+ linfa::LinearRegression::from_bytes(&data)?
}
Task::classification => {
- crate::bindings::linfa::LogisticRegression::from_bytes(&data)?
+ linfa::LogisticRegression::from_bytes(&data)?
}
_ => bail!("No default runtime available for tasks other than `classification` and `regression` when using a linear algorithm."),
},
- Algorithm::svm => crate::bindings::linfa::Svm::from_bytes(&data)?,
+ Algorithm::svm => linfa::Svm::from_bytes(&data)?,
_ => todo!(), //smartcore_load(&data, task, algorithm, &hyperparams),
}
}
#[cfg(feature = "python")]
- Runtime::python => crate::bindings::sklearn::Estimator::from_bytes(&data)?,
+ Runtime::python => sklearn::Estimator::from_bytes(&data)?,
#[cfg(not(feature = "python"))]
Runtime::python => {
@@ -468,7 +468,8 @@ impl Model {
Algorithm::svm => linfa::Svm::fit,
_ => todo!(),
},
- Task::cluster => todo!(),
+ Task::decomposition => todo!(),
+ Task::clustering => todo!(),
_ => error!("use pgml.tune for transformers tasks"),
},
@@ -488,7 +489,7 @@ impl Model {
Algorithm::random_forest => sklearn::random_forest_regression,
Algorithm::xgboost => sklearn::xgboost_regression,
Algorithm::xgboost_random_forest => sklearn::xgboost_random_forest_regression,
- Algorithm::orthogonal_matching_pursuit => sklearn::orthogonal_matching_persuit_regression,
+ Algorithm::orthogonal_matching_pursuit => sklearn::orthogonal_matching_pursuit_regression,
Algorithm::bayesian_ridge => sklearn::bayesian_ridge_regression,
Algorithm::automatic_relevance_determination => {
sklearn::automatic_relevance_determination_regression
@@ -512,7 +513,7 @@ impl Model {
Algorithm::linear_svm => sklearn::linear_svm_regression,
Algorithm::lightgbm => sklearn::lightgbm_regression,
Algorithm::catboost => sklearn::catboost_regression,
- _ => panic!("{:?} does not support regression", self.algorithm),
+ _ => error!("{:?} does not support regression", self.algorithm),
},
Task::classification => match self.algorithm {
Algorithm::linear => sklearn::linear_classification,
@@ -534,15 +535,19 @@ impl Model {
Algorithm::linear_svm => sklearn::linear_svm_classification,
Algorithm::lightgbm => sklearn::lightgbm_classification,
Algorithm::catboost => sklearn::catboost_classification,
- _ => panic!("{:?} does not support classification", self.algorithm),
+ _ => error!("{:?} does not support classification", self.algorithm),
},
- Task::cluster => match self.algorithm {
+ Task::clustering => match self.algorithm {
Algorithm::affinity_propagation => sklearn::affinity_propagation,
Algorithm::birch => sklearn::birch,
Algorithm::kmeans => sklearn::kmeans,
Algorithm::mini_batch_kmeans => sklearn::mini_batch_kmeans,
Algorithm::mean_shift => sklearn::mean_shift,
- _ => panic!("{:?} does not support clustering", self.algorithm),
+ _ => error!("{:?} does not support clustering", self.algorithm),
+ },
+ Task::decomposition => match self.algorithm {
+ Algorithm::pca => sklearn::pca,
+ _ => error!("{:?} does not support clustering", self.algorithm),
},
_ => error!("use pgml.tune for transformers tasks"),
},
@@ -618,7 +623,7 @@ impl Model {
Task::regression => {
#[cfg(all(feature = "python", any(test, feature = "pg_test")))]
{
- let sklearn_metrics = crate::bindings::sklearn::regression_metrics(y_test, &y_hat).unwrap();
+ let sklearn_metrics = sklearn::regression_metrics(y_test, &y_hat).unwrap();
metrics.insert("sklearn_r2".to_string(), sklearn_metrics["r2"]);
metrics.insert("sklearn_mean_absolute_error".to_string(), sklearn_metrics["mae"]);
metrics.insert("sklearn_mean_squared_error".to_string(), sklearn_metrics["mse"]);
@@ -641,8 +646,7 @@ impl Model {
#[cfg(all(feature = "python", any(test, feature = "pg_test")))]
{
let sklearn_metrics =
- crate::bindings::sklearn::classification_metrics(y_test, &y_hat, dataset.num_distinct_labels)
- .unwrap();
+ sklearn::classification_metrics(y_test, &y_hat, dataset.num_distinct_labels).unwrap();
if dataset.num_distinct_labels == 2 {
metrics.insert("sklearn_roc_auc".to_string(), sklearn_metrics["roc_auc"]);
@@ -692,15 +696,24 @@ impl Model {
// This one is inaccurate, I have it in my TODO to reimplement.
metrics.insert("mcc".to_string(), confusion_matrix.mcc());
}
- Task::cluster => {
+ Task::clustering => {
#[cfg(feature = "python")]
{
let sklearn_metrics =
- crate::bindings::sklearn::cluster_metrics(dataset.num_features, &dataset.x_test, &y_hat)
- .unwrap();
+ sklearn::clustering_metrics(dataset.num_features, &dataset.x_test, &y_hat).unwrap();
metrics.insert("silhouette".to_string(), sklearn_metrics["silhouette"]);
}
}
+ Task::decomposition => {
+ #[cfg(feature = "python")]
+ {
+ let sklearn_metrics = sklearn::decomposition_metrics(self.bindings.as_ref().unwrap()).unwrap();
+ metrics.insert(
+ "cumulative_explained_variance".to_string(),
+ sklearn_metrics["cumulative_explained_variance"],
+ );
+ }
+ }
task => error!("No test metrics available for task: {:?}", task),
}
@@ -1165,4 +1178,11 @@ impl Model {
.unwrap()
.predict(features, self.num_features, self.num_classes)
}
+
+ pub fn decompose(&self, vector: &[f32]) -> Result> {
+ self.bindings
+ .as_ref()
+ .unwrap()
+ .predict(vector, self.num_features, self.num_classes)
+ }
}
diff --git a/pgml-extension/src/orm/task.rs b/pgml-extension/src/orm/task.rs
index 1116d98ae..7c23d0861 100644
--- a/pgml-extension/src/orm/task.rs
+++ b/pgml-extension/src/orm/task.rs
@@ -6,31 +6,33 @@ use serde::Deserialize;
pub enum Task {
regression,
classification,
+ decomposition,
+ clustering,
question_answering,
summarization,
translation,
text_classification,
text_generation,
text2text,
- cluster,
embedding,
text_pair_classification,
conversation,
}
-// unfortunately the pgrx macro expands the enum names to underscore, but huggingface uses dash
+// unfortunately the pgrx macro expands the enum names to underscore, but hugging face uses dash
impl Task {
pub fn to_pg_enum(&self) -> String {
match *self {
Task::regression => "regression".to_string(),
Task::classification => "classification".to_string(),
+ Task::decomposition => "decomposition".to_string(),
+ Task::clustering => "clustering".to_string(),
Task::question_answering => "question_answering".to_string(),
Task::summarization => "summarization".to_string(),
Task::translation => "translation".to_string(),
Task::text_classification => "text_classification".to_string(),
Task::text_generation => "text_generation".to_string(),
Task::text2text => "text2text".to_string(),
- Task::cluster => "cluster".to_string(),
Task::embedding => "embedding".to_string(),
Task::text_pair_classification => "text_pair_classification".to_string(),
Task::conversation => "conversation".to_string(),
@@ -45,13 +47,14 @@ impl Task {
match self {
Task::regression => "r2",
Task::classification => "f1",
+ Task::decomposition => "cumulative_explained_variance",
+ Task::clustering => "silhouette",
Task::question_answering => "f1",
Task::translation => "blue",
Task::summarization => "rouge_ngram_f1",
Task::text_classification => "f1",
Task::text_generation => "perplexity",
Task::text2text => "perplexity",
- Task::cluster => "silhouette",
Task::embedding => error!("No default target metric for embedding task"),
Task::text_pair_classification => "f1",
Task::conversation => "bleu",
@@ -63,13 +66,14 @@ impl Task {
match self {
Task::regression => true,
Task::classification => true,
+ Task::decomposition => true,
+ Task::clustering => true,
Task::question_answering => true,
Task::translation => true,
Task::summarization => true,
Task::text_classification => true,
Task::text_generation => false,
Task::text2text => false,
- Task::cluster => true,
Task::embedding => error!("No default target metric positive for embedding task"),
Task::text_pair_classification => true,
Task::conversation => true,
@@ -105,13 +109,14 @@ impl std::str::FromStr for Task {
match input {
"regression" => Ok(Task::regression),
"classification" => Ok(Task::classification),
+ "decomposition" => Ok(Task::decomposition),
+ "clustering" => Ok(Task::clustering),
"question-answering" | "question_answering" => Ok(Task::question_answering),
"summarization" => Ok(Task::summarization),
"translation" => Ok(Task::translation),
"text-classification" | "text_classification" => Ok(Task::text_classification),
"text-generation" | "text_generation" => Ok(Task::text_generation),
"text2text" => Ok(Task::text2text),
- "cluster" => Ok(Task::cluster),
"text-pair-classification" | "text_pair_classification" => Ok(Task::text_pair_classification),
"conversation" => Ok(Task::conversation),
_ => Err(()),
@@ -124,13 +129,14 @@ impl std::string::ToString for Task {
match *self {
Task::regression => "regression".to_string(),
Task::classification => "classification".to_string(),
+ Task::decomposition => "decomposition".to_string(),
+ Task::clustering => "clustering".to_string(),
Task::question_answering => "question-answering".to_string(),
Task::summarization => "summarization".to_string(),
Task::translation => "translation".to_string(),
Task::text_classification => "text-classification".to_string(),
Task::text_generation => "text-generation".to_string(),
Task::text2text => "text2text".to_string(),
- Task::cluster => "cluster".to_string(),
Task::embedding => "embedding".to_string(),
Task::text_pair_classification => "text-pair-classification".to_string(),
Task::conversation => "conversation".to_string(),
diff --git a/pgml-extension/tests/test.sql b/pgml-extension/tests/test.sql
index a6c75dee9..2490678ee 100644
--- a/pgml-extension/tests/test.sql
+++ b/pgml-extension/tests/test.sql
@@ -21,7 +21,8 @@ SELECT pgml.load_dataset('iris');
SELECT pgml.load_dataset('linnerud');
SELECT pgml.load_dataset('wine');
-\i examples/cluster.sql
+\i examples/clustering.sql
+\i examples/decomposition.sql
\i examples/binary_classification.sql
\i examples/image_classification.sql
\i examples/joint_regression.sql