Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
87 commits
Select commit Hold shift + click to select a range
9fba9fa
ENH new svd_solver='covariance_eigh' for PCA
ogrisel Sep 28, 2023
12373df
Merge branch 'main' into pca-covariance-eigh
ogrisel Sep 28, 2023
f9b3a05
Add changelog entry
ogrisel Sep 28, 2023
cc570e4
Change test parametrization to workaround pytest xdist bug
ogrisel Sep 28, 2023
d1e4287
Avoid raising warning when dtype is None
ogrisel Sep 28, 2023
74a38ea
Another try to workaround pytest xdist bug
ogrisel Sep 28, 2023
fbb1575
Use the new svd_flip config consistently and update docstrings accord…
ogrisel Sep 28, 2023
1734878
One more missing sign update in the FeatureUnion doctest
ogrisel Sep 28, 2023
62116d4
Document the changed component sign heuristic
ogrisel Sep 28, 2023
7b071a6
More svd_flip update
ogrisel Sep 28, 2023
973145f
Attempt at fixing indentation problem for svd_solver docstring
ogrisel Sep 28, 2023
4ac8935
One more missing sign update in the FeatureUnion doctest
ogrisel Sep 28, 2023
0b3a192
Avoid SparsePCA's internal DictionaryLearning instance generating cod…
ogrisel Sep 28, 2023
f6815e5
TST test fit_transform and transform equivalence
ogrisel Sep 29, 2023
97ce07c
Merge branch 'main' into pca-covariance-eigh
ogrisel Sep 29, 2023
bbd8761
TST cleaner way to define X_train and X_test
ogrisel Sep 29, 2023
29798d5
Add new benchmark script to check the robustness of the auto policy
ogrisel Sep 29, 2023
00cd121
Ignore CSV file with collected benchmark data
ogrisel Sep 29, 2023
961cf6f
Adjust svd_solver='auto'
ogrisel Sep 29, 2023
c89cd13
Make test_pca_solver_equivalence more generic at dealing with unstabl…
ogrisel Sep 29, 2023
96e1f9b
Fix sklearn/decomposition/tests/test_incremental_pca.py:test_whitening
ogrisel Sep 29, 2023
19b778e
Merge branch 'main' into pca-covariance-eigh
ogrisel Sep 29, 2023
9552626
Update the docstring to reflect the new auto policy
ogrisel Sep 29, 2023
6e3a62b
More explicit figure titles
ogrisel Sep 29, 2023
4f15943
Update the docstring to mention the difference in numerical stability
ogrisel Sep 29, 2023
1de6f67
Improve solver equivalence tests
ogrisel Oct 3, 2023
b832aea
DOC make compose doc more interesting and robust to rounding errors
ogrisel Oct 3, 2023
b3eec64
Mark the stricter test_pca_solver_equivalence with the arpack solver …
ogrisel Oct 3, 2023
4120581
Make PCA(whiten=True).transform robust to rank deficient training data
ogrisel Oct 3, 2023
4a91dec
Merge branch 'main' into pca-covariance-eigh
ogrisel Oct 3, 2023
ec87a14
Update inline comment about clipping negative eigenvalues
ogrisel Oct 3, 2023
8be2f91
Simpler way to clip the whitening scale
ogrisel Oct 3, 2023
29372a1
Test solver equivalence with float32 data
ogrisel Oct 3, 2023
efc5a42
Forgot to update on condition based on variance_threshold
ogrisel Oct 3, 2023
f6e88b4
Typo
ogrisel Oct 3, 2023
08ae5d5
[azure parallel] [all random seeds]
ogrisel Oct 3, 2023
ed83ed5
Relax float32 test tolerance [azure parallel] [all random seeds]
ogrisel Oct 3, 2023
74638b7
More explicit conditions for assertions in test_pca_solver_equivalenc…
ogrisel Oct 4, 2023
3799e92
[azure parallel] [all random seeds] test_pca_solver_equivalence
ogrisel Oct 4, 2023
59aebb7
ENH do not center a priori to spare memory and make the new solver ru…
ogrisel Oct 4, 2023
29f1ff0
Merge branch 'main' into pca-covariance-eigh
ogrisel Oct 4, 2023
8c017b4
[azure parallel] [all random seeds]
ogrisel Oct 4, 2023
45928ae
[azure parallel] [all random seeds]
ogrisel Oct 4, 2023
01f6fd6
Propagate xp namespace to private methods
ogrisel Oct 5, 2023
f3c33dc
Scale the covariance matrix before calling eigh to improve numerical …
ogrisel Oct 5, 2023
64cbcc9
Make it explicit when the X passed between private methods has been c…
ogrisel Oct 5, 2023
8284d25
[azure parallel] [all random seeds]
ogrisel Oct 5, 2023
828a058
Improve test [azure parallel] [all random seeds] test_pca_solver_equi…
ogrisel Oct 5, 2023
2208539
Merge branch 'main' into pca-covariance-eigh
ogrisel Oct 5, 2023
56a7793
OPTIM Ensure that components_ is contiguous
ogrisel Oct 5, 2023
e797b53
Update auto policy and bench_pca_solvers.py
ogrisel Oct 5, 2023
8fcf2ff
Single call to xp.sum to compute total_var.
ogrisel Oct 6, 2023
a6db513
Apply suggestions from code review
ogrisel Oct 13, 2023
abd3282
Merge branch 'main' into pca-covariance-eigh
ogrisel Oct 20, 2023
4a2a062
Use the device() utility function instead of getattr
ogrisel Oct 20, 2023
a6bd3b8
Improve inline comment to explain copy=False
ogrisel Oct 20, 2023
1446f10
Improve comment to explain why we do not rely on numpy.cov
ogrisel Oct 20, 2023
ff48ede
Apply suggestions from code review
ogrisel Nov 4, 2023
bef7e0a
Merge branch 'main' into pca-covariance-eigh
ogrisel Nov 4, 2023
9e7f48a
Apply suggestions from code review
ogrisel Nov 4, 2023
abe95f1
Attempt at merging main with sparse input support for PCA
ogrisel Nov 10, 2023
9c54050
Fix typo
ogrisel Nov 10, 2023
2070f0f
Merge branch 'main' into pca-covariance-eigh
ogrisel Nov 12, 2023
59ec249
Fix broken tests
ogrisel Nov 12, 2023
31fe741
Increase rtol for array api compliance test with float32 pytorch
ogrisel Nov 12, 2023
ff3190c
Accept sparse input data for svd_solver='covariance_eigh'
ogrisel Nov 12, 2023
d38650d
Fix expected error message in test
ogrisel Nov 12, 2023
f27ca6f
Merge branch 'main' into pca-covariance-eigh
ogrisel Nov 17, 2023
181d852
Merge branch 'main' into pca-covariance-eigh
ogrisel Dec 13, 2023
c50ab47
Merge branch 'main' into pca-covariance-eigh
ogrisel Jan 4, 2024
15db022
DOC move changelog entries to target 1.5
ogrisel Jan 4, 2024
4aade1a
DOC cleanup left over entry that was meant to be moved to 1.5
ogrisel Jan 4, 2024
d131a64
Fix bad conflict resolution
ogrisel Jan 5, 2024
3bf57b1
Merge branch 'main' into pca-covariance-eigh
ogrisel Jan 5, 2024
6670604
Merge branch 'main' into pca-covariance-eigh
ogrisel Apr 7, 2024
b6ef8ba
Apply suggestions from code review
ogrisel Apr 7, 2024
e12fe81
Use assert_allclose in test_whitening
ogrisel Apr 7, 2024
85e6f24
Move changed models entry to the existing section of the change log
ogrisel Apr 7, 2024
ca899a0
Merge branch 'main' into pca-covariance-eigh
ogrisel Apr 9, 2024
b47b603
Apply suggestions from code review
ogrisel Apr 10, 2024
b43f6f9
Apply suggestions from code review
ogrisel Apr 10, 2024
c0b89b2
Merge branch 'main' into pca-covariance-eigh
ogrisel Apr 10, 2024
310b616
Wrap paragraphs in multiline comments
ogrisel Apr 10, 2024
73a4d61
Update sklearn/decomposition/tests/test_incremental_pca.py
ogrisel Apr 10, 2024
797e081
Wrap paragraphs in changelog
ogrisel Apr 10, 2024
47935f5
Do not test on isinstance(..., np.matrix) and only use Array API for …
ogrisel Apr 10, 2024
9842711
Do not test on isinstance(..., np.matrix) and only use Array API for …
ogrisel Apr 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ examples/cluster/joblib
reuters/
benchmarks/bench_covertype_data/
benchmarks/HIGGS.csv.gz
bench_pca_solvers.csv

