Skip to content

Commit 7426e6a

Browse files
committed
Merge pull request #3285 from kastnerkyle/incremental_pca
[MRG+2] Incremental PCA
2 parents fd4ba4d + 5f8271f commit 7426e6a

File tree

13 files changed

+1051
-12
lines changed

13 files changed

+1051
-12
lines changed
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
"""
2+
========================
3+
IncrementalPCA benchmark
4+
========================
5+
6+
Benchmarks for IncrementalPCA
7+
8+
"""
9+
10+
import numpy as np
11+
import gc
12+
from time import time
13+
from collections import defaultdict
14+
import matplotlib.pyplot as plt
15+
from sklearn.datasets import fetch_lfw_people
16+
from sklearn.decomposition import IncrementalPCA, RandomizedPCA, PCA
17+
18+
19+
def plot_results(X, y, label):
20+
plt.plot(X, y, label=label, marker='o')
21+
22+
23+
def benchmark(estimator, data):
24+
gc.collect()
25+
print("Benching %s" % estimator)
26+
t0 = time()
27+
estimator.fit(data)
28+
training_time = time() - t0
29+
data_t = estimator.transform(data)
30+
data_r = estimator.inverse_transform(data_t)
31+
reconstruction_error = np.mean(np.abs(data - data_r))
32+
return {'time': training_time, 'error': reconstruction_error}
33+
34+
35+
def plot_feature_times(all_times, batch_size, all_components, data):
36+
plt.figure()
37+
plot_results(all_components, all_times['pca'], label="PCA")
38+
plot_results(all_components, all_times['ipca'],
39+
label="IncrementalPCA, bsize=%i" % batch_size)
40+
plot_results(all_components, all_times['rpca'], label="RandomizedPCA")
41+
plt.legend(loc="upper left")
42+
plt.suptitle("Algorithm runtime vs. n_components\n \
43+
LFW, size %i x %i" % data.shape)
44+
plt.xlabel("Number of components (out of max %i)" % data.shape[1])
45+
plt.ylabel("Time (seconds)")
46+
47+
48+
def plot_feature_errors(all_errors, batch_size, all_components, data):
49+
plt.figure()
50+
plot_results(all_components, all_errors['pca'], label="PCA")
51+
plot_results(all_components, all_errors['ipca'],
52+
label="IncrementalPCA, bsize=%i" % batch_size)
53+
plot_results(all_components, all_errors['rpca'], label="RandomizedPCA")
54+
plt.legend(loc="lower left")
55+
plt.suptitle("Algorithm error vs. n_components\n"
56+
"LFW, size %i x %i" % data.shape)
57+
plt.xlabel("Number of components (out of max %i)" % data.shape[1])
58+
plt.ylabel("Mean absolute error")
59+
60+
61+
def plot_batch_times(all_times, n_features, all_batch_sizes, data):
62+
plt.figure()
63+
plot_results(all_batch_sizes, all_times['pca'], label="PCA")
64+
plot_results(all_batch_sizes, all_times['rpca'], label="RandomizedPCA")
65+
plot_results(all_batch_sizes, all_times['ipca'], label="IncrementalPCA")
66+
plt.legend(loc="lower left")
67+
plt.suptitle("Algorithm runtime vs. batch_size for n_components %i\n \
68+
LFW, size %i x %i" % (
69+
n_features, data.shape[0], data.shape[1]))
70+
plt.xlabel("Batch size")
71+
plt.ylabel("Time (seconds)")
72+
73+
74+
def plot_batch_errors(all_errors, n_features, all_batch_sizes, data):
75+
plt.figure()
76+
plot_results(all_batch_sizes, all_errors['pca'], label="PCA")
77+
plot_results(all_batch_sizes, all_errors['ipca'], label="IncrementalPCA")
78+
plt.legend(loc="lower left")
79+
plt.suptitle("Algorithm error vs. batch_size for n_components %i\n \
80+
LFW, size %i x %i" % (
81+
n_features, data.shape[0], data.shape[1]))
82+
plt.xlabel("Batch size")
83+
plt.ylabel("Mean absolute error")
84+
85+
86+
def fixed_batch_size_comparison(data):
87+
all_features = [i.astype(int) for i in np.linspace(data.shape[1] // 10,
88+
data.shape[1], num=5)]
89+
batch_size = 1000
90+
# Compare runtimes and error for fixed batch size
91+
all_times = defaultdict(list)
92+
all_errors = defaultdict(list)
93+
for n_components in all_features:
94+
pca = PCA(n_components=n_components)
95+
rpca = RandomizedPCA(n_components=n_components, random_state=1999)
96+
ipca = IncrementalPCA(n_components=n_components, batch_size=batch_size)
97+
results_dict = {k: benchmark(est, data) for k, est in [('pca', pca),
98+
('ipca', ipca),
99+
('rpca', rpca)]}
100+
101+
for k in sorted(results_dict.keys()):
102+
all_times[k].append(results_dict[k]['time'])
103+
all_errors[k].append(results_dict[k]['error'])
104+
105+
plot_feature_times(all_times, batch_size, all_features, data)
106+
plot_feature_errors(all_errors, batch_size, all_features, data)
107+
108+
109+
def variable_batch_size_comparison(data):
110+
batch_sizes = [i.astype(int) for i in np.linspace(data.shape[0] // 10,
111+
data.shape[0], num=10)]
112+
113+
for n_components in [i.astype(int) for i in
114+
np.linspace(data.shape[1] // 10,
115+
data.shape[1], num=4)]:
116+
all_times = defaultdict(list)
117+
all_errors = defaultdict(list)
118+
pca = PCA(n_components=n_components)
119+
rpca = RandomizedPCA(n_components=n_components, random_state=1999)
120+
results_dict = {k: benchmark(est, data) for k, est in [('pca', pca),
121+
('rpca', rpca)]}
122+
123+
# Create flat baselines to compare the variation over batch size
124+
all_times['pca'].extend([results_dict['pca']['time']] *
125+
len(batch_sizes))
126+
all_errors['pca'].extend([results_dict['pca']['error']] *
127+
len(batch_sizes))
128+
all_times['rpca'].extend([results_dict['rpca']['time']] *
129+
len(batch_sizes))
130+
all_errors['rpca'].extend([results_dict['rpca']['error']] *
131+
len(batch_sizes))
132+
for batch_size in batch_sizes:
133+
ipca = IncrementalPCA(n_components=n_components,
134+
batch_size=batch_size)
135+
results_dict = {k: benchmark(est, data) for k, est in [('ipca',
136+
ipca)]}
137+
all_times['ipca'].append(results_dict['ipca']['time'])
138+
all_errors['ipca'].append(results_dict['ipca']['error'])
139+
140+
plot_batch_times(all_times, n_components, batch_sizes, data)
141+
# RandomizedPCA error is always worse (approx 100x) than other PCA
142+
# tests
143+
plot_batch_errors(all_errors, n_components, batch_sizes, data)
144+
145+
faces = fetch_lfw_people(resize=.2, min_faces_per_person=5)
146+
# limit dataset to 5000 people (don't care who they are!)
147+
X = faces.data[:5000]
148+
n_samples, h, w = faces.images.shape
149+
n_features = X.shape[1]
150+
151+
X -= X.mean(axis=0)
152+
X /= X.std(axis=0)
153+
154+
fixed_batch_size_comparison(X)
155+
variable_batch_size_comparison(X)
156+
plt.show()

