Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 156 additions & 0 deletions benchmarks/bench_plot_incremental_pca.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
"""
========================
IncrementalPCA benchmark
========================

Benchmarks for IncrementalPCA

"""

import numpy as np
import gc
from time import time
from collections import defaultdict
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_lfw_people
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please reorder imports (numpy first and sklearn last)

from sklearn.decomposition import IncrementalPCA, RandomizedPCA, PCA


def plot_results(X, y, label):
plt.plot(X, y, label=label, marker='o')


def benchmark(estimator, data):
gc.collect()
print("Benching %s" % estimator)
t0 = time()
estimator.fit(data)
training_time = time() - t0
data_t = estimator.transform(data)
data_r = estimator.inverse_transform(data_t)
reconstruction_error = np.mean(np.abs(data - data_r))
return {'time': training_time, 'error': reconstruction_error}


def plot_feature_times(all_times, batch_size, all_components, data):
plt.figure()
plot_results(all_components, all_times['pca'], label="PCA")
plot_results(all_components, all_times['ipca'],
label="IncrementalPCA, bsize=%i" % batch_size)
plot_results(all_components, all_times['rpca'], label="RandomizedPCA")
plt.legend(loc="upper left")
plt.suptitle("Algorithm runtime vs. n_components\n \
LFW, size %i x %i" % data.shape)
plt.xlabel("Number of components (out of max %i)" % data.shape[1])
plt.ylabel("Time (seconds)")


def plot_feature_errors(all_errors, batch_size, all_components, data):
plt.figure()
plot_results(all_components, all_errors['pca'], label="PCA")
plot_results(all_components, all_errors['ipca'],
label="IncrementalPCA, bsize=%i" % batch_size)
plot_results(all_components, all_errors['rpca'], label="RandomizedPCA")
plt.legend(loc="lower left")
plt.suptitle("Algorithm error vs. n_components\n"
"LFW, size %i x %i" % data.shape)
plt.xlabel("Number of components (out of max %i)" % data.shape[1])
plt.ylabel("Mean absolute error")


def plot_batch_times(all_times, n_features, all_batch_sizes, data):
plt.figure()
plot_results(all_batch_sizes, all_times['pca'], label="PCA")
plot_results(all_batch_sizes, all_times['rpca'], label="RandomizedPCA")
plot_results(all_batch_sizes, all_times['ipca'], label="IncrementalPCA")
plt.legend(loc="lower left")
plt.suptitle("Algorithm runtime vs. batch_size for n_components %i\n \
LFW, size %i x %i" % (
n_features, data.shape[0], data.shape[1]))
plt.xlabel("Batch size")
plt.ylabel("Time (seconds)")


def plot_batch_errors(all_errors, n_features, all_batch_sizes, data):
plt.figure()
plot_results(all_batch_sizes, all_errors['pca'], label="PCA")
plot_results(all_batch_sizes, all_errors['ipca'], label="IncrementalPCA")
plt.legend(loc="lower left")
plt.suptitle("Algorithm error vs. batch_size for n_components %i\n \
LFW, size %i x %i" % (
n_features, data.shape[0], data.shape[1]))
plt.xlabel("Batch size")
plt.ylabel("Mean absolute error")


def fixed_batch_size_comparison(data):
all_features = [i.astype(int) for i in np.linspace(data.shape[1] // 10,
data.shape[1], num=5)]
batch_size = 1000
# Compare runtimes and error for fixed batch size
all_times = defaultdict(list)
all_errors = defaultdict(list)
for n_components in all_features:
pca = PCA(n_components=n_components)
rpca = RandomizedPCA(n_components=n_components, random_state=1999)
ipca = IncrementalPCA(n_components=n_components, batch_size=batch_size)
results_dict = {k: benchmark(est, data) for k, est in [('pca', pca),
('ipca', ipca),
('rpca', rpca)]}

for k in sorted(results_dict.keys()):
all_times[k].append(results_dict[k]['time'])
all_errors[k].append(results_dict[k]['error'])