*.prefs
.pydevproject
Expand Down
165 changes: 165 additions & 0 deletions benchmarks/bench_pca_solvers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
# %%
#
# This benchmark compares the speed of PCA solvers on datasets of different
# sizes in order to determine the best solver to select by default via the
# "auto" heuristic.
#
# Note: we do not control for the accuracy of the solvers: we assume that all
# solvers yield transformed data with similar explained variance. This
# assumption is generally true, except for the randomized solver that might
# require more power iterations.
#
# We generate synthetic data with dimensions that are useful to plot:
# - time vs n_samples for a fixed n_features and,
# - time vs n_features for a fixed n_samples for a fixed n_features.
import itertools
from math import log10
from time import perf_counter

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from sklearn import config_context
from sklearn.decomposition import PCA

REF_DIMS = [100, 1000, 10_000]
data_shapes = []
for ref_dim in REF_DIMS:
data_shapes.extend([(ref_dim, 10**i) for i in range(1, 8 - int(log10(ref_dim)))])
data_shapes.extend(
[(ref_dim, 3 * 10**i) for i in range(1, 8 - int(log10(ref_dim)))]
)
data_shapes.extend([(10**i, ref_dim) for i in range(1, 8 - int(log10(ref_dim)))])
data_shapes.extend(
[(3 * 10**i, ref_dim) for i in range(1, 8 - int(log10(ref_dim)))]
)