doc/modules/classes.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@ Samples generator
270270
:template: class.rst
271271

272272
decomposition.PCA
273+
decomposition.IncrementalPCA
273274
decomposition.ProjectedGradientNMF
274275
decomposition.RandomizedPCA
275276
decomposition.KernelPCA

doc/modules/decomposition.rst

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,7 @@ project the data onto the singular space while scaling each component
2828
to unit variance. This is often useful if the models down-stream make
2929
strong assumptions on the isotropy of the signal: this is for example
3030
the case for Support Vector Machines with the RBF kernel and the K-Means
31-
clustering algorithm. However in that case the inverse transform is no
32-
longer exact since some information is lost while forward transforming.
31+
clustering algorithm.
3332

3433
Below is an example of the iris dataset, which is comprised of 4
3534
features, projected on the 2 dimensions that explain most variance:
@@ -57,6 +56,46 @@ data based on the amount of variance it explains. As such it implements a
5756
* :ref:`example_decomposition_plot_pca_vs_fa_model_selection.py`
5857

5958

59+
.. _IncrementalPCA:
60+
61+
Incremental PCA
62+
---------------
63+
64+
The :class:`PCA` object is very useful, but has certain limitations for
65+
large datasets. The biggest limitation is that :class:`PCA` only supports
66+
batch processing, which means all of the data to be processed must fit in main
67+
memory. The :class:`IncrementalPCA` object uses a different form of
68+
processing and allows for partial computations which almost
69+
exactly match the results of :class:`PCA` while processing the data in a
70+
minibatch fashion. :class:`IncrementalPCA` makes it possible to implement
71+
out-of-core Principal Component Analysis either by:
72+
73+
* Using its ``partial_fit`` method on chunks of data fetched sequentially
74+
from the local hard drive or a network database.
75+
76+
* Calling its fit method on a memory mapped file using ``numpy.memmap``.
77+
78+
:class:`IncrementalPCA` only stores estimates of component and noise variances,
79+
in order update ``explained_variance_ratio_`` incrementally. This is why
80+
memory usage depends on the number of samples per batch, rather than the
81+
number of samples to be processed in the dataset.
82+
83+
.. figure:: ../auto_examples/decomposition/images/plot_incremental_pca_001.png
84+
:target: ../auto_examples/decomposition/plot_incremental_pca.html
85+
:align: center
86+
:scale: 75%
87+
88+
.. figure:: ../auto_examples/decomposition/images/plot_incremental_pca_002.png
89+
:target: ../auto_examples/decomposition/plot_incremental_pca.html
90+
:align: center
91+
:scale: 75%
92+
93+
94+
.. topic:: Examples:
95+
96+
* :ref:`example_decomposition_plot_incremental_pca.py`
97+
98+
6099
.. _RandomizedPCA:
61100