plot_feature_times(all_times, batch_size, all_features, data)
plot_feature_errors(all_errors, batch_size, all_features, data)


def variable_batch_size_comparison(data):
batch_sizes = [i.astype(int) for i in np.linspace(data.shape[0] // 10,
data.shape[0], num=10)]

for n_components in [i.astype(int) for i in
np.linspace(data.shape[1] // 10,
data.shape[1], num=4)]:
all_times = defaultdict(list)
all_errors = defaultdict(list)
pca = PCA(n_components=n_components)
rpca = RandomizedPCA(n_components=n_components, random_state=1999)
results_dict = {k: benchmark(est, data) for k, est in [('pca', pca),
('rpca', rpca)]}

# Create flat baselines to compare the variation over batch size
all_times['pca'].extend([results_dict['pca']['time']] *
len(batch_sizes))
all_errors['pca'].extend([results_dict['pca']['error']] *
len(batch_sizes))
all_times['rpca'].extend([results_dict['rpca']['time']] *
len(batch_sizes))
all_errors['rpca'].extend([results_dict['rpca']['error']] *
len(batch_sizes))
for batch_size in batch_sizes:
ipca = IncrementalPCA(n_components=n_components,
batch_size=batch_size)
results_dict = {k: benchmark(est, data) for k, est in [('ipca',
ipca)]}
all_times['ipca'].append(results_dict['ipca']['time'])
all_errors['ipca'].append(results_dict['ipca']['error'])

plot_batch_times(all_times, n_components, batch_sizes, data)
# RandomizedPCA error is always worse (approx 100x) than other PCA
# tests
plot_batch_errors(all_errors, n_components, batch_sizes, data)

faces = fetch_lfw_people(resize=.2, min_faces_per_person=5)
# limit dataset to 5000 people (don't care who they are!)
X = faces.data[:5000]
n_samples, h, w = faces.images.shape
n_features = X.shape[1]

X -= X.mean(axis=0)
X /= X.std(axis=0)

fixed_batch_size_comparison(X)
variable_batch_size_comparison(X)
plt.show()
1 change: 1 addition & 0 deletions doc/modules/classes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ Samples generator
:template: class.rst

decomposition.PCA
decomposition.IncrementalPCA
decomposition.ProjectedGradientNMF
decomposition.RandomizedPCA
decomposition.KernelPCA
Expand Down
43 changes: 41 additions & 2 deletions doc/modules/decomposition.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ project the data onto the singular space while scaling each component
to unit variance. This is often useful if the models down-stream make
strong assumptions on the isotropy of the signal: this is for example
the case for Support Vector Machines with the RBF kernel and the K-Means
clustering algorithm. However in that case the inverse transform is no
longer exact since some information is lost while forward transforming.
clustering algorithm.

Below is an example of the iris dataset, which is comprised of 4
features, projected on the 2 dimensions that explain most variance:
Expand Down Expand Up @@ -57,6 +56,46 @@ data based on the amount of variance it explains. As such it implements a
* :ref:`example_decomposition_plot_pca_vs_fa_model_selection.py`


.. _IncrementalPCA:

Incremental PCA
---------------

The :class:`PCA` object is very useful, but has certain limitations for
large datasets. The biggest limitation is that :class:`PCA` only supports
batch processing, which means all of the data to be processed must fit in main
memory. The :class:`IncrementalPCA` object uses a different form of
processing and allows for partial computations which almost
exactly match the results of :class:`PCA` while processing the data in a
minibatch fashion. :class:`IncrementalPCA` makes it possible to implement
out-of-core Principal Component Analysis either by:

* Using its ``partial_fit`` method on chunks of data fetched sequentially
from the local hard drive or a network database.

* Calling its fit method on a memory mapped file using ``numpy.memmap``.

:class:`IncrementalPCA` only stores estimates of component and noise variances,
in order update ``explained_variance_ratio_`` incrementally. This is why
memory usage depends on the number of samples per batch, rather than the
number of samples to be processed in the dataset.