# Remove duplicates:
data_shapes = sorted(set(data_shapes))

print("Generating test datasets...")
rng = np.random.default_rng(0)
datasets = [rng.normal(size=shape) for shape in data_shapes]


# %%
def measure_one(data, n_components, solver, method_name="fit"):
print(
f"Benchmarking {solver=!r}, {n_components=}, {method_name=!r} on data with"
f" shape {data.shape}"
)
pca = PCA(n_components=n_components, svd_solver=solver, random_state=0)
timings = []
elapsed = 0
method = getattr(pca, method_name)
with config_context(assume_finite=True):
while elapsed < 0.5:
tic = perf_counter()
method(data)
duration = perf_counter() - tic
timings.append(duration)
elapsed += duration
return np.median(timings)


SOLVERS = ["full", "covariance_eigh", "arpack", "randomized", "auto"]
measurements = []
for data, n_components, method_name in itertools.product(
datasets, [2, 50], ["fit", "fit_transform"]
):
if n_components >= min(data.shape):
continue
for solver in SOLVERS:
if solver == "covariance_eigh" and data.shape[1] > 5000:
# Too much memory and too slow.
continue
if solver in ["arpack", "full"] and log10(data.size) > 7:
# Too slow, in particular for the full solver.
continue
time = measure_one(data, n_components, solver, method_name=method_name)
measurements.append(
{
"n_components": n_components,
"n_samples": data.shape[0],
"n_features": data.shape[1],
"time": time,
"solver": solver,
"method_name": method_name,
}
)
measurements = pd.DataFrame(measurements)
measurements.to_csv("bench_pca_solvers.csv", index=False)

# %%
all_method_names = measurements["method_name"].unique()
all_n_components = measurements["n_components"].unique()

for method_name in all_method_names:
fig, axes = plt.subplots(
figsize=(16, 16),
nrows=len(REF_DIMS),
ncols=len(all_n_components),
sharey=True,
constrained_layout=True,
)
fig.suptitle(f"Benchmarks for PCA.{method_name}, varying n_samples", fontsize=16)