62101
Approximate PCA

doc/modules/scaling_strategies.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ Here is a list of incremental estimators for different tasks:
6969
+ :class:`sklearn.cluster.MiniBatchKMeans`
7070
- Decomposition / feature Extraction
7171
+ :class:`sklearn.decomposition.MiniBatchDictionaryLearning`
72+
+ :class:`sklearn.decomposition.IncrementalPCA`
7273
+ :class:`sklearn.cluster.MiniBatchKMeans`
7374

7475
For classification, a somewhat important thing to note is that although a

doc/whats_new.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ New features
3232
:class:`ensemble.GradientBoostingRegressor`. By
3333
`Peter Prettenhofer`_.
3434

35+
- Added :class:`decomposition.IncrementalPCA`, an implementation of the PCA
36+
algorithm that supports out-of-core learning with a ``partial_fit``
37+
method. By `Kyle Kastner`_.
38+
3539

3640
Enhancements
3741
............
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
"""
2+
3+
===============
4+
Incremental PCA
5+
===============
6+
7+
Incremental principal component analysis (IPCA) is typically used as a
8+
replacement for principal component analysis (PCA) when the dataset to be
9+
decomposed is too large to fit in memory. IPCA builds a low-rank approximation
10+
for the input data using an amount of memory which is independent of the
11+
number of input data samples. It is still dependent on the input data features,
12+
but changing the batch size allows for control of memory usage.
13+
14+
This example serves as a visual check that IPCA is able to find a similar
15+
projection of the data to PCA (to a sign flip), while only processing a
16+
few samples at a time. This can be considered a "toy example", as IPCA is
17+
intended for large datasets which do not fit in main memory, requiring
18+
incremental approaches.
19+
20+
"""
21+
print(__doc__)
22+
23+
# Authors: Kyle Kastner
24+
# License: BSD 3 clause
25+
26+
import numpy as np
27+
import matplotlib.pyplot as plt
28+
29+
from sklearn.datasets import load_iris
30+
from sklearn.decomposition import PCA, IncrementalPCA
31+
32+
iris = load_iris()
33+
X = iris.data
34+
y = iris.target
35+
36+
n_components = 2
37+
ipca = IncrementalPCA(n_components=n_components, batch_size=10)
38+
X_ipca = ipca.fit_transform(X)
39+
40+
pca = PCA(n_components=n_components)
41+
X_pca = pca.fit_transform(X)
42+
43+
for X_transformed, title in [(X_ipca, "Incremental PCA"), (X_pca, "PCA")]:
44+
plt.figure(figsize=(8, 8))
45+
for c, i, target_name in zip("rgb", [0, 1, 2], iris.target_names):
46+
plt.scatter(X_transformed[y == i, 0], X_transformed[y == i, 1],
47+
c=c, label=target_name)
48+
49+
if "Incremental" in title:
50+
err = np.abs(np.abs(X_pca) - np.abs(X_ipca)).mean()
51+
plt.title(title + " of iris dataset\nMean absolute unsigned error "
52+
"%.6f" % err)
53+
else:
54+
plt.title(title + " of iris dataset")
55+
plt.legend(loc="best")
56+
plt.axis([-4, 4, -1.5, 1.5])
57+
58+
plt.show()

sklearn/decomposition/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from .nmf import NMF, ProjectedGradientNMF
88
from .pca import PCA, RandomizedPCA
9+
from .incremental_pca import IncrementalPCA
910
from .kernel_pca import KernelPCA
1011
from .sparse_pca import SparsePCA, MiniBatchSparsePCA
1112
from .truncated_svd import TruncatedSVD
@@ -18,6 +19,7 @@
1819

1920
__all__ = ['DictionaryLearning',
2021
'FastICA',
22+
'IncrementalPCA',
2123
'KernelPCA',
2224
'MiniBatchDictionaryLearning',
2325
'MiniBatchSparsePCA',

0 commit comments

Comments
 (0)