From b2d6a30aeb155948fbee1c3a9f0ce9cb5bfb278a Mon Sep 17 00:00:00 2001 From: Kyle Kastner Date: Tue, 17 Jun 2014 11:55:52 -0500 Subject: [PATCH 1/2] IncrementalPCA implementation --- benchmarks/bench_plot_incremental_pca.py | 156 +++++++++++ doc/modules/classes.rst | 1 + doc/modules/decomposition.rst | 43 ++- doc/modules/scaling_strategies.rst | 1 + .../decomposition/plot_incremental_pca.py | 57 ++++ sklearn/decomposition/__init__.py | 2 + sklearn/decomposition/base.py | 158 +++++++++++ sklearn/decomposition/incremental_pca.py | 252 ++++++++++++++++++ .../tests/test_incremental_pca.py | 224 ++++++++++++++++ sklearn/utils/estimator_checks.py | 4 +- sklearn/utils/extmath.py | 77 +++++- sklearn/utils/tests/test_extmath.py | 82 ++++++ 12 files changed, 1045 insertions(+), 12 deletions(-) create mode 100644 benchmarks/bench_plot_incremental_pca.py create mode 100644 examples/decomposition/plot_incremental_pca.py create mode 100644 sklearn/decomposition/base.py create mode 100644 sklearn/decomposition/incremental_pca.py create mode 100644 sklearn/decomposition/tests/test_incremental_pca.py diff --git a/benchmarks/bench_plot_incremental_pca.py b/benchmarks/bench_plot_incremental_pca.py new file mode 100644 index 0000000000000..495d58f0f43ee --- /dev/null +++ b/benchmarks/bench_plot_incremental_pca.py @@ -0,0 +1,156 @@ +""" +======================== +IncrementalPCA benchmark +======================== + +Benchmarks for IncrementalPCA + +""" + +import numpy as np +import gc +from time import time +from collections import defaultdict +import matplotlib.pyplot as plt +from sklearn.datasets import fetch_lfw_people +from sklearn.decomposition import IncrementalPCA, RandomizedPCA, PCA + + +def plot_results(X, y, label): + plt.plot(X, y, label=label, marker='o') + + +def benchmark(estimator, data): + gc.collect() + print("Benching %s" % estimator) + t0 = time() + estimator.fit(data) + training_time = time() - t0 + data_t = estimator.transform(data) + data_r = estimator.inverse_transform(data_t) + reconstruction_error = np.mean(np.abs(data - data_r)) + return {'time': training_time, 'error': reconstruction_error} + + +def plot_feature_times(all_times, batch_size, all_components, data): + plt.figure() + plot_results(all_components, all_times['pca'], label="PCA") + plot_results(all_components, all_times['ipca'], + label="IncrementalPCA, bsize=%i" % batch_size) + plot_results(all_components, all_times['rpca'], label="RandomizedPCA") + plt.legend(loc="upper left") + plt.suptitle("Algorithm runtime vs. n_components\n \ + LFW, size %i x %i" % data.shape) + plt.xlabel("Number of components (out of max %i)" % data.shape[1]) + plt.ylabel("Time (seconds)") + + +def plot_feature_errors(all_errors, batch_size, all_components, data): + plt.figure() + plot_results(all_components, all_errors['pca'], label="PCA") + plot_results(all_components, all_errors['ipca'], + label="IncrementalPCA, bsize=%i" % batch_size) + plot_results(all_components, all_errors['rpca'], label="RandomizedPCA") + plt.legend(loc="lower left") + plt.suptitle("Algorithm error vs. n_components\n" + "LFW, size %i x %i" % data.shape) + plt.xlabel("Number of components (out of max %i)" % data.shape[1]) + plt.ylabel("Mean absolute error") + + +def plot_batch_times(all_times, n_features, all_batch_sizes, data): + plt.figure() + plot_results(all_batch_sizes, all_times['pca'], label="PCA") + plot_results(all_batch_sizes, all_times['rpca'], label="RandomizedPCA") + plot_results(all_batch_sizes, all_times['ipca'], label="IncrementalPCA") + plt.legend(loc="lower left") + plt.suptitle("Algorithm runtime vs. batch_size for n_components %i\n \ + LFW, size %i x %i" % ( + n_features, data.shape[0], data.shape[1])) + plt.xlabel("Batch size") + plt.ylabel("Time (seconds)") + + +def plot_batch_errors(all_errors, n_features, all_batch_sizes, data): + plt.figure() + plot_results(all_batch_sizes, all_errors['pca'], label="PCA") + plot_results(all_batch_sizes, all_errors['ipca'], label="IncrementalPCA") + plt.legend(loc="lower left") + plt.suptitle("Algorithm error vs. batch_size for n_components %i\n \ + LFW, size %i x %i" % ( + n_features, data.shape[0], data.shape[1])) + plt.xlabel("Batch size") + plt.ylabel("Mean absolute error") + + +def fixed_batch_size_comparison(data): + all_features = [i.astype(int) for i in np.linspace(data.shape[1] // 10, + data.shape[1], num=5)] + batch_size = 1000 + # Compare runtimes and error for fixed batch size + all_times = defaultdict(list) + all_errors = defaultdict(list) + for n_components in all_features: + pca = PCA(n_components=n_components) + rpca = RandomizedPCA(n_components=n_components, random_state=1999) + ipca = IncrementalPCA(n_components=n_components, batch_size=batch_size) + results_dict = {k: benchmark(est, data) for k, est in [('pca', pca), + ('ipca', ipca), + ('rpca', rpca)]} + + for k in sorted(results_dict.keys()): + all_times[k].append(results_dict[k]['time']) + all_errors[k].append(results_dict[k]['error']) + + plot_feature_times(all_times, batch_size, all_features, data) + plot_feature_errors(all_errors, batch_size, all_features, data) + + +def variable_batch_size_comparison(data): + batch_sizes = [i.astype(int) for i in np.linspace(data.shape[0] // 10, + data.shape[0], num=10)] + + for n_components in [i.astype(int) for i in + np.linspace(data.shape[1] // 10, + data.shape[1], num=4)]: + all_times = defaultdict(list) + all_errors = defaultdict(list) + pca = PCA(n_components=n_components) + rpca = RandomizedPCA(n_components=n_components, random_state=1999) + results_dict = {k: benchmark(est, data) for k, est in [('pca', pca), + ('rpca', rpca)]} + + # Create flat baselines to compare the variation over batch size + all_times['pca'].extend([results_dict['pca']['time']] * + len(batch_sizes)) + all_errors['pca'].extend([results_dict['pca']['error']] * + len(batch_sizes)) + all_times['rpca'].extend([results_dict['rpca']['time']] * + len(batch_sizes)) + all_errors['rpca'].extend([results_dict['rpca']['error']] * + len(batch_sizes)) + for batch_size in batch_sizes: + ipca = IncrementalPCA(n_components=n_components, + batch_size=batch_size) + results_dict = {k: benchmark(est, data) for k, est in [('ipca', + ipca)]} + all_times['ipca'].append(results_dict['ipca']['time']) + all_errors['ipca'].append(results_dict['ipca']['error']) + + plot_batch_times(all_times, n_components, batch_sizes, data) + # RandomizedPCA error is always worse (approx 100x) than other PCA + # tests + plot_batch_errors(all_errors, n_components, batch_sizes, data) + +faces = fetch_lfw_people(resize=.2, min_faces_per_person=5) +# limit dataset to 5000 people (don't care who they are!) +X = faces.data[:5000] +n_samples, h, w = faces.images.shape +n_features = X.shape[1] + +X -= X.mean(axis=0) +X /= X.std(axis=0) + +fixed_batch_size_comparison(X) +variable_batch_size_comparison(X) +plt.show() diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst index 495bcd223f6d2..4fd228c0e5c46 100644 --- a/doc/modules/classes.rst +++ b/doc/modules/classes.rst @@ -270,6 +270,7 @@ Samples generator :template: class.rst decomposition.PCA + decomposition.IncrementalPCA decomposition.ProjectedGradientNMF decomposition.RandomizedPCA decomposition.KernelPCA diff --git a/doc/modules/decomposition.rst b/doc/modules/decomposition.rst index d0edaf5e235d0..f34c4cd1e8e1b 100644 --- a/doc/modules/decomposition.rst +++ b/doc/modules/decomposition.rst @@ -28,8 +28,7 @@ project the data onto the singular space while scaling each component to unit variance. This is often useful if the models down-stream make strong assumptions on the isotropy of the signal: this is for example the case for Support Vector Machines with the RBF kernel and the K-Means -clustering algorithm. However in that case the inverse transform is no -longer exact since some information is lost while forward transforming. +clustering algorithm. Below is an example of the iris dataset, which is comprised of 4 features, projected on the 2 dimensions that explain most variance: @@ -57,6 +56,46 @@ data based on the amount of variance it explains. As such it implements a * :ref:`example_decomposition_plot_pca_vs_fa_model_selection.py` +.. _IncrementalPCA: + +Incremental PCA +--------------- + +The :class:`PCA` object is very useful, but has certain limitations for +large datasets. The biggest limitation is that :class:`PCA` only supports +batch processing, which means all of the data to be processed must fit in main +memory. The :class:`IncrementalPCA` object uses a different form of +processing and allows for partial computations which almost +exactly match the results of :class:`PCA` while processing the data in a +minibatch fashion. :class:`IncrementalPCA` makes it possible to implement +out-of-core Principal Component Analysis either by: + + * Using its ``partial_fit`` method on chunks of data fetched sequentially + from the local hard drive or a network database. + + * Calling its fit method on a memory mapped file using ``numpy.memmap``. + +:class:`IncrementalPCA` only stores estimates of component and noise variances, +in order update ``explained_variance_ratio_`` incrementally. This is why +memory usage depends on the number of samples per batch, rather than the +number of samples to be processed in the dataset. + +.. figure:: ../auto_examples/decomposition/images/plot_incremental_pca_001.png + :target: ../auto_examples/decomposition/plot_incremental_pca.html + :align: center + :scale: 75% + +.. figure:: ../auto_examples/decomposition/images/plot_incremental_pca_002.png + :target: ../auto_examples/decomposition/plot_incremental_pca.html + :align: center + :scale: 75% + + +.. topic:: Examples: + + * :ref:`example_decomposition_plot_incremental_pca.py` + + .. _RandomizedPCA: Approximate PCA diff --git a/doc/modules/scaling_strategies.rst b/doc/modules/scaling_strategies.rst index 087512a02dfde..56f7f5ec8f3b7 100644 --- a/doc/modules/scaling_strategies.rst +++ b/doc/modules/scaling_strategies.rst @@ -69,6 +69,7 @@ Here is a list of incremental estimators for different tasks: + :class:`sklearn.cluster.MiniBatchKMeans` - Decomposition / feature Extraction + :class:`sklearn.decomposition.MiniBatchDictionaryLearning` + + :class:`sklearn.decomposition.IncrementalPCA` + :class:`sklearn.cluster.MiniBatchKMeans` For classification, a somewhat important thing to note is that although a diff --git a/examples/decomposition/plot_incremental_pca.py b/examples/decomposition/plot_incremental_pca.py new file mode 100644 index 0000000000000..bfcf5d3158fa0 --- /dev/null +++ b/examples/decomposition/plot_incremental_pca.py @@ -0,0 +1,57 @@ +""" + +=============== +Incremental PCA +=============== + +Incremental principal component analysis (IPCA) is typically used as a +replacement for principal component analysis (PCA) when the dataset to be +decomposed is too large to fit in memory. IPCA builds a low-rank approximation +for the input data using an amount of memory which is independent of the +input data size. + +This example serves as a visual check that IPCA is able to find a similar +projection of the data to PCA (to a sign flip), while only processing a +few samples at a time. This can be considered a "toy example", as IPCA is +intended for large datasets which do not fit in main memory, requiring +incremental approaches. + +""" +print(__doc__) + +# Authors: Kyle Kastner +# License: BSD 3 clause + +import numpy as np +import matplotlib.pyplot as plt + +from sklearn.datasets import load_iris +from sklearn.decomposition import PCA, IncrementalPCA + +iris = load_iris() +X = iris.data +y = iris.target + +n_components = 2 +ipca = IncrementalPCA(n_components=n_components, batch_size=10) +X_ipca = ipca.fit_transform(X) + +pca = PCA(n_components=n_components) +X_pca = pca.fit_transform(X) + +for X_transformed, title in [(X_ipca, "Incremental PCA"), (X_pca, "PCA")]: + plt.figure(figsize=(8, 8)) + for c, i, target_name in zip("rgb", [0, 1, 2], iris.target_names): + plt.scatter(X_transformed[y == i, 0], X_transformed[y == i, 1], + c=c, label=target_name) + + if "Incremental" in title: + err = np.abs(np.abs(X_pca) - np.abs(X_ipca)).mean() + plt.title(title + " of iris dataset\nMean absolute unsigned error " + "%.6f" % err) + else: + plt.title(title + " of iris dataset") + plt.legend(loc="best") + plt.axis([-4, 4, -1.5, 1.5]) + +plt.show() diff --git a/sklearn/decomposition/__init__.py b/sklearn/decomposition/__init__.py index 3eb364af94a2d..6b88a86f9bb5f 100644 --- a/sklearn/decomposition/__init__.py +++ b/sklearn/decomposition/__init__.py @@ -6,6 +6,7 @@ from .nmf import NMF, ProjectedGradientNMF from .pca import PCA, RandomizedPCA +from .incremental_pca import IncrementalPCA from .kernel_pca import KernelPCA from .sparse_pca import SparsePCA, MiniBatchSparsePCA from .truncated_svd import TruncatedSVD @@ -18,6 +19,7 @@ __all__ = ['DictionaryLearning', 'FastICA', + 'IncrementalPCA', 'KernelPCA', 'MiniBatchDictionaryLearning', 'MiniBatchSparsePCA', diff --git a/sklearn/decomposition/base.py b/sklearn/decomposition/base.py new file mode 100644 index 0000000000000..132fd9a8ca645 --- /dev/null +++ b/sklearn/decomposition/base.py @@ -0,0 +1,158 @@ +"""Principal Component Analysis Base Classes""" + +# Author: Alexandre Gramfort +# Olivier Grisel +# Mathieu Blondel +# Denis A. Engemann +# Kyle Kastner +# +# License: BSD 3 clause + +import numpy as np +from scipy import linalg + +from ..base import BaseEstimator, TransformerMixin +from ..utils import check_array +from ..utils.extmath import fast_dot +from ..externals import six +from abc import ABCMeta, abstractmethod + + +class _BasePCA(six.with_metaclass(ABCMeta, BaseEstimator, TransformerMixin)): + """Base class for PCA methods. + + Warning: This class should not be used directly. + Use derived classes instead. + """ + def get_covariance(self): + """Compute data covariance with the generative model. + + ``cov = components_.T * S**2 * components_ + sigma2 * eye(n_features)`` + where S**2 contains the explained variances, and sigma2 contains the + noise variances. + + Returns + ------- + cov : array, shape=(n_features, n_features) + Estimated covariance of data. + """ + components_ = self.components_ + exp_var = self.explained_variance_ + if self.whiten: + components_ = components_ * np.sqrt(exp_var[:, np.newaxis]) + exp_var_diff = np.maximum(exp_var - self.noise_variance_, 0.) + cov = np.dot(components_.T * exp_var_diff, components_) + cov.flat[::len(cov) + 1] += self.noise_variance_ # modify diag inplace + return cov + + def get_precision(self): + """Compute data precision matrix with the generative model. + + Equals the inverse of the covariance but computed with + the matrix inversion lemma for efficiency. + + Returns + ------- + precision : array, shape=(n_features, n_features) + Estimated precision of data. + """ + n_features = self.components_.shape[1] + + # handle corner cases first + if self.n_components_ == 0: + return np.eye(n_features) / self.noise_variance_ + if self.n_components_ == n_features: + return linalg.inv(self.get_covariance()) + + # Get precision using matrix inversion lemma + components_ = self.components_ + exp_var = self.explained_variance_ + if self.whiten: + components_ = components_ * np.sqrt(exp_var[:, np.newaxis]) + exp_var_diff = np.maximum(exp_var - self.noise_variance_, 0.) + precision = np.dot(components_, components_.T) / self.noise_variance_ + precision.flat[::len(precision) + 1] += 1. / exp_var_diff + precision = np.dot(components_.T, + np.dot(linalg.inv(precision), components_)) + precision /= -(self.noise_variance_ ** 2) + precision.flat[::len(precision) + 1] += 1. / self.noise_variance_ + return precision + + @abstractmethod + def fit(X, y=None): + """Placeholder for fit. Subclasses should implement this method! + + Fit the model with X. + + Parameters + ---------- + X: array-like, shape (n_samples, n_features) + Training data, where n_samples is the number of samples and + n_features is the number of features. + + Returns + ------- + self: object + Returns the instance itself. + """ + + def transform(self, X, y=None): + """Apply dimensionality reduction to X. + + X is projected on the first principal components previously extracted + from a training set. + + Parameters + ---------- + X : array-like, shape (n_samples, n_features) + New data, where n_samples is the number of samples + and n_features is the number of features. + + Returns + ------- + X_new : array-like, shape (n_samples, n_components) + + Examples + -------- + + >>> import numpy as np + >>> from sklearn.decomposition import IncrementalPCA + >>> X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]]) + >>> ipca = IncrementalPCA(n_components=2, batch_size=3) + >>> ipca.fit(X) + IncrementalPCA(batch_size=3, copy=True, n_components=2, whiten=False) + >>> ipca.transform(X) # doctest: +SKIP + """ + X = check_array(X) + if self.mean_ is not None: + X = X - self.mean_ + X_transformed = fast_dot(X, self.components_.T) + if self.whiten: + X_transformed /= np.sqrt(self.explained_variance_) + return X_transformed + + def inverse_transform(self, X, y=None): + """Transform data back to its original space. + + In other words, return an input X_original whose transform would be X. + + Parameters + ---------- + X : array-like, shape (n_samples, n_components) + New data, where n_samples is the number of samples + and n_components is the number of components. + + Returns + ------- + X_original array-like, shape (n_samples, n_features) + + Notes + ----- + If whitening is enabled, inverse_transform will compute the + exact inverse operation, which includes reversing whitening. + """ + if self.whiten: + return fast_dot(X, np.sqrt(self.explained_variance_[:, np.newaxis]) * + self.components_) + self.mean_ + else: + return fast_dot(X, self.components_) + self.mean_ diff --git a/sklearn/decomposition/incremental_pca.py b/sklearn/decomposition/incremental_pca.py new file mode 100644 index 0000000000000..cf3ca5337cb3e --- /dev/null +++ b/sklearn/decomposition/incremental_pca.py @@ -0,0 +1,252 @@ +"""Incremental Principal Components Analysis.""" + +# Author: Kyle Kastner +# License: BSD 3 clause + +import numpy as np +from scipy import linalg + +from .base import _BasePCA +from ..utils import check_array, gen_batches +from ..utils.extmath import svd_flip, _batch_mean_variance_update + + +class IncrementalPCA(_BasePCA): + """Incremental principal components analysis (IPCA). + + Linear dimensionality reduction using Singular Value Decomposition of + centered data, keeping only the most significant singular vectors to + project the data to a lower dimensional space. + + Depending on the size of the input data, this algorithm can be much more + memory efficient than a PCA. + + This algorithm has constant memory complexity, on the order + of ``batch_size``, enabling use of np.memmap files without loading the + entire file into memory. + + The computational overhead of each SVD is + ``O(batch_size * n_features ** 2)``, but only 2 * batch_size samples + remain in memory at a time. There will be ``n_samples / batch_size`` SVD + computations to get the principal components, versus 1 large SVD of + complexity ``O(n_samples * n_features ** 2)`` for PCA. + + Parameters + ---------- + n_components : int or None, (default=None) + Number of components to keep. If ``n_components `` is ``None``, + then ``n_components`` is set to ``min(n_samples, n_features)``. + + batch_size : int or None, (default=None) + The number of samples to use for each batch. Only used when calling + ``fit``. If ``batch_size`` is ``None``, then ``batch_size`` + is inferred from the data and set to ``5 * n_features``, to provide a + balance between approximation accuracy and memory consumption. + + copy : bool, (default=True) + If False, X will be overwritten. ``copy=False`` can be used to + save memory but is unsafe for general use. + + whiten : bool, optional + When True (False by default) the ``components_`` vectors are divided + by ``n_samples`` times ``components_`` to ensure uncorrelated outputs + with unit component-wise variances. + + Whitening will remove some information from the transformed signal + (the relative variance scales of the components) but can sometimes + improve the predictive accuracy of the downstream estimators by + making data respect some hard-wired assumptions. + + Attributes + ---------- + components_ : array, shape (n_components, n_features) + Components with maximum variance. + + explained_variance_ : array, shape (n_components,) + Variance explained by each of the selected components. + + explained_variance_ratio_ : array, shape (n_components,) + Percentage of variance explained by each of the selected components. + If all components are stored, the sum of explained variances is equal + to 1.0 + + mean_ : array, shape (n_features,) + Per-feature empirical mean, aggregate over calls to ``partial_fit``. + + var_ : array, shape (n_features,) + Per-feature empirical variance, aggregate over calls to ``partial_fit``. + + noise_variance_ : float + The estimated noise covariance following the Probabilistic PCA model + from Tipping and Bishop 1999. See "Pattern Recognition and + Machine Learning" by C. Bishop, 12.2.1 p. 574 or + http://www.miketipping.com/papers/met-mppca.pdf. + + n_components_ : int + The estimated number of components. Relevant when ``n_components=None``. + + n_samples_seen_ : int + The number of samples processed by the estimator. Will be reset on + new calls to fit, but increments across ``partial_fit`` calls. + + Notes + ----- + Implements the incremental PCA model from: + `D. Ross, J. Lim, R. Lin, M. Yang, Incremental Learning for Robust Visual + Tracking, International Journal of Computer Vision, Volume 77, Issue 1-3, + pp. 125-141, May 2008.` + See http://www.cs.toronto.edu/~dross/ivt/RossLimLinYang_ijcv.pdf + + This model is an extension of the Sequential Karhunen-Loeve Transform from: + `A. Levy and M. Lindenbaum, Sequential Karhunen-Loeve Basis Extraction and + its Application to Images, IEEE Transactions on Image Processing, Volume 9, + Number 8, pp. 1371-1374, August 2000.` + See http://www.cs.technion.ac.il/~mic/doc/skl-ip.pdf + + We have specifically abstained from an optimization used by authors of both + papers, a QR decomposition used in specific situations to reduce the + algorithmic complexity of the SVD. The source for this technique is + `Matrix Computations, Third Edition, G. Holub and C. Van Loan, Chapter 5, + section 5.4.4, pp 252-253.`. This technique has been omitted because it is + advantageous only when decomposing a matrix with ``n_samples`` (rows) + >= 5/3 * ``n_features`` (columns), and hurts the readability of the + implemented algorithm. This would be a good opportunity for future + optimization, if it is deemed necessary. + + References + ---------- + D. Ross, J. Lim, R. Lin, M. Yang. Incremental Learning for Robust Visual + Tracking, International Journal of Computer Vision, Volume 77, + Issue 1-3, pp. 125-141, May 2008. + + G. Golub and C. Van Loan. Matrix Computations, Third Edition, Chapter 5, + Section 5.4.4, pp. 252-253. + + See also + -------- + PCA + RandomizedPCA + KernelPCA + SparsePCA + TruncatedSVD + """ + + def __init__(self, n_components=None, whiten=False, copy=True, + batch_size=None): + self.n_components = n_components + self.whiten = whiten + self.copy = copy + self.batch_size = batch_size + self.components_ = None + + def fit(self, X, y=None): + """Fit the model with X, using minibatches of size batch_size. + + Parameters + ---------- + X: array-like, shape (n_samples, n_features) + Training data, where n_samples is the number of samples and + n_features is the number of features. + + y: Passthrough for ``Pipeline`` compatibility. + + Returns + ------- + self: object + Returns the instance itself. + """ + self.components_ = None + self.mean_ = None + self.singular_values_ = None + self.explained_variance_ = None + self.explained_variance_ratio_ = None + self.noise_variance_ = None + self.var_ = None + self.n_samples_seen_ = 0 + X = check_array(X, dtype=np.float) + n_samples, n_features = X.shape + + if self.batch_size is None: + self.batch_size_ = 5 * n_features + else: + self.batch_size_ = self.batch_size + + for batch in gen_batches(n_samples, self.batch_size_): + self.partial_fit(X[batch]) + return self + + def partial_fit(self, X): + """Incremental fit with X. All of X is processed as a single batch. + + Parameters + ---------- + X: array-like, shape (n_samples, n_features) + Training data, where n_samples is the number of samples and + n_features is the number of features. + + Returns + ------- + self: object + Returns the instance itself. + """ + X = check_array(X, copy=self.copy, dtype=np.float) + n_samples, n_features = X.shape + + if self.n_components is None: + self.n_components_ = n_features + elif not 1 <= self.n_components <= n_features: + raise ValueError("n_components=%r invalid for n_features=%d, need " + "more rows than columns for IncrementalPCA " + "processing" % (self.n_components, n_features)) + else: + self.n_components_ = self.n_components + + if (self.components_ is not None) and (self.components_.shape[0] + != self.n_components_): + raise ValueError("Number of input features has changed from %i " + "to %i between calls to partial_fit! Try " + "setting n_components to a fixed value." % ( + self.components_.shape[0], self.n_components_)) + + if self.components_ is None: + # This is the first pass through partial_fit + self.n_samples_seen_ = 0 + col_var = X.var(axis=0) + col_mean = X.mean(axis=0) + X -= col_mean + U, S, V = linalg.svd(X, full_matrices=False) + U, V = svd_flip(U, V, u_based_decision=False) + explained_variance = S ** 2 / n_samples + explained_variance_ratio = S ** 2 / np.sum(col_var * + n_samples) + else: + col_batch_mean = X.mean(axis=0) + col_mean, col_var, n_total_samples = _batch_mean_variance_update( + X, self.mean_, self.var_, self.n_samples_seen_) + X -= col_batch_mean + # Build matrix of combined previous basis and new data + mean_correction = np.sqrt((self.n_samples_seen_ * n_samples) / + n_total_samples) * (self.mean_ - + col_batch_mean) + X_combined = np.vstack((self.singular_values_.reshape((-1, 1)) * + self.components_, X, + mean_correction)) + U, S, V = linalg.svd(X_combined, full_matrices=False) + U, V = svd_flip(U, V, u_based_decision=False) + explained_variance = S ** 2 / n_total_samples + explained_variance_ratio = S ** 2 / np.sum(col_var * + n_total_samples) + self.n_samples_seen_ += n_samples + self.components_ = V[:self.n_components_] + self.singular_values_ = S[:self.n_components_] + self.mean_ = col_mean + self.var_ = col_var + self.explained_variance_ = explained_variance[:self.n_components_] + self.explained_variance_ratio_ = \ + explained_variance_ratio[:self.n_components_] + if self.n_components_ < n_features: + self.noise_variance_ = \ + explained_variance[self.n_components_:].mean() + else: + self.noise_variance_ = 0. + return self diff --git a/sklearn/decomposition/tests/test_incremental_pca.py b/sklearn/decomposition/tests/test_incremental_pca.py new file mode 100644 index 0000000000000..db75dfe7f9960 --- /dev/null +++ b/sklearn/decomposition/tests/test_incremental_pca.py @@ -0,0 +1,224 @@ +"""Tests for Incremental PCA.""" +import numpy as np + +from sklearn.utils.testing import assert_almost_equal +from sklearn.utils.testing import assert_array_almost_equal +from sklearn.utils.testing import assert_raises + +from sklearn import datasets +from sklearn.decomposition import PCA, IncrementalPCA + +iris = datasets.load_iris() + + +def test_incremental_pca(): + """Incremental PCA on dense arrays.""" + X = iris.data + batch_size = X.shape[0] // 3 + ipca = IncrementalPCA(n_components=2, batch_size=batch_size) + pca = PCA(n_components=2) + pca.fit_transform(X) + + X_transformed = ipca.fit_transform(X) + + np.testing.assert_equal(X_transformed.shape, (X.shape[0], 2)) + assert_almost_equal(ipca.explained_variance_ratio_.sum(), + pca.explained_variance_ratio_.sum(), 1) + + for n_components in [1, 2, X.shape[1]]: + ipca = IncrementalPCA(n_components, batch_size=batch_size) + ipca.fit(X) + cov = ipca.get_covariance() + precision = ipca.get_precision() + assert_array_almost_equal(np.dot(cov, precision), + np.eye(X.shape[1])) + + +def test_incremental_pca_check_projection(): + """Test that the projection of data is correct.""" + rng = np.random.RandomState(1999) + n, p = 100, 3 + X = rng.randn(n, p) * .1 + X[:10] += np.array([3, 4, 5]) + Xt = 0.1 * rng.randn(1, p) + np.array([3, 4, 5]) + + # Get the reconstruction of the generated data X + # Note that Xt has the same "components" as X, just separated + # This is what we want to ensure is recreated correctly + Yt = IncrementalPCA(n_components=2).fit(X).transform(Xt) + + # Normalize + Yt /= np.sqrt((Yt ** 2).sum()) + + # Make sure that the first element of Yt is ~1, this means + # the reconstruction worked as expected + assert_almost_equal(np.abs(Yt[0][0]), 1., 1) + + +def test_incremental_pca_inverse(): + """Test that the projection of data can be inverted.""" + rng = np.random.RandomState(1999) + n, p = 50, 3 + X = rng.randn(n, p) # spherical data + X[:, 1] *= .00001 # make middle component relatively small + X += [5, 4, 3] # make a large mean + + # same check that we can find the original data from the transformed + # signal (since the data is almost of rank n_components) + ipca = IncrementalPCA(n_components=2, batch_size=10).fit(X) + Y = ipca.transform(X) + Y_inverse = ipca.inverse_transform(Y) + assert_almost_equal(X, Y_inverse, decimal=3) + + +def test_incremental_pca_validation(): + """Test that n_components is >=1 and <= n_features.""" + X = [[0, 1], [1, 0]] + for n_components in [-1, 0, .99, 3]: + assert_raises(ValueError, IncrementalPCA(n_components, + batch_size=10).fit, X) + + +def test_incremental_pca_set_params(): + """Test that components_ sign is stable over batch sizes.""" + rng = np.random.RandomState(1999) + n_samples = 100 + n_features = 20 + X = rng.randn(n_samples, n_features) + X2 = rng.randn(n_samples, n_features) + X3 = rng.randn(n_samples, n_features) + ipca = IncrementalPCA(n_components=20) + ipca.fit(X) + # Decreasing number of components + ipca.set_params(n_components=10) + assert_raises(ValueError, ipca.partial_fit, X2) + # Increasing number of components + ipca.set_params(n_components=15) + assert_raises(ValueError, ipca.partial_fit, X3) + # Returning to original setting + ipca.set_params(n_components=20) + ipca.partial_fit(X) + + +def test_incremental_pca_num_features_change(): + """Test that changing n_components will raise an error.""" + rng = np.random.RandomState(1999) + n_samples = 100 + X = rng.randn(n_samples, 20) + X2 = rng.randn(n_samples, 50) + ipca = IncrementalPCA(n_components=None) + ipca.fit(X) + assert_raises(ValueError, ipca.partial_fit, X2) + + +def test_incremental_pca_batch_signs(): + """Test that components_ sign is stable over batch sizes.""" + rng = np.random.RandomState(1999) + n_samples = 100 + n_features = 3 + X = rng.randn(n_samples, n_features) + all_components = [] + batch_sizes = np.arange(10, 20) + for batch_size in batch_sizes: + ipca = IncrementalPCA(n_components=None, batch_size=batch_size).fit(X) + all_components.append(ipca.components_) + + for i, j in zip(all_components[:-1], all_components[1:]): + assert_almost_equal(np.sign(i), np.sign(j), decimal=6) + + +def test_incremental_pca_batch_values(): + """Test that components_ values are stable over batch sizes.""" + rng = np.random.RandomState(1999) + n_samples = 100 + n_features = 3 + X = rng.randn(n_samples, n_features) + all_components = [] + batch_sizes = np.arange(20, 40, 3) + for batch_size in batch_sizes: + ipca = IncrementalPCA(n_components=None, batch_size=batch_size).fit(X) + all_components.append(ipca.components_) + + for i, j in zip(all_components[:-1], all_components[1:]): + assert_almost_equal(i, j, decimal=1) + + +def test_incremental_pca_partial_fit(): + """Test that fit and partial_fit get equivalent results.""" + rng = np.random.RandomState(1999) + n, p = 50, 3 + X = rng.randn(n, p) # spherical data + X[:, 1] *= .00001 # make middle component relatively small + X += [5, 4, 3] # make a large mean + + # same check that we can find the original data from the transformed + # signal (since the data is almost of rank n_components) + batch_size = 10 + ipca = IncrementalPCA(n_components=2, batch_size=batch_size).fit(X) + pipca = IncrementalPCA(n_components=2, batch_size=batch_size) + # Add one to make sure endpoint is included + batch_itr = np.arange(0, n + 1, batch_size) + for i, j in zip(batch_itr[:-1], batch_itr[1:]): + pipca.partial_fit(X[i:j, :]) + assert_almost_equal(ipca.components_, pipca.components_, decimal=3) + + +def test_incremental_pca_against_pca_iris(): + """Test that IncrementalPCA and PCA are approximate (to a sign flip).""" + X = iris.data + + Y_pca = PCA(n_components=2).fit_transform(X) + Y_ipca = IncrementalPCA(n_components=2, batch_size=25).fit_transform(X) + + assert_almost_equal(np.abs(Y_pca), np.abs(Y_ipca), 1) + + +def test_incremental_pca_against_pca_random_data(): + """Test that IncrementalPCA and PCA are approximate (to a sign flip).""" + rng = np.random.RandomState(1999) + n_samples = 100 + n_features = 3 + X = rng.randn(n_samples, n_features) + 5 * rng.rand(1, n_features) + + Y_pca = PCA(n_components=3).fit_transform(X) + Y_ipca = IncrementalPCA(n_components=3, batch_size=25).fit_transform(X) + + assert_almost_equal(np.abs(Y_pca), np.abs(Y_ipca), 1) + + +def test_explained_variances(): + """Test that PCA and IncrementalPCA calculations match""" + X = datasets.make_low_rank_matrix(1000, 100, tail_strength=0., + effective_rank=10, random_state=1999) + prec = 3 + n_samples, n_features = X.shape + for nc in [None, 99]: + pca = PCA(n_components=nc).fit(X) + ipca = IncrementalPCA(n_components=nc, batch_size=100).fit(X) + assert_almost_equal(pca.explained_variance_, ipca.explained_variance_, + decimal=prec) + assert_almost_equal(pca.explained_variance_ratio_, + ipca.explained_variance_ratio_, decimal=prec) + assert_almost_equal(pca.noise_variance_, ipca.noise_variance_, + decimal=prec) + + +def test_whitening(): + """Test that PCA and IncrementalPCA transforms match to sign flip.""" + X = datasets.make_low_rank_matrix(1000, 10, tail_strength=0., + effective_rank=2, random_state=1999) + prec = 3 + n_samples, n_features = X.shape + for nc in [None, 9]: + pca = PCA(whiten=True, n_components=nc).fit(X) + ipca = IncrementalPCA(whiten=True, n_components=nc, + batch_size=250).fit(X) + + Xt_pca = pca.transform(X) + Xt_ipca = ipca.transform(X) + assert_almost_equal(np.abs(Xt_pca), np.abs(Xt_ipca), decimal=prec) + Xinv_ipca = ipca.inverse_transform(Xt_ipca) + Xinv_pca = pca.inverse_transform(Xt_pca) + assert_almost_equal(X, Xinv_ipca, decimal=prec) + assert_almost_equal(X, Xinv_pca, decimal=prec) + assert_almost_equal(Xinv_pca, Xinv_ipca, decimal=prec) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 701a0a5efd27a..acdf451648830 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -764,9 +764,9 @@ def check_estimators_overwrite_params(name, Estimator): # catch deprecation warnings estimator = Estimator() - if hasattr(estimator, 'batch_size'): + if name == 'MiniBatchDictLearning' or name == 'MiniBatchSparsePCA': # FIXME - # for MiniBatchDictLearning + # for MiniBatchDictLearning and MiniBatchSparsePCA estimator.batch_size = 1 set_fast_parameters(estimator) diff --git a/sklearn/utils/extmath.py b/sklearn/utils/extmath.py index e9eb901256467..65e1cae3685df 100644 --- a/sklearn/utils/extmath.py +++ b/sklearn/utils/extmath.py @@ -7,8 +7,10 @@ # Olivier Grisel # Lars Buitinck # Stefan van der Walt +# Kyle Kastner # License: BSD 3 clause +from __future__ import division from functools import partial import warnings @@ -519,27 +521,41 @@ def cartesian(arrays, out=None): return out -def svd_flip(u, v): - """Sign correction to ensure deterministic output from SVD +def svd_flip(u, v, u_based_decision=True): + """Sign correction to ensure deterministic output from SVD. Adjusts the columns of u and the rows of v such that the loadings in the columns in u that are largest in absolute value are always positive. Parameters ---------- - u, v: arrays + u, v : arrays The output of `linalg.svd` or `sklearn.utils.extmath.randomized_svd`, with matching inner dimensions so one can compute `np.dot(u * s, v)`. + u_based_decision : boolean, (default=True) + If True, use the columns of u as the basis for sign flipping. Otherwise, + use the rows of v. The choice of which variable to base the decision on + is generally algorithm dependent. + + Returns ------- - u_adjusted, s, v_adjusted: arrays with the same dimensions as the input. + u_adjusted, v_adjusted : arrays with the same dimensions as the input. """ - max_abs_cols = np.argmax(np.abs(u), axis=0) - signs = np.sign(u[max_abs_cols, xrange(u.shape[1])]) - u *= signs - v *= signs[:, np.newaxis] + if u_based_decision: + # columns of u, rows of v + max_abs_cols = np.argmax(np.abs(u), axis=0) + signs = np.sign(u[max_abs_cols, xrange(u.shape[1])]) + u *= signs + v *= signs[:, np.newaxis] + else: + # rows of v, columns of u + max_abs_rows = np.argmax(np.abs(v), axis=1) + signs = np.sign(v[xrange(v.shape[0]), max_abs_rows]) + u *= signs + v *= signs[:, np.newaxis] return u, v @@ -621,3 +637,48 @@ def make_nonnegative(X, min_value=0): " make it no longer sparse.") X = X + (min_value - min_) return X + + +def _batch_mean_variance_update(X, old_mean, old_variance, old_sample_count): + """Calculate an average mean update and a Youngs and Cramer variance update. + + From the paper "Algorithms for computing the sample variance: analysis and + recommendations", by Chan, Golub, and LeVeque. + + Parameters + ---------- + X : array-like, shape (n_samples, n_features) + Data to use for variance update + + old_mean : array-like, shape: (n_features,) + + old_variance : array-like, shape: (n_features,) + + old_sample_count : int + + Returns + ------- + updated_mean : array, shape (n_features,) + + updated_variance : array, shape (n_features,) + + updated_sample_count : int + + References + ---------- + T. Chan, G. Golub, R. LeVeque. Algorithms for computing the sample variance: + recommendations, The American Statistician, Vol. 37, No. 3, pp. 242-247 + + """ + new_sum = X.sum(axis=0) + new_variance = X.var(axis=0) * X.shape[0] + old_sum = old_mean * old_sample_count + n_samples = X.shape[0] + updated_sample_count = old_sample_count + n_samples + partial_variance = old_sample_count / (n_samples * updated_sample_count) * ( + n_samples / old_sample_count * old_sum - new_sum) ** 2 + unnormalized_variance = old_variance * old_sample_count + new_variance + \ + partial_variance + return ((old_sum + new_sum) / updated_sample_count, + unnormalized_variance / updated_sample_count, + updated_sample_count) diff --git a/sklearn/utils/tests/test_extmath.py b/sklearn/utils/tests/test_extmath.py index 8676f1ceaa23a..b4a651e0971dd 100644 --- a/sklearn/utils/tests/test_extmath.py +++ b/sklearn/utils/tests/test_extmath.py @@ -27,6 +27,8 @@ from sklearn.utils.extmath import cartesian from sklearn.utils.extmath import log_logistic, logistic_sigmoid from sklearn.utils.extmath import fast_dot, _fast_dot +from sklearn.utils.extmath import svd_flip +from sklearn.utils.extmath import _batch_mean_variance_update from sklearn.datasets.samples_generator import make_low_rank_matrix @@ -246,6 +248,31 @@ def test_randomized_svd_transpose_consistency(): assert_almost_equal(s2, s3) +def test_svd_flip(): + """Check that svd_flip works in both situations, and reconstructs input.""" + rs = np.random.RandomState(1999) + n_samples = 20 + n_features = 10 + X = rs.randn(n_samples, n_features) + + # Check matrix reconstruction + U, S, V = linalg.svd(X, full_matrices=False) + U1, V1 = svd_flip(U, V, u_based_decision=False) + assert_almost_equal(np.dot(U1 * S, V1), X, decimal=6) + + # Check transposed matrix reconstruction + XT = X.T + U, S, V = linalg.svd(XT, full_matrices=False) + U2, V2 = svd_flip(U, V, u_based_decision=True) + assert_almost_equal(np.dot(U2 * S, V2), XT, decimal=6) + + # Check that different flip methods are equivalent under reconstruction + U_flip1, V_flip1 = svd_flip(U, V, u_based_decision=True) + assert_almost_equal(np.dot(U_flip1 * S, V_flip1), XT, decimal=6) + U_flip2, V_flip2 = svd_flip(U, V, u_based_decision=False) + assert_almost_equal(np.dot(U_flip2 * S, V_flip2), XT, decimal=6) + + def test_randomized_svd_sign_flip(): a = np.array([[2.0, 0.0], [0.0, 1.0]]) u1, s1, v1 = randomized_svd(a, 2, flip_sign=True, random_state=41) @@ -375,6 +402,61 @@ def test_fast_dot(): for x in [np.array([[d] * 10] * 2) for d in [np.inf, np.nan]]: assert_raises(ValueError, _fast_dot, x, x.T) + +def test_incremental_variance_update_formulas(): + """Test Youngs and Cramer incremental variance formulas.""" + # Doggie data from http://www.mathsisfun.com/data/standard-deviation.html + A = np.array([[600, 470, 170, 430, 300], + [600, 470, 170, 430, 300], + [600, 470, 170, 430, 300], + [600, 470, 170, 430, 300]]).T + idx = 2 + X1 = A[:idx, :] + X2 = A[idx:, :] + + old_means = X1.mean(axis=0) + old_variances = X1.var(axis=0) + old_sample_count = X1.shape[0] + final_means, final_variances, final_count = _batch_mean_variance_update( + X2, old_means, old_variances, old_sample_count) + assert_almost_equal(final_means, A.mean(axis=0), 6) + assert_almost_equal(final_variances, A.var(axis=0), 6) + assert_almost_equal(final_count, A.shape[0]) + + +def test_incremental_variance_ddof(): + """Test that degrees of freedom parameter for calculations are correct.""" + rng = np.random.RandomState(1999) + X = rng.randn(50, 10) + n_samples, n_features = X.shape + for batch_size in [11, 20, 37]: + steps = np.arange(0, X.shape[0], batch_size) + if steps[-1] != X.shape[0]: + steps = np.hstack([steps, n_samples]) + + for i, j in zip(steps[:-1], steps[1:]): + batch = X[i:j, :] + if i == 0: + incremental_means = batch.mean(axis=0) + incremental_variances = batch.var(axis=0) + # Assign this twice so that the test logic is consistent + incremental_count = batch.shape[0] + sample_count = batch.shape[0] + else: + result = _batch_mean_variance_update(batch, incremental_means, + incremental_variances, + sample_count) + (incremental_means, incremental_variances, + incremental_count) = result + sample_count += batch.shape[0] + + calculated_means = np.mean(X[:j], axis=0) + calculated_variances = np.var(X[:j], axis=0) + assert_almost_equal(incremental_means, calculated_means, 6) + assert_almost_equal(incremental_variances, + calculated_variances, 6) + assert_equal(incremental_count, sample_count) + if __name__ == '__main__': import nose nose.runmodule() From 5f8271f60c7293e04fddb3d524dd9e0a2c793655 Mon Sep 17 00:00:00 2001 From: Kyle Kastner Date: Fri, 19 Sep 2014 10:52:35 -0400 Subject: [PATCH 2/2] Updated what's new to add IncrementalPCA --- doc/whats_new.rst | 4 ++++ examples/decomposition/plot_incremental_pca.py | 3 ++- sklearn/decomposition/incremental_pca.py | 3 ++- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 9591c5d67721f..94c218e080f4e 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -32,6 +32,10 @@ New features :class:`ensemble.GradientBoostingRegressor`. By `Peter Prettenhofer`_. + - Added :class:`decomposition.IncrementalPCA`, an implementation of the PCA + algorithm that supports out-of-core learning with a ``partial_fit`` + method. By `Kyle Kastner`_. + Enhancements ............ diff --git a/examples/decomposition/plot_incremental_pca.py b/examples/decomposition/plot_incremental_pca.py index bfcf5d3158fa0..899ae6ce1cff0 100644 --- a/examples/decomposition/plot_incremental_pca.py +++ b/examples/decomposition/plot_incremental_pca.py @@ -8,7 +8,8 @@ replacement for principal component analysis (PCA) when the dataset to be decomposed is too large to fit in memory. IPCA builds a low-rank approximation for the input data using an amount of memory which is independent of the -input data size. +number of input data samples. It is still dependent on the input data features, +but changing the batch size allows for control of memory usage. This example serves as a visual check that IPCA is able to find a similar projection of the data to PCA (to a sign flip), while only processing a diff --git a/sklearn/decomposition/incremental_pca.py b/sklearn/decomposition/incremental_pca.py index cf3ca5337cb3e..4991bc7ea3680 100644 --- a/sklearn/decomposition/incremental_pca.py +++ b/sklearn/decomposition/incremental_pca.py @@ -137,7 +137,6 @@ def __init__(self, n_components=None, whiten=False, copy=True, self.whiten = whiten self.copy = copy self.batch_size = batch_size - self.components_ = None def fit(self, X, y=None): """Fit the model with X, using minibatches of size batch_size. @@ -191,6 +190,8 @@ def partial_fit(self, X): """ X = check_array(X, copy=self.copy, dtype=np.float) n_samples, n_features = X.shape + if not hasattr(self, 'components_'): + self.components_ = None if self.n_components is None: self.n_components_ = n_features