for row_idx, ref_dim in enumerate(REF_DIMS):
for n_components, ax in zip(all_n_components, axes[row_idx]):
for solver in SOLVERS:
if solver == "auto":
style_kwargs = dict(linewidth=2, color="black", style="--")
else:
style_kwargs = dict(style="o-")
ax.set(
title=f"n_components={n_components}, n_features={ref_dim}",
ylabel="time (s)",
)
measurements.query(
"n_components == @n_components and n_features == @ref_dim"
" and solver == @solver and method_name == @method_name"
).plot.line(
x="n_samples",
y="time",
label=solver,
logx=True,
logy=True,
ax=ax,
**style_kwargs,
)
# %%
for method_name in all_method_names:
fig, axes = plt.subplots(
figsize=(16, 16),
nrows=len(REF_DIMS),
ncols=len(all_n_components),
sharey=True,
)
fig.suptitle(f"Benchmarks for PCA.{method_name}, varying n_features", fontsize=16)

for row_idx, ref_dim in enumerate(REF_DIMS):
for n_components, ax in zip(all_n_components, axes[row_idx]):
for solver in SOLVERS:
if solver == "auto":
style_kwargs = dict(linewidth=2, color="black", style="--")
else:
style_kwargs = dict(style="o-")
ax.set(
title=f"n_components={n_components}, n_samples={ref_dim}",
ylabel="time (s)",
)
measurements.query(
"n_components == @n_components and n_samples == @ref_dim "
" and solver == @solver and method_name == @method_name"
).plot.line(
x="n_features",
y="time",
label=solver,
logx=True,
logy=True,
ax=ax,
**style_kwargs,
)

# %%
16 changes: 8 additions & 8 deletions doc/modules/compose.rst
Original file line number Diff line number Diff line change
Expand Up @@ -254,14 +254,14 @@ inspect the original instance such as::

>>> from sklearn.datasets import load_digits
>>> X_digits, y_digits = load_digits(return_X_y=True)
>>> pca1 = PCA()
>>> pca1 = PCA(n_components=10)
>>> svm1 = SVC()
>>> pipe = Pipeline([('reduce_dim', pca1), ('clf', svm1)])
>>> pipe.fit(X_digits, y_digits)
Pipeline(steps=[('reduce_dim', PCA()), ('clf', SVC())])
Pipeline(steps=[('reduce_dim', PCA(n_components=10)), ('clf', SVC())])
>>> # The pca instance can be inspected directly
>>> print(pca1.components_)
[[-1.77484909e-19 ... 4.07058917e-18]]
>>> pca1.components_.shape
(10, 64)


Enabling caching triggers a clone of the transformers before fitting.
Expand All @@ -274,15 +274,15 @@ Instead, use the attribute ``named_steps`` to inspect estimators within
the pipeline::

>>> cachedir = mkdtemp()
>>> pca2 = PCA()
>>> pca2 = PCA(n_components=10)
>>> svm2 = SVC()
>>> cached_pipe = Pipeline([('reduce_dim', pca2), ('clf', svm2)],
... memory=cachedir)
>>> cached_pipe.fit(X_digits, y_digits)
Pipeline(memory=...,
steps=[('reduce_dim', PCA()), ('clf', SVC())])
>>> print(cached_pipe.named_steps['reduce_dim'].components_)
[[-1.77484909e-19 ... 4.07058917e-18]]
steps=[('reduce_dim', PCA(n_components=10)), ('clf', SVC())])
>>> cached_pipe.named_steps['reduce_dim'].components_.shape
(10, 64)
>>> # Remove the cache directory
>>> rmtree(cachedir)

Expand Down
30 changes: 30 additions & 0 deletions doc/whats_new/v1.5.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@ Changed models
properties).
:pr:`27344` by :user:`Xuefeng Xu <xuefeng-xu>`.

- |Enhancement| :class:`decomposition.PCA`, :class:`decomposition.SparsePCA`
and :class:`decomposition.TruncatedSVD` now set the sign of the `components_`
attribute based on the component values instead of using the transformed data
as reference. This change is needed to be able to offer consistent component
signs across all `PCA` solvers, including the new
`svd_solver="covariance_eigh"` option introduced in this release.

Support for Array API
---------------------

Expand Down Expand Up @@ -169,10 +176,33 @@ Changelog
:mod:`sklearn.decomposition`
............................