.. figure:: ../auto_examples/decomposition/images/plot_incremental_pca_001.png
:target: ../auto_examples/decomposition/plot_incremental_pca.html
:align: center
:scale: 75%

.. figure:: ../auto_examples/decomposition/images/plot_incremental_pca_002.png
:target: ../auto_examples/decomposition/plot_incremental_pca.html
:align: center
:scale: 75%


.. topic:: Examples:

* :ref:`example_decomposition_plot_incremental_pca.py`


.. _RandomizedPCA:

Approximate PCA
Expand Down
1 change: 1 addition & 0 deletions doc/modules/scaling_strategies.rst
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ Here is a list of incremental estimators for different tasks:
+ :class:`sklearn.cluster.MiniBatchKMeans`
- Decomposition / feature Extraction
+ :class:`sklearn.decomposition.MiniBatchDictionaryLearning`
+ :class:`sklearn.decomposition.IncrementalPCA`
+ :class:`sklearn.cluster.MiniBatchKMeans`

For classification, a somewhat important thing to note is that although a
Expand Down
4 changes: 4 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ New features
:class:`ensemble.GradientBoostingRegressor`. By
`Peter Prettenhofer`_.

- Added :class:`decomposition.IncrementalPCA`, an implementation of the PCA
algorithm that supports out-of-core learning with a ``partial_fit``
method. By `Kyle Kastner`_.


Enhancements
............
Expand Down
58 changes: 58 additions & 0 deletions examples/decomposition/plot_incremental_pca.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
"""

===============
Incremental PCA
===============

Incremental principal component analysis (IPCA) is typically used as a
replacement for principal component analysis (PCA) when the dataset to be
decomposed is too large to fit in memory. IPCA builds a low-rank approximation
for the input data using an amount of memory which is independent of the
number of input data samples. It is still dependent on the input data features,
but changing the batch size allows for control of memory usage.

This example serves as a visual check that IPCA is able to find a similar
projection of the data to PCA (to a sign flip), while only processing a
few samples at a time. This can be considered a "toy example", as IPCA is
intended for large datasets which do not fit in main memory, requiring
incremental approaches.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would add that the purpose of this example is to only serve as a visual sanity check as IPCA is only interesting on datasets with large number of samples.

"""
print(__doc__)

# Authors: Kyle Kastner
# License: BSD 3 clause

import numpy as np
import matplotlib.pyplot as plt

from sklearn.datasets import load_iris
from sklearn.decomposition import PCA, IncrementalPCA

iris = load_iris()
X = iris.data
y = iris.target

n_components = 2
ipca = IncrementalPCA(n_components=n_components, batch_size=10)
X_ipca = ipca.fit_transform(X)

pca = PCA(n_components=n_components)
X_pca = pca.fit_transform(X)

for X_transformed, title in [(X_ipca, "Incremental PCA"), (X_pca, "PCA")]:
plt.figure(figsize=(8, 8))
for c, i, target_name in zip("rgb", [0, 1, 2], iris.target_names):
plt.scatter(X_transformed[y == i, 0], X_transformed[y == i, 1],
c=c, label=target_name)

if "Incremental" in title:
err = np.abs(np.abs(X_pca) - np.abs(X_ipca)).mean()
plt.title(title + " of iris dataset\nMean absolute unsigned error "
"%.6f" % err)
else:
plt.title(title + " of iris dataset")
plt.legend(loc="best")
plt.axis([-4, 4, -1.5, 1.5])

plt.show()
2 changes: 2 additions & 0 deletions sklearn/decomposition/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from .nmf import NMF, ProjectedGradientNMF
from .pca import PCA, RandomizedPCA
from .incremental_pca import IncrementalPCA
from .kernel_pca import KernelPCA
from .sparse_pca import SparsePCA, MiniBatchSparsePCA
from .truncated_svd import TruncatedSVD
Expand All @@ -18,6 +19,7 @@

__all__ = ['DictionaryLearning',
'FastICA',
'IncrementalPCA',
'KernelPCA',
'MiniBatchDictionaryLearning',
'MiniBatchSparsePCA',
Expand Down
Loading