- |Efficiency| :class:`decomposition.PCA` with `svd_solver="full"` now assigns
a contiguous `components_` attribute instead of an non-contiguous slice of
the singular vectors. When `n_components << n_features`, this can save some
memory and, more importantly, help speed-up subsequent calls to the `transform`
method by more than an order of magnitude by leveraging cache locality of
BLAS GEMM on contiguous arrays.
:pr:`27491` by :user:`Olivier Grisel <ogrisel>`.

- |Enhancement| :class:`~decomposition.PCA` now automatically selects the ARPACK solver
for sparse inputs when `svd_solver="auto"` instead of raising an error.
:pr:`28498` by :user:`Thanh Lam Dang <lamdang2k>`.

- |Enhancement| :class:`decomposition.PCA` now supports a new solver option
named `svd_solver="covariance_eigh"` which offers an order of magnitude
speed-up and reduced memory usage for datasets with a large number of data
points and a small number of features (say, `n_samples >> 1000 >
n_features`). The `svd_solver="auto"` option has been updated to use the new
solver automatically for such datasets. This solver also accepts sparse input
data.
:pr:`27491` by :user:`Olivier Grisel <ogrisel>`.

- |Fix| :class:`decomposition.PCA` fit with `svd_solver="arpack"`,
`whiten=True` and a value for `n_components` that is larger than the rank of
the training set, no longer returns infinite values when transforming
hold-out data.
:pr:`27491` by :user:`Olivier Grisel <ogrisel>`.

:mod:`sklearn.dummy`
....................

Expand Down
30 changes: 20 additions & 10 deletions sklearn/decomposition/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,9 @@

import numpy as np
from scipy import linalg
from scipy.sparse import issparse

from ..base import BaseEstimator, ClassNamePrefixFeaturesOutMixin, TransformerMixin
from ..utils._array_api import _add_to_diagonal, device, get_namespace
from ..utils.sparsefuncs import _implicit_column_offset
from ..utils.validation import check_is_fitted


Expand Down Expand Up @@ -138,21 +136,33 @@ def transform(self, X):
Projection of X in the first principal components, where `n_samples`
is the number of samples and `n_components` is the number of the components.
"""
xp, _ = get_namespace(X)
xp, _ = get_namespace(X, self.components_, self.explained_variance_)

check_is_fitted(self)

X = self._validate_data(
X, accept_sparse=("csr", "csc"), dtype=[xp.float64, xp.float32], reset=False
X, dtype=[xp.float64, xp.float32], accept_sparse=("csr", "csc"), reset=False
)
if self.mean_ is not None:
if issparse(X):
X = _implicit_column_offset(X, self.mean_)
else:
X = X - self.mean_
return self._transform(X, xp=xp, x_is_centered=False)

def _transform(self, X, xp, x_is_centered=False):
X_transformed = X @ self.components_.T
if not x_is_centered:
# Apply the centering after the projection.
# For dense X this avoids copying or mutating the data passed by
# the caller.
# For sparse X it keeps sparsity and avoids having to wrap X into
# a linear operator.
X_transformed -= xp.reshape(self.mean_, (1, -1)) @ self.components_.T
if self.whiten:
X_transformed /= xp.sqrt(self.explained_variance_)
# For some solvers (such as "arpack" and "covariance_eigh"), on
# rank deficient data, some components can have a variance
# arbitrarily close to zero, leading to non-finite results when
# whitening. To avoid this problem we clip the variance below.
scale = xp.sqrt(self.explained_variance_)
min_scale = xp.finfo(scale.dtype).eps
scale[scale < min_scale] = min_scale
X_transformed /= scale
return X_transformed

def inverse_transform(self, X):
Expand Down
4 changes: 1 addition & 3 deletions sklearn/decomposition/_kernel_pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,9 +366,7 @@ def _fit_transform(self, K):
)

# flip eigenvectors' sign to enforce deterministic output
self.eigenvectors_, _ = svd_flip(
self.eigenvectors_, np.zeros_like(self.eigenvectors_).T
)
self.eigenvectors_, _ = svd_flip(u=self.eigenvectors_, v=None)

# sort eigenvectors in descending order
indices = self.eigenvalues_.argsort()[::-1]
Expand Down
Loading