From 9fba9fac7bb78a14dfc0ad04e5ed8e74cd4ac3be Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 28 Sep 2023 15:21:12 +0200 Subject: [PATCH 01/70] ENH new svd_solver='covariance_eigh' for PCA --- sklearn/decomposition/_base.py | 9 ++- sklearn/decomposition/_pca.py | 87 ++++++++++++++++--------- sklearn/decomposition/tests/test_pca.py | 69 +++++++++++++++----- sklearn/utils/extmath.py | 15 +++-- 4 files changed, 128 insertions(+), 52 deletions(-) diff --git a/sklearn/decomposition/_base.py b/sklearn/decomposition/_base.py index c4ccf92212fe9..dc27a6c1ac123 100644 --- a/sklearn/decomposition/_base.py +++ b/sklearn/decomposition/_base.py @@ -143,7 +143,14 @@ def transform(self, X): X = self._validate_data(X, dtype=[xp.float64, xp.float32], reset=False) if self.mean_ is not None: X = X - self.mean_ - X_transformed = X @ self.components_.T + return self._transform(X, xp) + + def _transform(self, X_centered, xp=None): + if xp is None: + xp, _ = get_namespace( + X_centered, self.components_, self.explained_variance_ + ) + X_transformed = X_centered @ self.components_.T if self.whiten: X_transformed /= xp.sqrt(self.explained_variance_) return X_transformed diff --git a/sklearn/decomposition/_pca.py b/sklearn/decomposition/_pca.py index 877baf4d4e81c..03ec4c0e19ee4 100644 --- a/sklearn/decomposition/_pca.py +++ b/sklearn/decomposition/_pca.py @@ -175,7 +175,8 @@ class PCA(_BasePCA): improve the predictive accuracy of the downstream estimators by making their data respect some hard-wired assumptions. - svd_solver : {'auto', 'full', 'arpack', 'randomized'}, default='auto' + svd_solver : {'auto', 'full', 'covariance_eigh', 'arpack', 'randomized'}, + default='auto' If auto : The solver is selected by a default policy based on `X.shape` and `n_components`: if the input data is larger than 500x500 and the @@ -186,6 +187,13 @@ class PCA(_BasePCA): If full : run exact full SVD calling the standard LAPACK solver via `scipy.linalg.svd` and select the components by postprocessing + If covariance_eigh : + precompute the covariance matrix (on centered data) and run a + classical eigenvalue decomposition on the covariance matrix + typically using LAPACK and select the components by postprocessing. + This solver is very efficient when the number of features is small + and not tractable otherwise (large memory footprint required to + materialize the covariance matrix). If arpack : run SVD truncated to n_components calling ARPACK solver via `scipy.sparse.linalg.svds`. It requires strictly @@ -195,6 +203,9 @@ class PCA(_BasePCA): .. versionadded:: 0.18.0 + .. versionchanged:: 1.4 + Added the 'covariance_eigh' solver. + tol : float, default=0.0 Tolerance for singular values computed by svd_solver == 'arpack'. Must be of range [0.0, infinity). @@ -372,7 +383,9 @@ class PCA(_BasePCA): ], "copy": ["boolean"], "whiten": ["boolean"], - "svd_solver": [StrOptions({"auto", "full", "arpack", "randomized"})], + "svd_solver": [ + StrOptions({"auto", "full", "covariance_eigh", "arpack", "randomized"}) + ], "tol": [Interval(Real, 0, None, closed="left")], "iterated_power": [ StrOptions({"auto"}), @@ -460,17 +473,20 @@ def fit_transform(self, X, y=None): This method returns a Fortran-ordered array. To convert it to a C-ordered array, use 'np.ascontiguousarray'. """ - U, S, Vt = self._fit(X) - U = U[:, : self.n_components_] + U, S, Vt, X_centered = self._fit(X) + if U is not None: + U = U[:, : self.n_components_] - if self.whiten: - # X_new = X * V / S * sqrt(n_samples) = U * sqrt(n_samples) - U *= sqrt(X.shape[0] - 1) - else: - # X_new = X * V = U * S * Vt * V = U * S - U *= S[: self.n_components_] + if self.whiten: + # X_new = X * V / S * sqrt(n_samples) = U * sqrt(n_samples) + U *= sqrt(X.shape[0] - 1) + else: + # X_new = X * V = U * S * Vt * V = U * S + U *= S[: self.n_components_] - return U + return U + else: + return self._transform(X_centered) def _fit(self, X): """Dispatch to the right submethod depending on the chosen solver.""" @@ -515,7 +531,7 @@ def _fit(self, X): self._fit_svd_solver = "full" # Call different fits for either full or truncated SVD - if self._fit_svd_solver == "full": + if self._fit_svd_solver in ("full", "covariance_eigh"): return self._fit_full(X, n_components) elif self._fit_svd_solver in ["arpack", "randomized"]: return self._fit_truncated(X, n_components, self._fit_svd_solver) @@ -533,28 +549,38 @@ def _fit_full(self, X, n_components): ) elif not 0 <= n_components <= min(n_samples, n_features): raise ValueError( - "n_components=%r must be between 0 and " - "min(n_samples, n_features)=%r with " - "svd_solver='full'" % (n_components, min(n_samples, n_features)) + f"n_components={n_components} must be between 0 and " + f"min(n_samples, n_features)={min(n_samples, n_features)} with " + f"svd_solver={self._fit_svd_solver!r}" ) # Center data self.mean_ = xp.mean(X, axis=0) X -= self.mean_ - if not is_array_api_compliant: - # Use scipy.linalg with NumPy/SciPy inputs for the sake of not - # introducing unanticipated behavior changes. In the long run we - # could instead decide to always use xp.linalg.svd for all inputs, - # but that would make this code rely on numpy's SVD instead of - # scipy's. It's not 100% clear whether they use the same LAPACK - # solver by default though (assuming both are built against the - # same BLAS). - U, S, Vt = linalg.svd(X, full_matrices=False) + if self._fit_svd_solver == "full": + if not is_array_api_compliant: + # Use scipy.linalg with NumPy/SciPy inputs for the sake of not + # introducing unanticipated behavior changes. In the long run we + # could instead decide to always use xp.linalg.svd for all inputs, + # but that would make this code rely on numpy's SVD instead of + # scipy's. It's not 100% clear whether they use the same LAPACK + # solver by default though (assuming both are built against the + # same BLAS). + U, S, Vt = linalg.svd(X, full_matrices=False) + else: + U, S, Vt = xp.linalg.svd(X, full_matrices=False) else: - U, S, Vt = xp.linalg.svd(X, full_matrices=False) + assert self._fit_svd_solver == "covariance_eigh" + C = X.T @ X + evals, Evecs = xp.linalg.eigh(C) + evals[evals < 0] = 0.0 + S = xp.sqrt(xp.flip(evals, axis=0)) + Vt = xp.flip(Evecs, axis=1).T + U = None + # flip eigenvectors' sign to enforce deterministic output - U, Vt = svd_flip(U, Vt) + U, Vt = svd_flip(U, Vt, u_based_decision=False) components_ = Vt @@ -589,7 +615,7 @@ def _fit_full(self, X, n_components): self.explained_variance_ratio_ = explained_variance_ratio_[:n_components] self.singular_values_ = singular_values_[:n_components] - return U, S, Vt + return U, S, Vt, X def _fit_truncated(self, X, n_components, svd_solver): """Fit the model by computing truncated SVD (by ARPACK or randomized) @@ -632,7 +658,7 @@ def _fit_truncated(self, X, n_components, svd_solver): # conventions, so reverse its outputs. S = S[::-1] # flip eigenvectors' sign to enforce deterministic output - U, Vt = svd_flip(U[:, ::-1], Vt[::-1]) + U, Vt = svd_flip(U[:, ::-1], Vt[::-1], u_based_decision=False) elif svd_solver == "randomized": # sign flipping is done inside @@ -642,9 +668,10 @@ def _fit_truncated(self, X, n_components, svd_solver): n_oversamples=self.n_oversamples, n_iter=self.iterated_power, power_iteration_normalizer=self.power_iteration_normalizer, - flip_sign=True, + flip_sign=False, random_state=random_state, ) + U, Vt = svd_flip(U, Vt, u_based_decision=False) self.n_samples_ = n_samples self.components_ = Vt @@ -668,7 +695,7 @@ def _fit_truncated(self, X, n_components, svd_solver): else: self.noise_variance_ = 0.0 - return U, S, Vt + return U, S, Vt, X def score_samples(self, X): """Return the log-likelihood of each sample. diff --git a/sklearn/decomposition/tests/test_pca.py b/sklearn/decomposition/tests/test_pca.py index d6e2516aef2f2..e02d7f97cc4a0 100644 --- a/sklearn/decomposition/tests/test_pca.py +++ b/sklearn/decomposition/tests/test_pca.py @@ -8,7 +8,7 @@ from sklearn import config_context, datasets from sklearn.base import clone -from sklearn.datasets import load_iris +from sklearn.datasets import load_iris, make_low_rank_matrix from sklearn.decomposition import PCA from sklearn.decomposition._pca import _assess_dimension, _infer_dimension from sklearn.utils._array_api import ( @@ -24,7 +24,7 @@ from sklearn.utils.fixes import CSR_CONTAINERS iris = datasets.load_iris() -PCA_SOLVERS = ["full", "arpack", "randomized", "auto"] +PCA_SOLVERS = ["full", "covariance_eigh", "arpack", "randomized", "auto"] @pytest.mark.parametrize("svd_solver", PCA_SOLVERS) @@ -115,26 +115,60 @@ def test_whitening(solver, copy): # we always center, so no test for non-centering. -@pytest.mark.parametrize("svd_solver", ["arpack", "randomized"]) -def test_pca_explained_variance_equivalence_solver(svd_solver): - rng = np.random.RandomState(0) - n_samples, n_features = 100, 80 - X = rng.randn(n_samples, n_features) - - pca_full = PCA(n_components=2, svd_solver="full") - pca_other = PCA(n_components=2, svd_solver=svd_solver, random_state=0) - +@pytest.mark.parametrize( + "n_samples, n_features", + [ + (100, 80), + (80, 100), + ], +) +@pytest.mark.parametrize("other_svd_solver", set(PCA_SOLVERS) - {"full", "auto"}) +def test_pca_solver_equivalence( + n_samples, n_features, other_svd_solver, global_random_seed +): + X = make_low_rank_matrix( + n_samples=n_samples, n_features=n_features, random_state=global_random_seed + ) + tols = dict(atol=1e-10, rtol=1e-12) + + extra_other_kwargs = {} + if other_svd_solver == "randomized": + # Only check for a truncated result with a large number of iterations + # to make sure that we can recover precise results. + n_components = 10 + extra_other_kwargs = {"iterated_power": 50} + elif other_svd_solver == "arpack": + # Test all components except the last one which cannot be estimated by + # arpack. + n_components = np.minimum(n_samples, n_features) - 1 + else: + # Test all components to high precision. + n_components = None + + pca_full = PCA(n_components=n_components, svd_solver="full") + pca_other = PCA( + n_components=n_components, + svd_solver=other_svd_solver, + random_state=global_random_seed, + **extra_other_kwargs, + ) pca_full.fit(X) pca_other.fit(X) - assert_allclose( - pca_full.explained_variance_, pca_other.explained_variance_, rtol=5e-2 - ) + assert_allclose(pca_full.explained_variance_, pca_other.explained_variance_, **tols) assert_allclose( pca_full.explained_variance_ratio_, pca_other.explained_variance_ratio_, - rtol=5e-2, + **tols, ) + reference_components = pca_full.components_ + other_components = pca_other.components_ + if n_components is None and n_features > n_samples: + # The last component can be arbitrary because of the centering of the + # data. Let's ignore it: + reference_components = reference_components[:-1] + other_components = other_components[:-1] + assert_allclose(reference_components, other_components, **tols) @pytest.mark.parametrize( @@ -142,8 +176,9 @@ def test_pca_explained_variance_equivalence_solver(svd_solver): [ np.random.RandomState(0).randn(100, 80), datasets.make_classification(100, 80, n_informative=78, random_state=0)[0], + np.random.RandomState(0).randn(10, 100), ], - ids=["random-data", "correlated-data"], + ids=["random-tall", "correlated-tall", "random-wide"], ) @pytest.mark.parametrize("svd_solver", PCA_SOLVERS) def test_pca_explained_variance_empirical(X, svd_solver): @@ -743,6 +778,8 @@ def check_array_api_get_precision(name, estimator, array_namepsace, device, dtyp [ PCA(n_components=2, svd_solver="full"), PCA(n_components=2, svd_solver="full", whiten=True), + PCA(n_components=2, svd_solver="covariance_eigh"), + PCA(n_components=2, svd_solver="covariance_eigh", whiten=True), PCA( n_components=2, svd_solver="randomized", diff --git a/sklearn/utils/extmath.py b/sklearn/utils/extmath.py index c2aa9d07e6635..51f198fbf87f6 100644 --- a/sklearn/utils/extmath.py +++ b/sklearn/utils/extmath.py @@ -842,12 +842,14 @@ def svd_flip(u, v, u_based_decision=True): Parameters u and v are the output of `linalg.svd` or :func:`~sklearn.utils.extmath.randomized_svd`, with matching inner dimensions so one can compute `np.dot(u * s, v)`. + u can be None if u_based_decision is False. v : ndarray Parameters u and v are the output of `linalg.svd` or :func:`~sklearn.utils.extmath.randomized_svd`, with matching inner dimensions so one can compute `np.dot(u * s, v)`. The input v should really be called vt to be consistent with scipy's output. + v can be None if u_based_decision is True. u_based_decision : bool, default=True If True, use the columns of u as the basis for sign flipping. @@ -862,24 +864,27 @@ def svd_flip(u, v, u_based_decision=True): v_adjusted : ndarray Array v with adjusted rows and the same dimensions as v. """ - xp, _ = get_namespace(u, v) - device = getattr(u, "device", None) + xp, _ = get_namespace(*[a for a in [u, v] if a is not None]) if u_based_decision: # columns of u, rows of v, or equivalently rows of u.T and v + device = getattr(u, "device", None) max_abs_u_cols = xp.argmax(xp.abs(u.T), axis=1) shift = xp.arange(u.T.shape[0], device=device) indices = max_abs_u_cols + shift * u.T.shape[1] signs = xp.sign(xp.take(xp.reshape(u.T, (-1,)), indices, axis=0)) u *= signs[np.newaxis, :] - v *= signs[:, np.newaxis] + if v is not None: + v *= signs[:, np.newaxis] else: # rows of v, columns of u + device = getattr(v, "device", None) max_abs_v_rows = xp.argmax(xp.abs(v), axis=1) shift = xp.arange(v.shape[0], device=device) indices = max_abs_v_rows + shift * v.shape[1] - signs = xp.sign(xp.take(xp.reshape(v, (-1,)), indices)) - u *= signs[np.newaxis, :] + signs = xp.sign(xp.take(xp.reshape(v, (-1,)), indices, axis=0)) + if u is not None: + u *= signs[np.newaxis, :] v *= signs[:, np.newaxis] return u, v From f9b3a0534a83b9fbcf52ae9e83f8cc371d66da2d Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 28 Sep 2023 15:45:46 +0200 Subject: [PATCH 02/70] Add changelog entry --- doc/whats_new/v1.4.rst | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/doc/whats_new/v1.4.rst b/doc/whats_new/v1.4.rst index a9ea738beca91..c00a9682c828f 100644 --- a/doc/whats_new/v1.4.rst +++ b/doc/whats_new/v1.4.rst @@ -196,9 +196,15 @@ Changelog from `None` to `auto` in version 1.6. :pr:`26634` by :user:`Alexandre Landeau ` and :user:`Alexandre Vigny `. +- |Enhancement| :class:`decomposition.PCA` now supports a new solver option + named `svd_solver="covariance_eigh"` which offers an order of magnitude + speed-up and reduced memory usage for datasets with a large number of data + points and a small number of features (say, less than 1000). + :pr:`27491` by :user:`Olivier Grisel `. + - |Enhancement| :class:`decomposition.PCA` now supports the Array API for the - `full` and `randomized` solvers (with QR power iterations). See - :ref:`array_api` for more details. + `full`, `covariance_eigh` and `randomized` solvers (with QR power + iterations). See :ref:`array_api` for more details. :pr:`26315` and :pr:`27098` by :user:`Mateusz Sokół `, :user:`Olivier Grisel ` and :user:`Edoardo Abati `. From cc570e45d886eb8920bdc5392f0752932a3524d6 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 28 Sep 2023 15:47:46 +0200 Subject: [PATCH 03/70] Change test parametrization to workaround pytest xdist bug --- sklearn/decomposition/tests/test_pca.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/sklearn/decomposition/tests/test_pca.py b/sklearn/decomposition/tests/test_pca.py index e02d7f97cc4a0..0054c96f2a9c2 100644 --- a/sklearn/decomposition/tests/test_pca.py +++ b/sklearn/decomposition/tests/test_pca.py @@ -115,17 +115,14 @@ def test_whitening(solver, copy): # we always center, so no test for non-centering. -@pytest.mark.parametrize( - "n_samples, n_features", - [ - (100, 80), - (80, 100), - ], -) +@pytest.mark.parametrize("data_shape", ["tall", "wide"]) @pytest.mark.parametrize("other_svd_solver", set(PCA_SOLVERS) - {"full", "auto"}) -def test_pca_solver_equivalence( - n_samples, n_features, other_svd_solver, global_random_seed -): +def test_pca_solver_equivalence(data_shape, other_svd_solver, global_random_seed): + if data_shape == "tall": + n_samples, n_features = 100, 80 + else: + n_samples, n_features = 80, 100 + X = make_low_rank_matrix( n_samples=n_samples, n_features=n_features, random_state=global_random_seed ) From d1e428751a8aafa082e3b0008843479afbcb59ad Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 28 Sep 2023 15:51:06 +0200 Subject: [PATCH 04/70] Avoid raising warning when dtype is None --- sklearn/decomposition/tests/test_pca.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/decomposition/tests/test_pca.py b/sklearn/decomposition/tests/test_pca.py index 0054c96f2a9c2..433ec0e6b14be 100644 --- a/sklearn/decomposition/tests/test_pca.py +++ b/sklearn/decomposition/tests/test_pca.py @@ -749,7 +749,7 @@ def check_array_api_get_precision(name, estimator, array_namepsace, device, dtyp assert_allclose( _convert_to_numpy(precision_xp, xp=xp), precision_np, - atol=_atol_for_type(dtype), + atol=_atol_for_type(iris_np.dtype), ) covariance_xp = estimator_xp.get_covariance() assert covariance_xp.shape == (4, 4) @@ -758,7 +758,7 @@ def check_array_api_get_precision(name, estimator, array_namepsace, device, dtyp assert_allclose( _convert_to_numpy(covariance_xp, xp=xp), covariance_np, - atol=_atol_for_type(dtype), + atol=_atol_for_type(iris_np.dtype), ) From 74a38ea2360336b97398b753546921d69b986d31 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 28 Sep 2023 16:04:52 +0200 Subject: [PATCH 05/70] Another try to workaround pytest xdist bug --- sklearn/decomposition/tests/test_pca.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sklearn/decomposition/tests/test_pca.py b/sklearn/decomposition/tests/test_pca.py index 433ec0e6b14be..34a425bd1a374 100644 --- a/sklearn/decomposition/tests/test_pca.py +++ b/sklearn/decomposition/tests/test_pca.py @@ -116,8 +116,10 @@ def test_whitening(solver, copy): @pytest.mark.parametrize("data_shape", ["tall", "wide"]) -@pytest.mark.parametrize("other_svd_solver", set(PCA_SOLVERS) - {"full", "auto"}) -def test_pca_solver_equivalence(data_shape, other_svd_solver, global_random_seed): +@pytest.mark.parametrize( + "other_svd_solver", sorted(list(set(PCA_SOLVERS) - {"full", "auto"})) +) +def test_pca_solver_equivalence(other_svd_solver, global_random_seed, data_shape): if data_shape == "tall": n_samples, n_features = 100, 80 else: From fbb1575dbaf92e007f29cb3123750c3b53f88ed4 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 28 Sep 2023 16:43:24 +0200 Subject: [PATCH 06/70] Use the new svd_flip config consistently and update docstrings accordingly --- sklearn/decomposition/_kernel_pca.py | 4 +--- sklearn/decomposition/_truncated_svd.py | 3 ++- sklearn/pipeline.py | 4 ++-- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/sklearn/decomposition/_kernel_pca.py b/sklearn/decomposition/_kernel_pca.py index ccf79e896f210..a69d1cdba5087 100644 --- a/sklearn/decomposition/_kernel_pca.py +++ b/sklearn/decomposition/_kernel_pca.py @@ -363,9 +363,7 @@ def _fit_transform(self, K): ) # flip eigenvectors' sign to enforce deterministic output - self.eigenvectors_, _ = svd_flip( - self.eigenvectors_, np.zeros_like(self.eigenvectors_).T - ) + self.eigenvectors_, _ = svd_flip(self.eigenvectors_, None) # sort eigenvectors in descending order indices = self.eigenvalues_.argsort()[::-1] diff --git a/sklearn/decomposition/_truncated_svd.py b/sklearn/decomposition/_truncated_svd.py index 725683e8d46c6..f5fa7629acc94 100644 --- a/sklearn/decomposition/_truncated_svd.py +++ b/sklearn/decomposition/_truncated_svd.py @@ -235,7 +235,8 @@ def fit_transform(self, X, y=None): # svds doesn't abide by scipy.linalg.svd/randomized_svd # conventions, so reverse its outputs. Sigma = Sigma[::-1] - U, VT = svd_flip(U[:, ::-1], VT[::-1]) + # u_based_decision=False is needed to be consistent with PCA. + U, VT = svd_flip(U[:, ::-1], VT[::-1], u_based_decision=False) elif self.algorithm == "randomized": if self.n_components > X.shape[1]: diff --git a/sklearn/pipeline.py b/sklearn/pipeline.py index 014ecad1c6de0..9f1c9402fbfec 100644 --- a/sklearn/pipeline.py +++ b/sklearn/pipeline.py @@ -1401,8 +1401,8 @@ class FeatureUnion(_RoutingNotSupportedMixin, TransformerMixin, _BaseComposition [-1.5 , 5.7..., -0.4...]]) >>> # An estimator's parameter can be set using '__' syntax >>> union.set_params(pca__n_components=1).fit_transform(X) - array([[ 1.5 , 3.0...], - [-1.5 , 5.7...]]) + array([[-1.5 , 3.0...], + [ 1.5 , 5.7...]]) For a more detailed example of usage, see :ref:`sphx_glr_auto_examples_compose_plot_feature_union.py`. From 1734878b945675afc07c4f3201c74f96997f4475 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 28 Sep 2023 16:47:25 +0200 Subject: [PATCH 07/70] One more missing sign update in the FeatureUnion doctest --- sklearn/pipeline.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/pipeline.py b/sklearn/pipeline.py index 9f1c9402fbfec..8c856108ae18d 100644 --- a/sklearn/pipeline.py +++ b/sklearn/pipeline.py @@ -1397,8 +1397,8 @@ class FeatureUnion(_RoutingNotSupportedMixin, TransformerMixin, _BaseComposition ... ("svd", TruncatedSVD(n_components=2))]) >>> X = [[0., 1., 3], [2., 2., 5]] >>> union.fit_transform(X) - array([[ 1.5 , 3.0..., 0.8...], - [-1.5 , 5.7..., -0.4...]]) + array([[-1.5 , 3.0..., 0.8...], + [ 1.5 , 5.7..., -0.4...]]) >>> # An estimator's parameter can be set using '__' syntax >>> union.set_params(pca__n_components=1).fit_transform(X) array([[-1.5 , 3.0...], From 62116d4582a7160741efd0251ed9ba93d0825188 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 28 Sep 2023 16:52:40 +0200 Subject: [PATCH 08/70] Document the changed component sign heuristic --- doc/whats_new/v1.4.rst | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/doc/whats_new/v1.4.rst b/doc/whats_new/v1.4.rst index c00a9682c828f..f2465d8b502c0 100644 --- a/doc/whats_new/v1.4.rst +++ b/doc/whats_new/v1.4.rst @@ -19,6 +19,13 @@ parameters, may produce different models from the previous version. This often occurs due to changes in the modelling logic (bug fixes or enhancements), or in random sampling procedures. +- :class:`decomposition.PCA` and :class:`decomposition.TruncatedSVD` now set + the sign of the `components_` attribute based on the components values + instead of using the transformed data as reference. This change is needed to + be able to offer consistent component signs across all `PCA` solvers, + including the new `svd_solver="covariance_eigh"` option introduced in this + release. + Changes impacting all modules ----------------------------- From 7b071a6c6c05d9456688be8a858a4876da660dbf Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 28 Sep 2023 17:37:06 +0200 Subject: [PATCH 09/70] More svd_flip update --- doc/whats_new/v1.4.rst | 12 ++++++------ sklearn/decomposition/_sparse_pca.py | 2 +- sklearn/decomposition/_truncated_svd.py | 2 ++ sklearn/decomposition/tests/test_sparse_pca.py | 6 +++++- 4 files changed, 14 insertions(+), 8 deletions(-) diff --git a/doc/whats_new/v1.4.rst b/doc/whats_new/v1.4.rst index f2465d8b502c0..af0f96513dfb2 100644 --- a/doc/whats_new/v1.4.rst +++ b/doc/whats_new/v1.4.rst @@ -19,12 +19,12 @@ parameters, may produce different models from the previous version. This often occurs due to changes in the modelling logic (bug fixes or enhancements), or in random sampling procedures. -- :class:`decomposition.PCA` and :class:`decomposition.TruncatedSVD` now set - the sign of the `components_` attribute based on the components values - instead of using the transformed data as reference. This change is needed to - be able to offer consistent component signs across all `PCA` solvers, - including the new `svd_solver="covariance_eigh"` option introduced in this - release. +- :class:`decomposition.PCA`, :class:`decomposition.SparsePCA` and + :class:`decomposition.TruncatedSVD` now set the sign of the `components_` + attribute based on the components values instead of using the transformed + data as reference. This change is needed to be able to offer consistent + component signs across all `PCA` solvers, including the new + `svd_solver="covariance_eigh"` option introduced in this release. Changes impacting all modules ----------------------------- diff --git a/sklearn/decomposition/_sparse_pca.py b/sklearn/decomposition/_sparse_pca.py index aa4dec2fb7ee9..f78084d284fb8 100644 --- a/sklearn/decomposition/_sparse_pca.py +++ b/sklearn/decomposition/_sparse_pca.py @@ -324,7 +324,7 @@ def _fit(self, X, n_components, random_state): return_n_iter=True, ) # flip eigenvectors' sign to enforce deterministic output - code, dictionary = svd_flip(code, dictionary, u_based_decision=False) + code, dictionary = svd_flip(code, dictionary, u_based_decision=True) self.components_ = code.T components_norm = np.linalg.norm(self.components_, axis=1)[:, np.newaxis] components_norm[components_norm == 0] = 1 diff --git a/sklearn/decomposition/_truncated_svd.py b/sklearn/decomposition/_truncated_svd.py index f5fa7629acc94..413e30b510192 100644 --- a/sklearn/decomposition/_truncated_svd.py +++ b/sklearn/decomposition/_truncated_svd.py @@ -251,7 +251,9 @@ def fit_transform(self, X, y=None): n_oversamples=self.n_oversamples, power_iteration_normalizer=self.power_iteration_normalizer, random_state=random_state, + flip_sign=False, ) + U, VT = svd_flip(U, VT, u_based_decision=False) self.components_ = VT diff --git a/sklearn/decomposition/tests/test_sparse_pca.py b/sklearn/decomposition/tests/test_sparse_pca.py index 848d5d9d7ee34..4b7834c7bfda9 100644 --- a/sklearn/decomposition/tests/test_sparse_pca.py +++ b/sklearn/decomposition/tests/test_sparse_pca.py @@ -14,6 +14,7 @@ assert_array_almost_equal, if_safe_multiprocessing_with_blas, ) +from sklearn.utils.extmath import svd_flip def generate_toy_data(n_components, n_samples, image_size, random_state=None): @@ -114,7 +115,10 @@ def test_initialization(): n_components=3, U_init=U_init, V_init=V_init, max_iter=0, random_state=rng ) model.fit(rng.randn(5, 4)) - assert_allclose(model.components_, V_init / np.linalg.norm(V_init, axis=1)[:, None]) + + expected_components = V_init / np.linalg.norm(V_init, axis=1, keepdims=True) + expected_components = svd_flip(expected_components.T, None)[0].T + assert_allclose(model.components_, expected_components) def test_mini_batch_correct_shapes(): From 973145f1d8ef2cdce1353d4044b0d4007a0b70fc Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 28 Sep 2023 17:41:29 +0200 Subject: [PATCH 10/70] Attempt at fixing indentation problem for svd_solver docstring --- sklearn/decomposition/_pca.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/decomposition/_pca.py b/sklearn/decomposition/_pca.py index 03ec4c0e19ee4..b4b2dde4a76ca 100644 --- a/sklearn/decomposition/_pca.py +++ b/sklearn/decomposition/_pca.py @@ -175,7 +175,7 @@ class PCA(_BasePCA): improve the predictive accuracy of the downstream estimators by making their data respect some hard-wired assumptions. - svd_solver : {'auto', 'full', 'covariance_eigh', 'arpack', 'randomized'}, + svd_solver : {'auto', 'full', 'covariance_eigh', 'arpack', 'randomized'},\ default='auto' If auto : The solver is selected by a default policy based on `X.shape` and From 4ac89357505658554e6cc25132519700cd71fd10 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 28 Sep 2023 18:11:38 +0200 Subject: [PATCH 11/70] One more missing sign update in the FeatureUnion doctest --- sklearn/pipeline.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/pipeline.py b/sklearn/pipeline.py index 8c856108ae18d..43a4ba8a5bf0c 100644 --- a/sklearn/pipeline.py +++ b/sklearn/pipeline.py @@ -1397,8 +1397,8 @@ class FeatureUnion(_RoutingNotSupportedMixin, TransformerMixin, _BaseComposition ... ("svd", TruncatedSVD(n_components=2))]) >>> X = [[0., 1., 3], [2., 2., 5]] >>> union.fit_transform(X) - array([[-1.5 , 3.0..., 0.8...], - [ 1.5 , 5.7..., -0.4...]]) + array([[-1.5 , 3.0..., -0.8...], + [ 1.5 , 5.7..., 0.4...]]) >>> # An estimator's parameter can be set using '__' syntax >>> union.set_params(pca__n_components=1).fit_transform(X) array([[-1.5 , 3.0...], From 0b3a19218c6579cd1b6204dfc14927b35f3004ae Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 28 Sep 2023 22:48:29 +0200 Subject: [PATCH 12/70] Avoid SparsePCA's internal DictionaryLearning instance generating code as a pandas dataframe --- sklearn/decomposition/_dict_learning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/decomposition/_dict_learning.py b/sklearn/decomposition/_dict_learning.py index 7fc0915f2ea8e..7e181e40adfe4 100644 --- a/sklearn/decomposition/_dict_learning.py +++ b/sklearn/decomposition/_dict_learning.py @@ -1227,7 +1227,7 @@ def dict_learning( positive_code=positive_code, positive_dict=positive_dict, transform_max_iter=method_max_iter, - ) + ).set_output(transform="default") code = estimator.fit_transform(X) if return_n_iter: return ( From f6815e59c6ae9b28a36daa2d3ece1b5df9b5a324 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Fri, 29 Sep 2023 09:36:37 +0200 Subject: [PATCH 13/70] TST test fit_transform and transform equivalence --- sklearn/decomposition/tests/test_pca.py | 28 ++++++++++++++++++++----- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/sklearn/decomposition/tests/test_pca.py b/sklearn/decomposition/tests/test_pca.py index 34a425bd1a374..5245b5c2e4517 100644 --- a/sklearn/decomposition/tests/test_pca.py +++ b/sklearn/decomposition/tests/test_pca.py @@ -121,13 +121,18 @@ def test_whitening(solver, copy): ) def test_pca_solver_equivalence(other_svd_solver, global_random_seed, data_shape): if data_shape == "tall": - n_samples, n_features = 100, 80 + n_samples, n_features = 100, 30 else: - n_samples, n_features = 80, 100 + n_samples, n_features = 30, 100 - X = make_low_rank_matrix( + X_train = make_low_rank_matrix( n_samples=n_samples, n_features=n_features, random_state=global_random_seed ) + X_test = make_low_rank_matrix( + 10, + n_features=n_features, + random_state=global_random_seed + 1, + ) tols = dict(atol=1e-10, rtol=1e-12) extra_other_kwargs = {} @@ -151,8 +156,8 @@ def test_pca_solver_equivalence(other_svd_solver, global_random_seed, data_shape random_state=global_random_seed, **extra_other_kwargs, ) - pca_full.fit(X) - pca_other.fit(X) + X_trans_full_train = pca_full.fit_transform(X_train) + X_trans_other_train = pca_other.fit_transform(X_train) assert_allclose(pca_full.explained_variance_, pca_other.explained_variance_, **tols) assert_allclose( @@ -169,6 +174,19 @@ def test_pca_solver_equivalence(other_svd_solver, global_random_seed, data_shape other_components = other_components[:-1] assert_allclose(reference_components, other_components, **tols) + # As a result the output of fit_transform should be the same: + assert_allclose(X_trans_other_train, X_trans_full_train, **tols) + + # And similarly for the output of transform on new data (except for the + # last component that can be underdetermined): + X_trans_full_test = pca_full.transform(X_test) + X_trans_other_test = pca_other.transform(X_test) + if n_components is None and n_features > n_samples: + X_trans_full_test = X_trans_full_test[:, :-1] + X_trans_other_test = X_trans_other_test[:, :-1] + + assert_allclose(X_trans_other_test, X_trans_full_test, **tols) + @pytest.mark.parametrize( "X", From bbd876196ed275c3f59a6b25a08d7cf334091ace Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Fri, 29 Sep 2023 09:43:18 +0200 Subject: [PATCH 14/70] TST cleaner way to define X_train and X_test --- sklearn/decomposition/tests/test_pca.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sklearn/decomposition/tests/test_pca.py b/sklearn/decomposition/tests/test_pca.py index 5245b5c2e4517..73318f61763f6 100644 --- a/sklearn/decomposition/tests/test_pca.py +++ b/sklearn/decomposition/tests/test_pca.py @@ -124,15 +124,15 @@ def test_pca_solver_equivalence(other_svd_solver, global_random_seed, data_shape n_samples, n_features = 100, 30 else: n_samples, n_features = 30, 100 + n_samples_test = 10 - X_train = make_low_rank_matrix( - n_samples=n_samples, n_features=n_features, random_state=global_random_seed - ) - X_test = make_low_rank_matrix( - 10, + X = make_low_rank_matrix( + n_samples=n_samples + n_samples_test, n_features=n_features, - random_state=global_random_seed + 1, + random_state=global_random_seed, ) + X_train, X_test = X[:n_samples], X[n_samples:] + tols = dict(atol=1e-10, rtol=1e-12) extra_other_kwargs = {} From 29798d5d5bfb9bcd2f629110fab4814515d5d89e Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Fri, 29 Sep 2023 15:58:46 +0200 Subject: [PATCH 15/70] Add new benchmark script to check the robustness of the auto policy --- benchmarks/bench_pca_solvers.py | 162 ++++++++++++++++++++++++++++++++ 1 file changed, 162 insertions(+) create mode 100644 benchmarks/bench_pca_solvers.py diff --git a/benchmarks/bench_pca_solvers.py b/benchmarks/bench_pca_solvers.py new file mode 100644 index 0000000000000..5f84b635c255e --- /dev/null +++ b/benchmarks/bench_pca_solvers.py @@ -0,0 +1,162 @@ +# %% +# +# This benchmark compares the speed of PCA solvers on datasets of different +# sizes in order to determine the best solver to select by default via the +# "auto" heuristic. +# +# Note: we do not control for the accuracy of the solvers: we assume that all +# solvers yield transformed data with similar explained variance. This +# assumption is generally true, except for the randomized solver that might +# require more power iterations. +# +# We generate synthetic data with dimensions that are useful to plot: +# - time vs n_samples for a fixed n_features and, +# - time vs n_features for a fixed n_samples for a fixed n_features. +import itertools +from math import log10 +from time import perf_counter + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + +from sklearn import config_context +from sklearn.datasets import make_low_rank_matrix +from sklearn.decomposition import PCA + +REF_DIMS = [100, 1000, 10_000] +data_shapes = [] +for ref_dim in REF_DIMS: + data_shapes.extend([(ref_dim, 10**i) for i in range(1, 9 - int(log10(ref_dim)))]) + data_shapes.extend([(10**i, ref_dim) for i in range(1, 9 - int(log10(ref_dim)))]) + +# Remove duplicates: +data_shapes = sorted(set(data_shapes)) + +print("Generating test datasets...") +datasets = [ + make_low_rank_matrix(n_samples, n_features, random_state=0) + for n_samples, n_features in data_shapes +] + + +# %% +def measure_one(data, n_components, solver, method_name="fit"): + print( + f"Benchmarking {solver=!r}, {n_components=}, {method_name=!r} on data with" + f" shape {data.shape}" + ) + pca = PCA(n_components=n_components, svd_solver=solver, random_state=0) + timings = [] + elapsed = 0 + method = getattr(pca, method_name) + with config_context(assume_finite=True): + while elapsed < 0.5: + tic = perf_counter() + method(data) + duration = perf_counter() - tic + timings.append(duration) + elapsed += duration + return np.median(timings) + + +SOLVERS = ["full", "covariance_eigh", "arpack", "randomized", "auto"] +measurements = [] +for data, n_components, method_name in itertools.product( + datasets, [2, 50], ["fit", "fit_transform"] +): + if n_components >= min(data.shape): + continue + for solver in SOLVERS: + if solver == "covariance_eigh" and data.shape[1] > 1000: + # Too much memory and too slow. + continue + if solver in ["arpack", "full"] and log10(data.size) > 7: + # Too slow, in particular for the full solver. + continue + time = measure_one(data, n_components, solver, method_name=method_name) + measurements.append( + { + "n_components": n_components, + "n_samples": data.shape[0], + "n_features": data.shape[1], + "time": time, + "solver": solver, + "method_name": method_name, + } + ) +measurements = pd.DataFrame(measurements) +measurements.to_csv("bench_pca_solvers.csv", index=False) + +# %% +all_method_names = measurements["method_name"].unique() +all_n_components = measurements["n_components"].unique() + +for method_name in all_method_names: + fig, axes = plt.subplots( + figsize=(16, 16), + nrows=len(REF_DIMS), + ncols=len(all_n_components), + sharey=True, + constrained_layout=True, + ) + fig.suptitle(f"Benchmarks for PCA.{method_name}", fontsize=16) + + for row_idx, ref_dim in enumerate(REF_DIMS): + for n_components, ax in zip(all_n_components, axes[row_idx]): + for solver in SOLVERS: + if solver == "auto": + style_kwargs = dict(linewidth=2, color="black", style="--") + else: + style_kwargs = dict(style="o-") + ax.set( + title=f"n_components={n_components}, n_features={ref_dim}", + ylabel="time (s)", + ) + measurements.query( + "n_components == @n_components and n_features == @ref_dim" + " and solver == @solver and method_name == @method_name" + ).plot.line( + x="n_samples", + y="time", + label=solver, + logx=True, + logy=True, + ax=ax, + **style_kwargs, + ) +# %% +for method_name in all_method_names: + fig, axes = plt.subplots( + figsize=(16, 16), + nrows=len(REF_DIMS), + ncols=len(all_n_components), + sharey=True, + ) + fig.suptitle(f"Benchmarks for PCA.{method_name}", fontsize=16) + + for row_idx, ref_dim in enumerate(REF_DIMS): + for n_components, ax in zip(all_n_components, axes[row_idx]): + for solver in SOLVERS: + if solver == "auto": + style_kwargs = dict(linewidth=2, color="black", style="--") + else: + style_kwargs = dict(style="o-") + ax.set( + title=f"n_components={n_components}, n_samples={ref_dim}", + ylabel="time (s)", + ) + measurements.query( + "n_components == @n_components and n_samples == @ref_dim " + " and solver == @solver and method_name == @method_name" + ).plot.line( + x="n_features", + y="time", + label=solver, + logx=True, + logy=True, + ax=ax, + **style_kwargs, + ) + +# %% From 00cd121374f28d3c09b80014deff9b350f7e3d4c Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Fri, 29 Sep 2023 16:00:41 +0200 Subject: [PATCH 16/70] Ignore CSV file with collected benchmark data --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 199c2bd85d997..7ecad1dc18067 100644 --- a/.gitignore +++ b/.gitignore @@ -53,6 +53,7 @@ nips2010_pdf/ examples/cluster/joblib reuters/ benchmarks/bench_covertype_data/ +bench_pca_solvers.csv *.prefs .pydevproject From 961cf6fae46dfb8f3312eb4b4dd387ab3f989081 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Fri, 29 Sep 2023 16:02:46 +0200 Subject: [PATCH 17/70] Adjust svd_solver='auto' --- sklearn/decomposition/_pca.py | 6 +++++- sklearn/decomposition/tests/test_pca.py | 21 +++++++++++++-------- 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/sklearn/decomposition/_pca.py b/sklearn/decomposition/_pca.py index b4b2dde4a76ca..aff4f3202d5f3 100644 --- a/sklearn/decomposition/_pca.py +++ b/sklearn/decomposition/_pca.py @@ -521,8 +521,12 @@ def _fit(self, X): # Handle svd_solver self._fit_svd_solver = self.svd_solver if self._fit_svd_solver == "auto": + # Tall and skinny problems are best handled by precomputing the + # covariance matrix. + if X.shape[1] <= 500 and X.shape[0] >= 10 * X.shape[1]: + self._fit_svd_solver = "covariance_eigh" # Small problem or n_components == 'mle', just call full PCA - if max(X.shape) <= 500 or n_components == "mle": + elif max(X.shape) <= 500 or n_components == "mle": self._fit_svd_solver = "full" elif 1 <= n_components < 0.8 * min(X.shape): self._fit_svd_solver = "randomized" diff --git a/sklearn/decomposition/tests/test_pca.py b/sklearn/decomposition/tests/test_pca.py index 73318f61763f6..de68f2952df9f 100644 --- a/sklearn/decomposition/tests/test_pca.py +++ b/sklearn/decomposition/tests/test_pca.py @@ -533,23 +533,28 @@ def test_pca_zero_noise_variance_edge_cases(svd_solver): @pytest.mark.parametrize( - "data, n_components, expected_solver", - [ # case: n_components in (0,1) => 'full' - (np.random.RandomState(0).uniform(size=(1000, 50)), 0.5, "full"), - # case: max(X.shape) <= 500 => 'full' - (np.random.RandomState(0).uniform(size=(10, 50)), 5, "full"), + "n_samples, n_features, n_components, expected_solver", + [ + # case: n_samples < 10 * n_features and max(X.shape) <= 500 => 'full' + (10, 50, 5, "full"), + # case: n_samples > 10 * n_features and n_features < 500 => 'covariance_eigh' + (1000, 50, 50, "covariance_eigh"), # case: n_components >= .8 * min(X.shape) => 'full' - (np.random.RandomState(0).uniform(size=(1000, 50)), 50, "full"), + (1000, 500, 400, "full"), # n_components >= 1 and n_components < .8*min(X.shape) => 'randomized' - (np.random.RandomState(0).uniform(size=(1000, 50)), 10, "randomized"), + (1000, 500, 10, "randomized"), + # case: n_components in (0,1) => 'full' + (1000, 500, 0.5, "full"), ], ) -def test_pca_svd_solver_auto(data, n_components, expected_solver): +def test_pca_svd_solver_auto(n_samples, n_features, n_components, expected_solver): + data = np.random.RandomState(0).uniform(size=(n_samples, n_features)) pca_auto = PCA(n_components=n_components, random_state=0) pca_test = PCA( n_components=n_components, svd_solver=expected_solver, random_state=0 ) pca_auto.fit(data) + assert pca_auto._fit_svd_solver == expected_solver pca_test.fit(data) assert_allclose(pca_auto.components_, pca_test.components_) From c89cd131ee3364c2a49510a167d285446186a12e Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Fri, 29 Sep 2023 17:05:30 +0200 Subject: [PATCH 18/70] Make test_pca_solver_equivalence more generic at dealing with unstable components --- sklearn/decomposition/tests/test_pca.py | 47 +++++++++++++++---------- 1 file changed, 28 insertions(+), 19 deletions(-) diff --git a/sklearn/decomposition/tests/test_pca.py b/sklearn/decomposition/tests/test_pca.py index de68f2952df9f..59e4960895c43 100644 --- a/sklearn/decomposition/tests/test_pca.py +++ b/sklearn/decomposition/tests/test_pca.py @@ -115,22 +115,33 @@ def test_whitening(solver, copy): # we always center, so no test for non-centering. -@pytest.mark.parametrize("data_shape", ["tall", "wide"]) @pytest.mark.parametrize( "other_svd_solver", sorted(list(set(PCA_SOLVERS) - {"full", "auto"})) ) -def test_pca_solver_equivalence(other_svd_solver, global_random_seed, data_shape): +@pytest.mark.parametrize("data_shape", ["tall", "wide"]) +@pytest.mark.parametrize("rank_deficient", [False, True]) +@pytest.mark.parametrize("whiten", [False, True]) +def test_pca_solver_equivalence( + other_svd_solver, data_shape, rank_deficient, whiten, global_random_seed +): if data_shape == "tall": n_samples, n_features = 100, 30 else: n_samples, n_features = 30, 100 n_samples_test = 10 - X = make_low_rank_matrix( - n_samples=n_samples + n_samples_test, - n_features=n_features, - random_state=global_random_seed, - ) + if rank_deficient: + rng = np.random.default_rng(global_random_seed) + rank = min(n_samples, n_features) // 2 + X = rng.standard_normal( + size=(n_samples + n_samples_test, rank) + ) @ rng.standard_normal(size=(rank, n_features)) + else: + X = make_low_rank_matrix( + n_samples=n_samples + n_samples_test, + n_features=n_features, + random_state=global_random_seed, + ) X_train, X_test = X[:n_samples], X[n_samples:] tols = dict(atol=1e-10, rtol=1e-12) @@ -167,25 +178,23 @@ def test_pca_solver_equivalence(other_svd_solver, global_random_seed, data_shape ) reference_components = pca_full.components_ other_components = pca_other.components_ - if n_components is None and n_features > n_samples: - # The last component can be arbitrary because of the centering of the - # data. Let's ignore it: - reference_components = reference_components[:-1] - other_components = other_components[:-1] - assert_allclose(reference_components, other_components, **tols) + + # For some choice of n_components and data distribution, some components + # might be pure noise, let's ignore them in the comparison: + stable = pca_full.explained_variance_ > 1e-12 + assert stable.sum() > 1 + assert_allclose(reference_components[stable], other_components[stable], **tols) # As a result the output of fit_transform should be the same: - assert_allclose(X_trans_other_train, X_trans_full_train, **tols) + assert_allclose( + X_trans_other_train[:, stable], X_trans_full_train[:, stable], **tols + ) # And similarly for the output of transform on new data (except for the # last component that can be underdetermined): X_trans_full_test = pca_full.transform(X_test) X_trans_other_test = pca_other.transform(X_test) - if n_components is None and n_features > n_samples: - X_trans_full_test = X_trans_full_test[:, :-1] - X_trans_other_test = X_trans_other_test[:, :-1] - - assert_allclose(X_trans_other_test, X_trans_full_test, **tols) + assert_allclose(X_trans_other_test[:, stable], X_trans_full_test[:, stable], **tols) @pytest.mark.parametrize( From 96e1f9b1b57d9dec55b102dab546441a41de78c7 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Fri, 29 Sep 2023 17:16:17 +0200 Subject: [PATCH 19/70] Fix sklearn/decomposition/tests/test_incremental_pca.py:test_whitening --- .../decomposition/tests/test_incremental_pca.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/sklearn/decomposition/tests/test_incremental_pca.py b/sklearn/decomposition/tests/test_incremental_pca.py index 5d7c8aa03f174..b6a082b3fb9da 100644 --- a/sklearn/decomposition/tests/test_incremental_pca.py +++ b/sklearn/decomposition/tests/test_incremental_pca.py @@ -389,14 +389,27 @@ def test_whitening(): 1000, 10, tail_strength=0.0, effective_rank=2, random_state=1999 ) prec = 3 - n_samples, n_features = X.shape for nc in [None, 9]: pca = PCA(whiten=True, n_components=nc).fit(X) ipca = IncrementalPCA(whiten=True, n_components=nc, batch_size=250).fit(X) + # Since the data is rank deficient, some components are pure noise. We + # should not expect those dimensions to carry any signal and their + # values might be arbitrarily changed by implementation details of the + # internal SVD solver. We therefore mask them out before comparison. + stable_mask = pca.explained_variance_ratio_ > 1e-12 + Xt_pca = pca.transform(X) Xt_ipca = ipca.transform(X) - assert_almost_equal(np.abs(Xt_pca), np.abs(Xt_ipca), decimal=prec) + assert_almost_equal( + np.abs(Xt_pca)[:, stable_mask], + np.abs(Xt_ipca)[:, stable_mask], + decimal=prec, + ) + + # The noisy dimensions are in the null space of the inverse transform, + # so they are not influencing the reconstruction. We therefore don't + # need to apply the mask here. Xinv_ipca = ipca.inverse_transform(Xt_ipca) Xinv_pca = pca.inverse_transform(Xt_pca) assert_almost_equal(X, Xinv_ipca, decimal=prec) From 9552626cac144f627674666ede1525074c6c80d4 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Fri, 29 Sep 2023 17:23:16 +0200 Subject: [PATCH 20/70] Update the docstring to reflect the new auto policy --- sklearn/decomposition/_pca.py | 32 +++++++++++++++++--------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/sklearn/decomposition/_pca.py b/sklearn/decomposition/_pca.py index aff4f3202d5f3..137d3170fb4d5 100644 --- a/sklearn/decomposition/_pca.py +++ b/sklearn/decomposition/_pca.py @@ -177,29 +177,31 @@ class PCA(_BasePCA): svd_solver : {'auto', 'full', 'covariance_eigh', 'arpack', 'randomized'},\ default='auto' - If auto : + "auto" : The solver is selected by a default policy based on `X.shape` and - `n_components`: if the input data is larger than 500x500 and the - number of components to extract is lower than 80% of the smallest - dimension of the data, then the more efficient 'randomized' - method is enabled. Otherwise the exact full SVD is computed and - optionally truncated afterwards. - If full : - run exact full SVD calling the standard LAPACK solver via + `n_components`: if the input data has fewer than 500 features and + more than 10 times as many samples, then the more "covariance_eigh" + solver is used. Otherwise, if the input data is larger than 500x500 + and the number of components to extract is lower than 80% of the + smallest dimension of the data, then the more efficient + 'randomized' method is enabled. Otherwise the exact "full" SVD is + computed and optionally truncated afterwards. + "full" : + Run exact full SVD calling the standard LAPACK solver via `scipy.linalg.svd` and select the components by postprocessing - If covariance_eigh : - precompute the covariance matrix (on centered data) and run a + "covariance_eigh" : + Precompute the covariance matrix (on centered data) and run a classical eigenvalue decomposition on the covariance matrix typically using LAPACK and select the components by postprocessing. This solver is very efficient when the number of features is small and not tractable otherwise (large memory footprint required to materialize the covariance matrix). - If arpack : - run SVD truncated to n_components calling ARPACK solver via + "arpack" : + Run SVD truncated to `n_components` calling ARPACK solver via `scipy.sparse.linalg.svds`. It requires strictly - 0 < n_components < min(X.shape) - If randomized : - run randomized SVD by the method of Halko et al. + `0 < n_components < min(X.shape)` + "randomized" : + Run randomized SVD by the method of Halko et al. .. versionadded:: 0.18.0 From 6e3a62bc55dd7a1c921c0d972310a72f93827c52 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Fri, 29 Sep 2023 17:36:05 +0200 Subject: [PATCH 21/70] More explicit figure titles --- benchmarks/bench_pca_solvers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmarks/bench_pca_solvers.py b/benchmarks/bench_pca_solvers.py index 5f84b635c255e..b76df1120cbb1 100644 --- a/benchmarks/bench_pca_solvers.py +++ b/benchmarks/bench_pca_solvers.py @@ -100,7 +100,7 @@ def measure_one(data, n_components, solver, method_name="fit"): sharey=True, constrained_layout=True, ) - fig.suptitle(f"Benchmarks for PCA.{method_name}", fontsize=16) + fig.suptitle(f"Benchmarks for PCA.{method_name}, varying n_samples", fontsize=16) for row_idx, ref_dim in enumerate(REF_DIMS): for n_components, ax in zip(all_n_components, axes[row_idx]): @@ -133,7 +133,7 @@ def measure_one(data, n_components, solver, method_name="fit"): ncols=len(all_n_components), sharey=True, ) - fig.suptitle(f"Benchmarks for PCA.{method_name}", fontsize=16) + fig.suptitle(f"Benchmarks for PCA.{method_name}, varying n_features", fontsize=16) for row_idx, ref_dim in enumerate(REF_DIMS): for n_components, ax in zip(all_n_components, axes[row_idx]): From 4f15943a6b5c8ea19bef6540480fd83daed0bd3c Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Fri, 29 Sep 2023 17:47:56 +0200 Subject: [PATCH 22/70] Update the docstring to mention the difference in numerical stability --- sklearn/decomposition/_pca.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/sklearn/decomposition/_pca.py b/sklearn/decomposition/_pca.py index 137d3170fb4d5..6c99f8f3c6ec5 100644 --- a/sklearn/decomposition/_pca.py +++ b/sklearn/decomposition/_pca.py @@ -190,12 +190,15 @@ class PCA(_BasePCA): Run exact full SVD calling the standard LAPACK solver via `scipy.linalg.svd` and select the components by postprocessing "covariance_eigh" : - Precompute the covariance matrix (on centered data) and run a + Precompute the covariance matrix (on centered data), run a classical eigenvalue decomposition on the covariance matrix typically using LAPACK and select the components by postprocessing. This solver is very efficient when the number of features is small and not tractable otherwise (large memory footprint required to - materialize the covariance matrix). + materialize the covariance matrix). Also not that compared to the + "full" solver, this solver effectively doubles the condition number + and is therefore less numerical stable (e.g. on input data with a + large range of singular values). "arpack" : Run SVD truncated to `n_components` calling ARPACK solver via `scipy.sparse.linalg.svds`. It requires strictly From 1de6f67f34f381ad5611ae8c33fee52c10c5d73c Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Tue, 3 Oct 2023 08:42:23 +0200 Subject: [PATCH 23/70] Improve solver equivalence tests --- sklearn/decomposition/_pca.py | 15 ++++++++--- .../tests/test_incremental_pca.py | 4 +-- sklearn/decomposition/tests/test_pca.py | 27 ++++++++++++++++++- 3 files changed, 40 insertions(+), 6 deletions(-) diff --git a/sklearn/decomposition/_pca.py b/sklearn/decomposition/_pca.py index 6c99f8f3c6ec5..b1bcc8167176b 100644 --- a/sklearn/decomposition/_pca.py +++ b/sklearn/decomposition/_pca.py @@ -583,9 +583,18 @@ def _fit_full(self, X, n_components): assert self._fit_svd_solver == "covariance_eigh" C = X.T @ X evals, Evecs = xp.linalg.eigh(C) - evals[evals < 0] = 0.0 - S = xp.sqrt(xp.flip(evals, axis=0)) - Vt = xp.flip(Evecs, axis=1).T + evals = xp.flip(evals, axis=0) + Evecs = xp.flip(Evecs, axis=1) + + # Avoid numerical problems for zero or near-zero eigenvalues caused + # by rounding errors as they would lead to non-finite transformed + # values: the square root is undefined for near-zero negative + # values and furthermore, whitening divides the transformed data by + # the explained variance. + threshold = float(xp.finfo(evals.dtype).eps) + evals[evals < threshold] = threshold + S = xp.sqrt(evals) + Vt = Evecs.T U = None # flip eigenvectors' sign to enforce deterministic output diff --git a/sklearn/decomposition/tests/test_incremental_pca.py b/sklearn/decomposition/tests/test_incremental_pca.py index b6a082b3fb9da..030f9a101c47f 100644 --- a/sklearn/decomposition/tests/test_incremental_pca.py +++ b/sklearn/decomposition/tests/test_incremental_pca.py @@ -383,10 +383,10 @@ def test_singular_values(): assert_array_almost_equal(ipca.singular_values_, [3.142, 2.718, 1.0], 14) -def test_whitening(): +def test_whitening(global_random_seed): # Test that PCA and IncrementalPCA transforms match to sign flip. X = datasets.make_low_rank_matrix( - 1000, 10, tail_strength=0.0, effective_rank=2, random_state=1999 + 1000, 10, tail_strength=0.0, effective_rank=2, random_state=global_random_seed ) prec = 3 for nc in [None, 9]: diff --git a/sklearn/decomposition/tests/test_pca.py b/sklearn/decomposition/tests/test_pca.py index 59e4960895c43..f12ef6523e039 100644 --- a/sklearn/decomposition/tests/test_pca.py +++ b/sklearn/decomposition/tests/test_pca.py @@ -160,16 +160,20 @@ def test_pca_solver_equivalence( # Test all components to high precision. n_components = None - pca_full = PCA(n_components=n_components, svd_solver="full") + pca_full = PCA(n_components=n_components, svd_solver="full", whiten=whiten) pca_other = PCA( n_components=n_components, svd_solver=other_svd_solver, + whiten=whiten, random_state=global_random_seed, **extra_other_kwargs, ) X_trans_full_train = pca_full.fit_transform(X_train) + assert np.isfinite(X_trans_full_train).all() X_trans_other_train = pca_other.fit_transform(X_train) + assert np.isfinite(X_trans_other_train).all() + assert (pca_full.explained_variance_ >= 0).all() assert_allclose(pca_full.explained_variance_, pca_other.explained_variance_, **tols) assert_allclose( pca_full.explained_variance_ratio_, @@ -177,7 +181,9 @@ def test_pca_solver_equivalence( **tols, ) reference_components = pca_full.components_ + assert np.isfinite(reference_components).all() other_components = pca_other.components_ + assert np.isfinite(other_components).all() # For some choice of n_components and data distribution, some components # might be pure noise, let's ignore them in the comparison: @@ -193,9 +199,28 @@ def test_pca_solver_equivalence( # And similarly for the output of transform on new data (except for the # last component that can be underdetermined): X_trans_full_test = pca_full.transform(X_test) + assert np.isfinite(X_trans_full_test).all() X_trans_other_test = pca_other.transform(X_test) + assert np.isfinite(X_trans_other_test).all() assert_allclose(X_trans_other_test[:, stable], X_trans_full_test[:, stable], **tols) + # Check that inverse transform reconstructions for both solvers are + # compatible. + X_recons_full_test = pca_full.inverse_transform(X_trans_full_test) + assert np.isfinite(X_recons_full_test).all() + X_recons_other_test = pca_other.inverse_transform(X_trans_other_test) + assert np.isfinite(X_recons_other_test).all() + + if n_components is None and X_train.shape[0] > X_train.shape[1]: + # In this case, both models should be able to reconstruct the data, + # even in the presence of noisy components. + assert_allclose(X_recons_full_test, X_test, **tols) + assert_allclose(X_recons_other_test, X_test, **tols) + elif pca_full.explained_variance_.min() > 1e-12: + # In the absence of noisy components, both models should be able to + # reconstruct the same low-rank approximation of the original data. + assert_allclose(X_recons_full_test, X_recons_other_test, **tols) + @pytest.mark.parametrize( "X", From b832aea6d07fc1389b5ff5dd4eb9e01bad4a68f1 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Tue, 3 Oct 2023 10:13:30 +0200 Subject: [PATCH 24/70] DOC make compose doc more interesting and robust to rounding errors --- doc/modules/compose.rst | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/doc/modules/compose.rst b/doc/modules/compose.rst index 2ab77195d169f..76dc1da3a3b75 100644 --- a/doc/modules/compose.rst +++ b/doc/modules/compose.rst @@ -253,14 +253,14 @@ inspect the original instance such as:: >>> from sklearn.datasets import load_digits >>> X_digits, y_digits = load_digits(return_X_y=True) - >>> pca1 = PCA() + >>> pca1 = PCA(n_components=10) >>> svm1 = SVC() >>> pipe = Pipeline([('reduce_dim', pca1), ('clf', svm1)]) >>> pipe.fit(X_digits, y_digits) - Pipeline(steps=[('reduce_dim', PCA()), ('clf', SVC())]) + Pipeline(steps=[('reduce_dim', PCA(n_components=10)), ('clf', SVC())]) >>> # The pca instance can be inspected directly - >>> print(pca1.components_) - [[-1.77484909e-19 ... 4.07058917e-18]] + >>> pca1.components_.shape + (10, 64) Enabling caching triggers a clone of the transformers before fitting. @@ -273,15 +273,15 @@ Instead, use the attribute ``named_steps`` to inspect estimators within the pipeline:: >>> cachedir = mkdtemp() - >>> pca2 = PCA() + >>> pca2 = PCA(n_components=10) >>> svm2 = SVC() >>> cached_pipe = Pipeline([('reduce_dim', pca2), ('clf', svm2)], ... memory=cachedir) >>> cached_pipe.fit(X_digits, y_digits) Pipeline(memory=..., - steps=[('reduce_dim', PCA()), ('clf', SVC())]) - >>> print(cached_pipe.named_steps['reduce_dim'].components_) - [[-1.77484909e-19 ... 4.07058917e-18]] + steps=[('reduce_dim', PCA(n_components=10)), ('clf', SVC())]) + >>> cached_pipe.named_steps['reduce_dim'].components_.shape + (10, 64) >>> # Remove the cache directory >>> rmtree(cachedir) From b3eec64b8a87ac2068e126f4bef77b1bec4491ae Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Tue, 3 Oct 2023 15:15:25 +0200 Subject: [PATCH 25/70] Mark the stricter test_pca_solver_equivalence with the arpack solver xfail for old scipy versions --- sklearn/decomposition/tests/test_pca.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/sklearn/decomposition/tests/test_pca.py b/sklearn/decomposition/tests/test_pca.py index f12ef6523e039..cdbeb90321174 100644 --- a/sklearn/decomposition/tests/test_pca.py +++ b/sklearn/decomposition/tests/test_pca.py @@ -21,7 +21,7 @@ _get_check_estimator_ids, check_array_api_input_and_values, ) -from sklearn.utils.fixes import CSR_CONTAINERS +from sklearn.utils.fixes import CSR_CONTAINERS, parse_version, sp_version iris = datasets.load_iris() PCA_SOLVERS = ["full", "covariance_eigh", "arpack", "randomized", "auto"] @@ -124,6 +124,12 @@ def test_whitening(solver, copy): def test_pca_solver_equivalence( other_svd_solver, data_shape, rank_deficient, whiten, global_random_seed ): + if sp_version < parse_version("1.7") and other_svd_solver == "arpack": + pytest.xfail( + "Older scipy versions have a numerical stability problem that makes" + " `transform` output non-finite results." + ) + if data_shape == "tall": n_samples, n_features = 100, 30 else: From 4120581daaa09e9dd74d81c2b8e29e4824451d0a Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Tue, 3 Oct 2023 18:15:11 +0200 Subject: [PATCH 26/70] Make PCA(whiten=True).transform robust to rank deficient training data --- doc/whats_new/v1.4.rst | 10 +++++++++- sklearn/decomposition/_base.py | 13 ++++++++++++- sklearn/decomposition/_pca.py | 9 +++------ sklearn/decomposition/tests/test_pca.py | 18 +++++++++++------- 4 files changed, 35 insertions(+), 15 deletions(-) diff --git a/doc/whats_new/v1.4.rst b/doc/whats_new/v1.4.rst index af0f96513dfb2..cd82320dc754a 100644 --- a/doc/whats_new/v1.4.rst +++ b/doc/whats_new/v1.4.rst @@ -206,7 +206,9 @@ Changelog - |Enhancement| :class:`decomposition.PCA` now supports a new solver option named `svd_solver="covariance_eigh"` which offers an order of magnitude speed-up and reduced memory usage for datasets with a large number of data - points and a small number of features (say, less than 1000). + points and a small number of features (say, less than 500). The + `svd_solver="auto"` option has been updated to use the new solver + automatically for such datasets. :pr:`27491` by :user:`Olivier Grisel `. - |Enhancement| :class:`decomposition.PCA` now supports the Array API for the @@ -215,6 +217,12 @@ Changelog :pr:`26315` and :pr:`27098` by :user:`Mateusz Sokół `, :user:`Olivier Grisel ` and :user:`Edoardo Abati `. +- |Fix| :class:`decomposition.PCA` fit with `svd_solver="arpack"`, + `whiten=True` and a value for `n_components` that is larger than the rank of + the training set, no longer returns infinite values when transforming + held-out data. + :pr:`27491` by :user:`Olivier Grisel `. + :mod:`sklearn.ensemble` ....................... diff --git a/sklearn/decomposition/_base.py b/sklearn/decomposition/_base.py index dc27a6c1ac123..052b7a282326b 100644 --- a/sklearn/decomposition/_base.py +++ b/sklearn/decomposition/_base.py @@ -152,7 +152,18 @@ def _transform(self, X_centered, xp=None): ) X_transformed = X_centered @ self.components_.T if self.whiten: - X_transformed /= xp.sqrt(self.explained_variance_) + # For some solvers (such as "arpack" and "covariance_eigh"), on + # rank deficient data, some components can have a variance + # arbitrarily to zero, leading to non-finite results when + # whitening. To avoid this problem we clip the variance below. + scale = xp.sqrt(self.explained_variance_) + min_scale = xp.asarray( + [xp.finfo(scale.dtype).eps], + dtype=scale.dtype, + device=device(scale), + ) + scale = xp.where(scale > min_scale, scale, min_scale) + X_transformed /= scale return X_transformed def inverse_transform(self, X): diff --git a/sklearn/decomposition/_pca.py b/sklearn/decomposition/_pca.py index b1bcc8167176b..e81c9df8939ad 100644 --- a/sklearn/decomposition/_pca.py +++ b/sklearn/decomposition/_pca.py @@ -587,12 +587,9 @@ def _fit_full(self, X, n_components): Evecs = xp.flip(Evecs, axis=1) # Avoid numerical problems for zero or near-zero eigenvalues caused - # by rounding errors as they would lead to non-finite transformed - # values: the square root is undefined for near-zero negative - # values and furthermore, whitening divides the transformed data by - # the explained variance. - threshold = float(xp.finfo(evals.dtype).eps) - evals[evals < threshold] = threshold + # by rounding errors: the square root is undefined for near-zero + # negative values. + evals[evals < 0.0] = 0.0 S = xp.sqrt(evals) Vt = Evecs.T U = None diff --git a/sklearn/decomposition/tests/test_pca.py b/sklearn/decomposition/tests/test_pca.py index cdbeb90321174..f0ef4f8f35428 100644 --- a/sklearn/decomposition/tests/test_pca.py +++ b/sklearn/decomposition/tests/test_pca.py @@ -21,7 +21,7 @@ _get_check_estimator_ids, check_array_api_input_and_values, ) -from sklearn.utils.fixes import CSR_CONTAINERS, parse_version, sp_version +from sklearn.utils.fixes import CSR_CONTAINERS iris = datasets.load_iris() PCA_SOLVERS = ["full", "covariance_eigh", "arpack", "randomized", "auto"] @@ -124,12 +124,6 @@ def test_whitening(solver, copy): def test_pca_solver_equivalence( other_svd_solver, data_shape, rank_deficient, whiten, global_random_seed ): - if sp_version < parse_version("1.7") and other_svd_solver == "arpack": - pytest.xfail( - "Older scipy versions have a numerical stability problem that makes" - " `transform` output non-finite results." - ) - if data_shape == "tall": n_samples, n_features = 100, 30 else: @@ -226,6 +220,16 @@ def test_pca_solver_equivalence( # In the absence of noisy components, both models should be able to # reconstruct the same low-rank approximation of the original data. assert_allclose(X_recons_full_test, X_recons_other_test, **tols) + else: + # When n_features > n_samples and n_components is larger than the rank + # of the training set, the output of the `inverse_transform` function + # is ill-defined. We can only check that we reach the same fixed point + # after another round of transform: + assert_allclose( + pca_full.transform(X_recons_full_test)[:, stable], + pca_other.transform(X_recons_other_test)[:, stable], + **tols, + ) @pytest.mark.parametrize( From ec87a1447b776c12f224901872dc6ef057b25d56 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Tue, 3 Oct 2023 18:22:17 +0200 Subject: [PATCH 27/70] Update inline comment about clipping negative eigenvalues --- sklearn/decomposition/_pca.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sklearn/decomposition/_pca.py b/sklearn/decomposition/_pca.py index e81c9df8939ad..5e25c9c79690f 100644 --- a/sklearn/decomposition/_pca.py +++ b/sklearn/decomposition/_pca.py @@ -586,9 +586,10 @@ def _fit_full(self, X, n_components): evals = xp.flip(evals, axis=0) Evecs = xp.flip(Evecs, axis=1) - # Avoid numerical problems for zero or near-zero eigenvalues caused - # by rounding errors: the square root is undefined for near-zero - # negative values. + # The covariance matrix C is positive semi-definite by + # construction. However, the eigenvalues returned by xp.linalg.eigh + # can be slightly negative due to numerical errors. This would be + # an issue for the subsequent sqrt, hence the manual clipping. evals[evals < 0.0] = 0.0 S = xp.sqrt(evals) Vt = Evecs.T From 8be2f91976e3d4cd339846ccca7589bd9256c64d Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Tue, 3 Oct 2023 18:24:54 +0200 Subject: [PATCH 28/70] Simpler way to clip the whitening scale --- sklearn/decomposition/_base.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/sklearn/decomposition/_base.py b/sklearn/decomposition/_base.py index 052b7a282326b..7a1db85502e4f 100644 --- a/sklearn/decomposition/_base.py +++ b/sklearn/decomposition/_base.py @@ -157,12 +157,8 @@ def _transform(self, X_centered, xp=None): # arbitrarily to zero, leading to non-finite results when # whitening. To avoid this problem we clip the variance below. scale = xp.sqrt(self.explained_variance_) - min_scale = xp.asarray( - [xp.finfo(scale.dtype).eps], - dtype=scale.dtype, - device=device(scale), - ) - scale = xp.where(scale > min_scale, scale, min_scale) + min_scale = xp.finfo(scale.dtype).eps + scale[scale < min_scale] = min_scale X_transformed /= scale return X_transformed From 29372a16ae7ccbec18a42af3e283b81de1624162 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Tue, 3 Oct 2023 18:39:01 +0200 Subject: [PATCH 29/70] Test solver equivalence with float32 data --- sklearn/decomposition/tests/test_pca.py | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/sklearn/decomposition/tests/test_pca.py b/sklearn/decomposition/tests/test_pca.py index f0ef4f8f35428..8042124253a6a 100644 --- a/sklearn/decomposition/tests/test_pca.py +++ b/sklearn/decomposition/tests/test_pca.py @@ -122,7 +122,12 @@ def test_whitening(solver, copy): @pytest.mark.parametrize("rank_deficient", [False, True]) @pytest.mark.parametrize("whiten", [False, True]) def test_pca_solver_equivalence( - other_svd_solver, data_shape, rank_deficient, whiten, global_random_seed + other_svd_solver, + data_shape, + rank_deficient, + whiten, + global_random_seed, + global_dtype, ): if data_shape == "tall": n_samples, n_features = 100, 30 @@ -142,9 +147,15 @@ def test_pca_solver_equivalence( n_features=n_features, random_state=global_random_seed, ) + X = X.astype(global_dtype, copy=False) X_train, X_test = X[:n_samples], X[n_samples:] - tols = dict(atol=1e-10, rtol=1e-12) + if global_dtype == np.float32: + tols = dict(atol=5e-3, rtol=1e-5) + variance_threshold = 1e-5 + else: + tols = dict(atol=1e-10, rtol=1e-12) + variance_threshold = 1e-12 extra_other_kwargs = {} if other_svd_solver == "randomized": @@ -170,8 +181,10 @@ def test_pca_solver_equivalence( ) X_trans_full_train = pca_full.fit_transform(X_train) assert np.isfinite(X_trans_full_train).all() + assert X_trans_full_train.dtype == global_dtype X_trans_other_train = pca_other.fit_transform(X_train) assert np.isfinite(X_trans_other_train).all() + assert X_trans_other_train.dtype == global_dtype assert (pca_full.explained_variance_ >= 0).all() assert_allclose(pca_full.explained_variance_, pca_other.explained_variance_, **tols) @@ -187,7 +200,7 @@ def test_pca_solver_equivalence( # For some choice of n_components and data distribution, some components # might be pure noise, let's ignore them in the comparison: - stable = pca_full.explained_variance_ > 1e-12 + stable = pca_full.explained_variance_ > variance_threshold assert stable.sum() > 1 assert_allclose(reference_components[stable], other_components[stable], **tols) @@ -200,16 +213,20 @@ def test_pca_solver_equivalence( # last component that can be underdetermined): X_trans_full_test = pca_full.transform(X_test) assert np.isfinite(X_trans_full_test).all() + assert X_trans_full_test.dtype == global_dtype X_trans_other_test = pca_other.transform(X_test) assert np.isfinite(X_trans_other_test).all() + assert X_trans_other_test.dtype == global_dtype assert_allclose(X_trans_other_test[:, stable], X_trans_full_test[:, stable], **tols) # Check that inverse transform reconstructions for both solvers are # compatible. X_recons_full_test = pca_full.inverse_transform(X_trans_full_test) assert np.isfinite(X_recons_full_test).all() + assert X_recons_full_test.dtype == global_dtype X_recons_other_test = pca_other.inverse_transform(X_trans_other_test) assert np.isfinite(X_recons_other_test).all() + assert X_recons_other_test.dtype == global_dtype if n_components is None and X_train.shape[0] > X_train.shape[1]: # In this case, both models should be able to reconstruct the data, From efc5a42b24ab7eaf29dd4e3d128b9ac47575c328 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Tue, 3 Oct 2023 18:47:54 +0200 Subject: [PATCH 30/70] Forgot to update on condition based on variance_threshold --- sklearn/decomposition/tests/test_pca.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/decomposition/tests/test_pca.py b/sklearn/decomposition/tests/test_pca.py index 8042124253a6a..f0a45fb801bb6 100644 --- a/sklearn/decomposition/tests/test_pca.py +++ b/sklearn/decomposition/tests/test_pca.py @@ -233,7 +233,7 @@ def test_pca_solver_equivalence( # even in the presence of noisy components. assert_allclose(X_recons_full_test, X_test, **tols) assert_allclose(X_recons_other_test, X_test, **tols) - elif pca_full.explained_variance_.min() > 1e-12: + elif pca_full.explained_variance_.min() > variance_threshold: # In the absence of noisy components, both models should be able to # reconstruct the same low-rank approximation of the original data. assert_allclose(X_recons_full_test, X_recons_other_test, **tols) From f6e88b465faa09fb61684973c28fe34c624eb383 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Tue, 3 Oct 2023 18:58:32 +0200 Subject: [PATCH 31/70] Typo --- sklearn/decomposition/_pca.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/decomposition/_pca.py b/sklearn/decomposition/_pca.py index 5e25c9c79690f..fabec80484fd1 100644 --- a/sklearn/decomposition/_pca.py +++ b/sklearn/decomposition/_pca.py @@ -195,7 +195,7 @@ class PCA(_BasePCA): typically using LAPACK and select the components by postprocessing. This solver is very efficient when the number of features is small and not tractable otherwise (large memory footprint required to - materialize the covariance matrix). Also not that compared to the + materialize the covariance matrix). Also note that compared to the "full" solver, this solver effectively doubles the condition number and is therefore less numerical stable (e.g. on input data with a large range of singular values). From 08ae5d54f77043d0f0b8c4efd27f0df741ef13b2 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Tue, 3 Oct 2023 21:03:53 +0200 Subject: [PATCH 32/70] [azure parallel] [all random seeds] test_pca_solver_equivalence test_whitening From ed83ed5bd30e54982dcccf3b89c89abc64dfbdd3 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Tue, 3 Oct 2023 21:22:13 +0200 Subject: [PATCH 33/70] Relax float32 test tolerance [azure parallel] [all random seeds] test_pca_solver_equivalence test_whitening --- sklearn/decomposition/tests/test_pca.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/decomposition/tests/test_pca.py b/sklearn/decomposition/tests/test_pca.py index f0a45fb801bb6..a780f551972f3 100644 --- a/sklearn/decomposition/tests/test_pca.py +++ b/sklearn/decomposition/tests/test_pca.py @@ -151,7 +151,7 @@ def test_pca_solver_equivalence( X_train, X_test = X[:n_samples], X[n_samples:] if global_dtype == np.float32: - tols = dict(atol=5e-3, rtol=1e-5) + tols = dict(atol=1e-2, rtol=1e-5) variance_threshold = 1e-5 else: tols = dict(atol=1e-10, rtol=1e-12) From 74638b78004921f7efcde314576a64d143a66880 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Wed, 4 Oct 2023 10:56:10 +0200 Subject: [PATCH 34/70] More explicit conditions for assertions in test_pca_solver_equivalence [all random seeds] [azure parallel] --- sklearn/decomposition/tests/test_pca.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sklearn/decomposition/tests/test_pca.py b/sklearn/decomposition/tests/test_pca.py index a780f551972f3..3da7f3924f970 100644 --- a/sklearn/decomposition/tests/test_pca.py +++ b/sklearn/decomposition/tests/test_pca.py @@ -228,12 +228,14 @@ def test_pca_solver_equivalence( assert np.isfinite(X_recons_other_test).all() assert X_recons_other_test.dtype == global_dtype - if n_components is None and X_train.shape[0] > X_train.shape[1]: + effective_rank = np.linalg.matrix_rank(X_train) + effective_n_components = pca_full.n_components_ + if effective_n_components > effective_rank and X_train.shape[0] > effective_rank: # In this case, both models should be able to reconstruct the data, # even in the presence of noisy components. assert_allclose(X_recons_full_test, X_test, **tols) assert_allclose(X_recons_other_test, X_test, **tols) - elif pca_full.explained_variance_.min() > variance_threshold: + elif effective_n_components < effective_rank: # In the absence of noisy components, both models should be able to # reconstruct the same low-rank approximation of the original data. assert_allclose(X_recons_full_test, X_recons_other_test, **tols) From 3799e9256e67be438cc8dfd629beeeddb277ff6c Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Wed, 4 Oct 2023 11:30:32 +0200 Subject: [PATCH 35/70] [azure parallel] [all random seeds] test_pca_solver_equivalence From 59aebb7033cd6e83c99dfc4bb1a687ef27d49979 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Wed, 4 Oct 2023 18:53:07 +0200 Subject: [PATCH 36/70] ENH do not center a priori to spare memory and make the new solver run even faster --- sklearn/decomposition/_base.py | 15 +++++----- sklearn/decomposition/_pca.py | 39 ++++++++++++++++--------- sklearn/decomposition/tests/test_pca.py | 31 +++++++++++--------- 3 files changed, 50 insertions(+), 35 deletions(-) diff --git a/sklearn/decomposition/_base.py b/sklearn/decomposition/_base.py index 7a1db85502e4f..0bc369f60dc93 100644 --- a/sklearn/decomposition/_base.py +++ b/sklearn/decomposition/_base.py @@ -140,17 +140,16 @@ def transform(self, X): check_is_fitted(self) - X = self._validate_data(X, dtype=[xp.float64, xp.float32], reset=False) - if self.mean_ is not None: - X = X - self.mean_ + X = self._validate_data( + X, dtype=[xp.float64, xp.float32], copy=False, reset=False + ) return self._transform(X, xp) - def _transform(self, X_centered, xp=None): + def _transform(self, X, xp=None): if xp is None: - xp, _ = get_namespace( - X_centered, self.components_, self.explained_variance_ - ) - X_transformed = X_centered @ self.components_.T + xp, _ = get_namespace(X, self.components_, self.explained_variance_) + X_transformed = X @ self.components_.T + X_transformed -= xp.reshape(self.mean_, (1, -1)) @ self.components_.T if self.whiten: # For some solvers (such as "arpack" and "covariance_eigh"), on # rank deficient data, some components can have a variance diff --git a/sklearn/decomposition/_pca.py b/sklearn/decomposition/_pca.py index fabec80484fd1..b23abbb383e03 100644 --- a/sklearn/decomposition/_pca.py +++ b/sklearn/decomposition/_pca.py @@ -478,7 +478,7 @@ def fit_transform(self, X, y=None): This method returns a Fortran-ordered array. To convert it to a C-ordered array, use 'np.ascontiguousarray'. """ - U, S, Vt, X_centered = self._fit(X) + U, S, Vt, X_validated = self._fit(X) if U is not None: U = U[:, : self.n_components_] @@ -491,7 +491,7 @@ def fit_transform(self, X, y=None): return U else: - return self._transform(X_centered) + return self._transform(X) def _fit(self, X): """Dispatch to the right submethod depending on the chosen solver.""" @@ -510,8 +510,11 @@ def _fit(self, X): "PCA with svd_solver='arpack' is not supported for Array API inputs." ) + # Validate the data, without forcing a copy as it's not required for + # the `covariance_eigh` dataset and would be wasteful for large + # datasets. X = self._validate_data( - X, dtype=[xp.float64, xp.float32], ensure_2d=True, copy=self.copy + X, dtype=[xp.float64, xp.float32], ensure_2d=True, copy=False ) # Handle n_components==None @@ -563,11 +566,11 @@ def _fit_full(self, X, n_components): f"svd_solver={self._fit_svd_solver!r}" ) - # Center data self.mean_ = xp.mean(X, axis=0) - X -= self.mean_ - if self._fit_svd_solver == "full": + X_centered = xp.asarray(X, copy=True) if self.copy else X + X_centered -= self.mean_ + if not is_array_api_compliant: # Use scipy.linalg with NumPy/SciPy inputs for the sake of not # introducing unanticipated behavior changes. In the long run we @@ -576,12 +579,20 @@ def _fit_full(self, X, n_components): # scipy's. It's not 100% clear whether they use the same LAPACK # solver by default though (assuming both are built against the # same BLAS). - U, S, Vt = linalg.svd(X, full_matrices=False) + U, S, Vt = linalg.svd(X_centered, full_matrices=False) else: - U, S, Vt = xp.linalg.svd(X, full_matrices=False) + U, S, Vt = xp.linalg.svd(X_centered, full_matrices=False) else: assert self._fit_svd_solver == "covariance_eigh" + # In the following, we center the covariance matrix C without + # centering the data X to avoid an unecessary copy of X. Note that + # the mean_ attribute is also needed by the transform method. C = X.T @ X + C -= ( + X.shape[0] + * xp.reshape(self.mean_, (-1, 1)) + * xp.reshape(self.mean_, (1, -1)) + ) evals, Evecs = xp.linalg.eigh(C) evals = xp.flip(evals, axis=0) Evecs = xp.flip(Evecs, axis=1) @@ -663,13 +674,13 @@ def _fit_truncated(self, X, n_components, svd_solver): random_state = check_random_state(self.random_state) - # Center data self.mean_ = xp.mean(X, axis=0) - X -= self.mean_ + X_centered = xp.asarray(X, copy=True) if self.copy else X + X_centered -= self.mean_ if svd_solver == "arpack": v0 = _init_arpack_v0(min(X.shape), random_state) - U, S, Vt = svds(X, k=n_components, tol=self.tol, v0=v0) + U, S, Vt = svds(X_centered, k=n_components, tol=self.tol, v0=v0) # svds doesn't abide by scipy.linalg.svd/randomized_svd # conventions, so reverse its outputs. S = S[::-1] @@ -679,7 +690,7 @@ def _fit_truncated(self, X, n_components, svd_solver): elif svd_solver == "randomized": # sign flipping is done inside U, S, Vt = randomized_svd( - X, + X_centered, n_components=n_components, n_oversamples=self.n_oversamples, n_iter=self.iterated_power, @@ -699,8 +710,8 @@ def _fit_truncated(self, X, n_components, svd_solver): # Workaround in-place variance calculation since at the time numpy # did not have a way to calculate variance in-place. N = X.shape[0] - 1 - X **= 2 - total_var = xp.sum(xp.sum(X, axis=0) / N) + X_centered **= 2 + total_var = xp.sum(xp.sum(X_centered, axis=0) / N) self.explained_variance_ratio_ = self.explained_variance_ / total_var self.singular_values_ = xp.asarray(S, copy=True) # Store the singular values. diff --git a/sklearn/decomposition/tests/test_pca.py b/sklearn/decomposition/tests/test_pca.py index 3da7f3924f970..810de619d9331 100644 --- a/sklearn/decomposition/tests/test_pca.py +++ b/sklearn/decomposition/tests/test_pca.py @@ -647,28 +647,33 @@ def test_pca_deterministic_output(svd_solver): @pytest.mark.parametrize("svd_solver", PCA_SOLVERS) -def test_pca_dtype_preservation(svd_solver): - check_pca_float_dtype_preservation(svd_solver) +def test_pca_dtype_preservation(svd_solver, global_random_seed): + check_pca_float_dtype_preservation(svd_solver, global_random_seed) check_pca_int_dtype_upcast_to_double(svd_solver) -def check_pca_float_dtype_preservation(svd_solver): +def check_pca_float_dtype_preservation(svd_solver, seed): # Ensure that PCA does not upscale the dtype when input is float32 - X_64 = np.random.RandomState(0).rand(1000, 4).astype(np.float64, copy=False) - X_32 = X_64.astype(np.float32) + X = np.random.RandomState(seed).rand(1000, 4) + X_float64 = X.astype(np.float64, copy=False) + X_float32 = X.astype(np.float32) - pca_64 = PCA(n_components=3, svd_solver=svd_solver, random_state=0).fit(X_64) - pca_32 = PCA(n_components=3, svd_solver=svd_solver, random_state=0).fit(X_32) + pca_64 = PCA(n_components=3, svd_solver=svd_solver, random_state=seed).fit( + X_float64 + ) + pca_32 = PCA(n_components=3, svd_solver=svd_solver, random_state=seed).fit( + X_float32 + ) assert pca_64.components_.dtype == np.float64 assert pca_32.components_.dtype == np.float32 - assert pca_64.transform(X_64).dtype == np.float64 - assert pca_32.transform(X_32).dtype == np.float32 + assert pca_64.transform(X_float64).dtype == np.float64 + assert pca_32.transform(X_float32).dtype == np.float32 - # the rtol is set such that the test passes on all platforms tested on - # conda-forge: PR#15775 - # see: https://github.com/conda-forge/scikit-learn-feedstock/pull/113 - assert_allclose(pca_64.components_, pca_32.components_, rtol=2e-4) + # The atol and rtol are set such that the test passes for all random seeds + # on all supported platforms on our CI and conda-forge with the default + # random seed. + assert_allclose(pca_64.components_, pca_32.components_, rtol=1e-3, atol=1e-3) def check_pca_int_dtype_upcast_to_double(svd_solver): From 8c017b433cff3a228be8ddb87c3c768d397b772e Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Wed, 4 Oct 2023 19:56:17 +0200 Subject: [PATCH 37/70] [azure parallel] [all random seeds] test_pca_solver_equivalence" test_whitening test_pca_dtype_preservation From 45928ae52b2d03f3d0c4049b7e1982aae71f7335 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Wed, 4 Oct 2023 20:04:10 +0200 Subject: [PATCH 38/70] [azure parallel] [all random seeds] test_pca_solver_equivalence test_whitening test_pca_dtype_preservation From 01f6fd62fc733068f7443b8f5abc9a4bac933783 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 5 Oct 2023 08:04:04 +0200 Subject: [PATCH 39/70] Propagate xp namespace to private methods --- benchmarks/bench_pca_solvers.py | 2 +- sklearn/decomposition/_base.py | 6 ++---- sklearn/decomposition/_pca.py | 21 +++++++++------------ 3 files changed, 12 insertions(+), 17 deletions(-) diff --git a/benchmarks/bench_pca_solvers.py b/benchmarks/bench_pca_solvers.py index b76df1120cbb1..605406974af01 100644 --- a/benchmarks/bench_pca_solvers.py +++ b/benchmarks/bench_pca_solvers.py @@ -11,7 +11,7 @@ # # We generate synthetic data with dimensions that are useful to plot: # - time vs n_samples for a fixed n_features and, -# - time vs n_features for a fixed n_samples for a fixed n_features. +# - time vs n_features for a fixed n_samples for a fixed n_features. import itertools from math import log10 from time import perf_counter diff --git a/sklearn/decomposition/_base.py b/sklearn/decomposition/_base.py index 0bc369f60dc93..e4fecf6bf3fbc 100644 --- a/sklearn/decomposition/_base.py +++ b/sklearn/decomposition/_base.py @@ -136,7 +136,7 @@ def transform(self, X): Projection of X in the first principal components, where `n_samples` is the number of samples and `n_components` is the number of the components. """ - xp, _ = get_namespace(X) + xp, _ = get_namespace(X, self.components_, self.explained_variance_) check_is_fitted(self) @@ -145,9 +145,7 @@ def transform(self, X): ) return self._transform(X, xp) - def _transform(self, X, xp=None): - if xp is None: - xp, _ = get_namespace(X, self.components_, self.explained_variance_) + def _transform(self, X, xp): X_transformed = X @ self.components_.T X_transformed -= xp.reshape(self.mean_, (1, -1)) @ self.components_.T if self.whiten: diff --git a/sklearn/decomposition/_pca.py b/sklearn/decomposition/_pca.py index b23abbb383e03..8317c60ab3546 100644 --- a/sklearn/decomposition/_pca.py +++ b/sklearn/decomposition/_pca.py @@ -478,7 +478,7 @@ def fit_transform(self, X, y=None): This method returns a Fortran-ordered array. To convert it to a C-ordered array, use 'np.ascontiguousarray'. """ - U, S, Vt, X_validated = self._fit(X) + U, S, Vt, X_validated, xp = self._fit(X) if U is not None: U = U[:, : self.n_components_] @@ -491,7 +491,7 @@ def fit_transform(self, X, y=None): return U else: - return self._transform(X) + return self._transform(X, xp) def _fit(self, X): """Dispatch to the right submethod depending on the chosen solver.""" @@ -544,14 +544,12 @@ def _fit(self, X): # Call different fits for either full or truncated SVD if self._fit_svd_solver in ("full", "covariance_eigh"): - return self._fit_full(X, n_components) + return self._fit_full(X, n_components, xp, is_array_api_compliant) elif self._fit_svd_solver in ["arpack", "randomized"]: - return self._fit_truncated(X, n_components, self._fit_svd_solver) + return self._fit_truncated(X, n_components, xp) - def _fit_full(self, X, n_components): + def _fit_full(self, X, n_components, xp, is_array_api_compliant): """Fit the model by computing full SVD on X.""" - xp, is_array_api_compliant = get_namespace(X) - n_samples, n_features = X.shape if n_components == "mle": @@ -642,16 +640,15 @@ def _fit_full(self, X, n_components): self.explained_variance_ratio_ = explained_variance_ratio_[:n_components] self.singular_values_ = singular_values_[:n_components] - return U, S, Vt, X + return U, S, Vt, X, xp - def _fit_truncated(self, X, n_components, svd_solver): + def _fit_truncated(self, X, n_components, xp): """Fit the model by computing truncated SVD (by ARPACK or randomized) on X. """ - xp, _ = get_namespace(X) - n_samples, n_features = X.shape + svd_solver = self._fit_svd_solver if isinstance(n_components, str): raise ValueError( "n_components=%r cannot be a string with svd_solver='%s'" @@ -722,7 +719,7 @@ def _fit_truncated(self, X, n_components, svd_solver): else: self.noise_variance_ = 0.0 - return U, S, Vt, X + return U, S, Vt, X, xp def score_samples(self, X): """Return the log-likelihood of each sample. From f3c33dc3d6b54197cf5dfe6ffad6b632976a964f Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 5 Oct 2023 08:12:51 +0200 Subject: [PATCH 40/70] Scale the covariance matrix before calling eigh to improve numerical stability --- sklearn/decomposition/_pca.py | 29 ++++++++++++++++++----------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/sklearn/decomposition/_pca.py b/sklearn/decomposition/_pca.py index 8317c60ab3546..bfecba98703af 100644 --- a/sklearn/decomposition/_pca.py +++ b/sklearn/decomposition/_pca.py @@ -580,28 +580,36 @@ def _fit_full(self, X, n_components, xp, is_array_api_compliant): U, S, Vt = linalg.svd(X_centered, full_matrices=False) else: U, S, Vt = xp.linalg.svd(X_centered, full_matrices=False) + explained_variance_ = (S**2) / (n_samples - 1) + else: assert self._fit_svd_solver == "covariance_eigh" - # In the following, we center the covariance matrix C without - # centering the data X to avoid an unecessary copy of X. Note that - # the mean_ attribute is also needed by the transform method. + # In the following, we center the covariance matrix C a-posteriori + # (without centering the data X first) to avoid an unecessary copy + # of X. Note that the mean_ attribute is still needed to center + # test data in the transform method. C = X.T @ X C -= ( - X.shape[0] + n_samples * xp.reshape(self.mean_, (-1, 1)) * xp.reshape(self.mean_, (1, -1)) ) - evals, Evecs = xp.linalg.eigh(C) - evals = xp.flip(evals, axis=0) - Evecs = xp.flip(Evecs, axis=1) + C /= n_samples - 1 + eigenvals, Eigenvecs = xp.linalg.eigh(C) + eigenvals = xp.flip(eigenvals, axis=0) + Eigenvecs = xp.flip(Eigenvecs, axis=1) # The covariance matrix C is positive semi-definite by # construction. However, the eigenvalues returned by xp.linalg.eigh # can be slightly negative due to numerical errors. This would be # an issue for the subsequent sqrt, hence the manual clipping. - evals[evals < 0.0] = 0.0 - S = xp.sqrt(evals) - Vt = Evecs.T + eigenvals[eigenvals < 0.0] = 0.0 + explained_variance_ = eigenvals + + # Re-construct synthetic SVD components to be consistent with the + # other solvers. + S = xp.sqrt(eigenvals * (n_samples - 1)) + Vt = Eigenvecs.T U = None # flip eigenvectors' sign to enforce deterministic output @@ -610,7 +618,6 @@ def _fit_full(self, X, n_components, xp, is_array_api_compliant): components_ = Vt # Get variance explained by singular values - explained_variance_ = (S**2) / (n_samples - 1) total_var = xp.sum(explained_variance_) explained_variance_ratio_ = explained_variance_ / total_var singular_values_ = xp.asarray(S, copy=True) # Store the singular values. From 64cbcc9e616864e8c4124b9eaa6855a8daa2c298 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 5 Oct 2023 09:55:25 +0200 Subject: [PATCH 41/70] Make it explicit when the X passed between private methods has been centered or not --- sklearn/decomposition/_base.py | 11 ++++++++--- sklearn/decomposition/_pca.py | 11 +++++++---- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/sklearn/decomposition/_base.py b/sklearn/decomposition/_base.py index e4fecf6bf3fbc..47f0074f639cf 100644 --- a/sklearn/decomposition/_base.py +++ b/sklearn/decomposition/_base.py @@ -143,11 +143,16 @@ def transform(self, X): X = self._validate_data( X, dtype=[xp.float64, xp.float32], copy=False, reset=False ) - return self._transform(X, xp) + return self._transform( + X, + xp=xp, + x_is_centered=False, + ) - def _transform(self, X, xp): + def _transform(self, X, xp, x_is_centered=False): X_transformed = X @ self.components_.T - X_transformed -= xp.reshape(self.mean_, (1, -1)) @ self.components_.T + if not x_is_centered: + X_transformed -= xp.reshape(self.mean_, (1, -1)) @ self.components_.T if self.whiten: # For some solvers (such as "arpack" and "covariance_eigh"), on # rank deficient data, some components can have a variance diff --git a/sklearn/decomposition/_pca.py b/sklearn/decomposition/_pca.py index bfecba98703af..31b59fbe3da92 100644 --- a/sklearn/decomposition/_pca.py +++ b/sklearn/decomposition/_pca.py @@ -478,7 +478,7 @@ def fit_transform(self, X, y=None): This method returns a Fortran-ordered array. To convert it to a C-ordered array, use 'np.ascontiguousarray'. """ - U, S, Vt, X_validated, xp = self._fit(X) + U, S, Vt, X, x_is_centered, xp = self._fit(X) if U is not None: U = U[:, : self.n_components_] @@ -491,7 +491,7 @@ def fit_transform(self, X, y=None): return U else: - return self._transform(X, xp) + return self._transform(X, xp, x_is_centered=x_is_centered) def _fit(self, X): """Dispatch to the right submethod depending on the chosen solver.""" @@ -568,6 +568,7 @@ def _fit_full(self, X, n_components, xp, is_array_api_compliant): if self._fit_svd_solver == "full": X_centered = xp.asarray(X, copy=True) if self.copy else X X_centered -= self.mean_ + x_is_centered = not self.copy if not is_array_api_compliant: # Use scipy.linalg with NumPy/SciPy inputs for the sake of not @@ -588,6 +589,7 @@ def _fit_full(self, X, n_components, xp, is_array_api_compliant): # (without centering the data X first) to avoid an unecessary copy # of X. Note that the mean_ attribute is still needed to center # test data in the transform method. + x_is_centered = False C = X.T @ X C -= ( n_samples @@ -647,7 +649,7 @@ def _fit_full(self, X, n_components, xp, is_array_api_compliant): self.explained_variance_ratio_ = explained_variance_ratio_[:n_components] self.singular_values_ = singular_values_[:n_components] - return U, S, Vt, X, xp + return U, S, Vt, X, x_is_centered, xp def _fit_truncated(self, X, n_components, xp): """Fit the model by computing truncated SVD (by ARPACK or randomized) @@ -681,6 +683,7 @@ def _fit_truncated(self, X, n_components, xp): self.mean_ = xp.mean(X, axis=0) X_centered = xp.asarray(X, copy=True) if self.copy else X X_centered -= self.mean_ + x_is_centered = not self.copy if svd_solver == "arpack": v0 = _init_arpack_v0(min(X.shape), random_state) @@ -726,7 +729,7 @@ def _fit_truncated(self, X, n_components, xp): else: self.noise_variance_ = 0.0 - return U, S, Vt, X, xp + return U, S, Vt, X, x_is_centered, xp def score_samples(self, X): """Return the log-likelihood of each sample. From 8284d25386bfc9574ff04d91aa76097a7e8337fc Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 5 Oct 2023 12:03:30 +0200 Subject: [PATCH 42/70] [azure parallel] [all random seeds] test_pca_solver_equivalence test_whitening test_pca_dtype_preservation From 828a0582295f6fca5faa84c2bbffe414263bdbc2 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 5 Oct 2023 14:16:38 +0200 Subject: [PATCH 43/70] Improve test [azure parallel] [all random seeds] test_pca_solver_equivalence --- sklearn/decomposition/tests/test_pca.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/sklearn/decomposition/tests/test_pca.py b/sklearn/decomposition/tests/test_pca.py index 810de619d9331..4bd44b7ecf899 100644 --- a/sklearn/decomposition/tests/test_pca.py +++ b/sklearn/decomposition/tests/test_pca.py @@ -145,8 +145,12 @@ def test_pca_solver_equivalence( X = make_low_rank_matrix( n_samples=n_samples + n_samples_test, n_features=n_features, + tail_strength=0.5, random_state=global_random_seed, ) + # With a non-zero tail strength, the data is actually full-rank. + rank = min(n_samples, n_features) + X = X.astype(global_dtype, copy=False) X_train, X_test = X[:n_samples], X[n_samples:] @@ -228,16 +232,16 @@ def test_pca_solver_equivalence( assert np.isfinite(X_recons_other_test).all() assert X_recons_other_test.dtype == global_dtype - effective_rank = np.linalg.matrix_rank(X_train) - effective_n_components = pca_full.n_components_ - if effective_n_components > effective_rank and X_train.shape[0] > effective_rank: - # In this case, both models should be able to reconstruct the data, - # even in the presence of noisy components. + if pca_full.components_.shape[0] == pca_full.components_.shape[1]: + # In this case, the models should have learned the same invertible + # transform. They should therefore both be able to reconstruct the test + # data. assert_allclose(X_recons_full_test, X_test, **tols) assert_allclose(X_recons_other_test, X_test, **tols) - elif effective_n_components < effective_rank: + elif pca_full.components_.shape[0] < rank: # In the absence of noisy components, both models should be able to # reconstruct the same low-rank approximation of the original data. + assert pca_full.explained_variance_.min() > variance_threshold assert_allclose(X_recons_full_test, X_recons_other_test, **tols) else: # When n_features > n_samples and n_components is larger than the rank From 56a7793ed36d455d0b83387cbc7ca2d93738eb1d Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 5 Oct 2023 16:33:22 +0200 Subject: [PATCH 44/70] OPTIM Ensure that components_ is contiguous --- doc/whats_new/v1.4.rst | 8 ++++++++ sklearn/decomposition/_pca.py | 19 +++++++++++++++---- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/doc/whats_new/v1.4.rst b/doc/whats_new/v1.4.rst index b66787da244ea..c241db6801bea 100644 --- a/doc/whats_new/v1.4.rst +++ b/doc/whats_new/v1.4.rst @@ -217,6 +217,14 @@ Changelog :mod:`sklearn.decomposition` ............................ +- |Efficiency| :class:`decomposition.PCA` with `svd_solver="full"` now assigns + a contiguous `components_` attribute instead of an non-contiguous slice of + the singular vectors. When `n_components << n_features`, this can save some + memory and more importantly help speed-up subsequent calls to the `transform` + method by more than an order of magnitude by leveraging cache locality of + BLAS GEMM on contiguous arrays. + :pr:`27491` by :user:`Olivier Grisel `. + - |Enhancement| An "auto" option was added to the `n_components` parameter of :func:`decomposition.non_negative_factorization`, :class:`decomposition.NMF` and :class:`decomposition.MiniBatchNMF` to automatically infer the number of components from W or H shapes diff --git a/sklearn/decomposition/_pca.py b/sklearn/decomposition/_pca.py index 31b59fbe3da92..bd6def31f7132 100644 --- a/sklearn/decomposition/_pca.py +++ b/sklearn/decomposition/_pca.py @@ -643,11 +643,22 @@ def _fit_full(self, X, n_components, xp, is_array_api_compliant): self.noise_variance_ = 0.0 self.n_samples_ = n_samples - self.components_ = components_[:n_components, :] self.n_components_ = n_components - self.explained_variance_ = explained_variance_[:n_components] - self.explained_variance_ratio_ = explained_variance_ratio_[:n_components] - self.singular_values_ = singular_values_[:n_components] + # Assign a copy of the result of the truncation of the components in + # order to: + # - release the memory used by the discarded components, + # - ensure that the kept components are allocated contiguously in memory + # to make the transform method faster by leveraging cache locality. + self.components_ = xp.asarray(components_[:n_components, :], copy=True) + + # We do the same for the other arrays for the sake of consistency. + self.explained_variance_ = xp.asarray( + explained_variance_[:n_components], copy=True + ) + self.explained_variance_ratio_ = xp.asarray( + explained_variance_ratio_[:n_components], copy=True + ) + self.singular_values_ = xp.asarray(singular_values_[:n_components], copy=True) return U, S, Vt, X, x_is_centered, xp From e797b53efd58f863355a27288e05578ce5ee69d8 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 5 Oct 2023 17:39:06 +0200 Subject: [PATCH 45/70] Update auto policy and bench_pca_solvers.py --- benchmarks/bench_pca_solvers.py | 19 +++++++++++-------- sklearn/decomposition/_pca.py | 4 ++-- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/benchmarks/bench_pca_solvers.py b/benchmarks/bench_pca_solvers.py index 605406974af01..337af3a42e900 100644 --- a/benchmarks/bench_pca_solvers.py +++ b/benchmarks/bench_pca_solvers.py @@ -21,23 +21,26 @@ import pandas as pd from sklearn import config_context -from sklearn.datasets import make_low_rank_matrix from sklearn.decomposition import PCA REF_DIMS = [100, 1000, 10_000] data_shapes = [] for ref_dim in REF_DIMS: - data_shapes.extend([(ref_dim, 10**i) for i in range(1, 9 - int(log10(ref_dim)))]) - data_shapes.extend([(10**i, ref_dim) for i in range(1, 9 - int(log10(ref_dim)))]) + data_shapes.extend([(ref_dim, 10**i) for i in range(1, 8 - int(log10(ref_dim)))]) + data_shapes.extend( + [(ref_dim, 3 * 10**i) for i in range(1, 8 - int(log10(ref_dim)))] + ) + data_shapes.extend([(10**i, ref_dim) for i in range(1, 8 - int(log10(ref_dim)))]) + data_shapes.extend( + [(3 * 10**i, ref_dim) for i in range(1, 8 - int(log10(ref_dim)))] + ) # Remove duplicates: data_shapes = sorted(set(data_shapes)) print("Generating test datasets...") -datasets = [ - make_low_rank_matrix(n_samples, n_features, random_state=0) - for n_samples, n_features in data_shapes -] +rng = np.random.default_rng(0) +datasets = [rng.normal(size=shape) for shape in data_shapes] # %% @@ -68,7 +71,7 @@ def measure_one(data, n_components, solver, method_name="fit"): if n_components >= min(data.shape): continue for solver in SOLVERS: - if solver == "covariance_eigh" and data.shape[1] > 1000: + if solver == "covariance_eigh" and data.shape[1] > 5000: # Too much memory and too slow. continue if solver in ["arpack", "full"] and log10(data.size) > 7: diff --git a/sklearn/decomposition/_pca.py b/sklearn/decomposition/_pca.py index bd6def31f7132..4021a00c4a8ca 100644 --- a/sklearn/decomposition/_pca.py +++ b/sklearn/decomposition/_pca.py @@ -179,7 +179,7 @@ class PCA(_BasePCA): default='auto' "auto" : The solver is selected by a default policy based on `X.shape` and - `n_components`: if the input data has fewer than 500 features and + `n_components`: if the input data has fewer than 1000 features and more than 10 times as many samples, then the more "covariance_eigh" solver is used. Otherwise, if the input data is larger than 500x500 and the number of components to extract is lower than 80% of the @@ -531,7 +531,7 @@ def _fit(self, X): if self._fit_svd_solver == "auto": # Tall and skinny problems are best handled by precomputing the # covariance matrix. - if X.shape[1] <= 500 and X.shape[0] >= 10 * X.shape[1]: + if X.shape[1] <= 1000 and X.shape[0] >= 10 * X.shape[1]: self._fit_svd_solver = "covariance_eigh" # Small problem or n_components == 'mle', just call full PCA elif max(X.shape) <= 500 or n_components == "mle": From 8fcf2fffebc5ac10948f859119cd611881360e13 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Fri, 6 Oct 2023 10:30:46 +0200 Subject: [PATCH 46/70] Single call to xp.sum to compute total_var. Co-authored-by: Christian Lorentzen --- sklearn/decomposition/_pca.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/decomposition/_pca.py b/sklearn/decomposition/_pca.py index 4021a00c4a8ca..2cc7c1078507f 100644 --- a/sklearn/decomposition/_pca.py +++ b/sklearn/decomposition/_pca.py @@ -729,7 +729,7 @@ def _fit_truncated(self, X, n_components, xp): # did not have a way to calculate variance in-place. N = X.shape[0] - 1 X_centered **= 2 - total_var = xp.sum(xp.sum(X_centered, axis=0) / N) + total_var = xp.sum(X_centered) / N self.explained_variance_ratio_ = self.explained_variance_ / total_var self.singular_values_ = xp.asarray(S, copy=True) # Store the singular values. From a6db51348207377abc5f3009ea6bc420c1e88c9a Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Fri, 13 Oct 2023 17:02:53 +0200 Subject: [PATCH 47/70] Apply suggestions from code review Co-authored-by: Tim Head --- sklearn/decomposition/_base.py | 2 +- sklearn/decomposition/_pca.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sklearn/decomposition/_base.py b/sklearn/decomposition/_base.py index 47f0074f639cf..14a918642d424 100644 --- a/sklearn/decomposition/_base.py +++ b/sklearn/decomposition/_base.py @@ -156,7 +156,7 @@ def _transform(self, X, xp, x_is_centered=False): if self.whiten: # For some solvers (such as "arpack" and "covariance_eigh"), on # rank deficient data, some components can have a variance - # arbitrarily to zero, leading to non-finite results when + # arbitrarily close to zero, leading to non-finite results when # whitening. To avoid this problem we clip the variance below. scale = xp.sqrt(self.explained_variance_) min_scale = xp.finfo(scale.dtype).eps diff --git a/sklearn/decomposition/_pca.py b/sklearn/decomposition/_pca.py index 2cc7c1078507f..206b806d75347 100644 --- a/sklearn/decomposition/_pca.py +++ b/sklearn/decomposition/_pca.py @@ -180,11 +180,11 @@ class PCA(_BasePCA): "auto" : The solver is selected by a default policy based on `X.shape` and `n_components`: if the input data has fewer than 1000 features and - more than 10 times as many samples, then the more "covariance_eigh" + more than 10 times as many samples, then the "covariance_eigh" solver is used. Otherwise, if the input data is larger than 500x500 and the number of components to extract is lower than 80% of the smallest dimension of the data, then the more efficient - 'randomized' method is enabled. Otherwise the exact "full" SVD is + "randomized" method is enabled. Otherwise the exact "full" SVD is computed and optionally truncated afterwards. "full" : Run exact full SVD calling the standard LAPACK solver via From 4a2a0627a4f2b43b21e913deb771a1fdd719d709 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Fri, 20 Oct 2023 17:30:33 +0200 Subject: [PATCH 48/70] Use the device() utility function instead of getattr --- sklearn/utils/extmath.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/sklearn/utils/extmath.py b/sklearn/utils/extmath.py index 668f9e66e7488..6ec03c6db861b 100644 --- a/sklearn/utils/extmath.py +++ b/sklearn/utils/extmath.py @@ -869,9 +869,8 @@ def svd_flip(u, v, u_based_decision=True): if u_based_decision: # columns of u, rows of v, or equivalently rows of u.T and v - device = getattr(u, "device", None) max_abs_u_cols = xp.argmax(xp.abs(u.T), axis=1) - shift = xp.arange(u.T.shape[0], device=device) + shift = xp.arange(u.T.shape[0], device=device(u)) indices = max_abs_u_cols + shift * u.T.shape[1] signs = xp.sign(xp.take(xp.reshape(u.T, (-1,)), indices, axis=0)) u *= signs[np.newaxis, :] @@ -879,9 +878,8 @@ def svd_flip(u, v, u_based_decision=True): v *= signs[:, np.newaxis] else: # rows of v, columns of u - device = getattr(v, "device", None) max_abs_v_rows = xp.argmax(xp.abs(v), axis=1) - shift = xp.arange(v.shape[0], device=device) + shift = xp.arange(v.shape[0], device=device(v)) indices = max_abs_v_rows + shift * v.shape[1] signs = xp.sign(xp.take(xp.reshape(v, (-1,)), indices, axis=0)) if u is not None: From a6bd3b84ce63348c29d6ce3f01812d9c7b2c7b25 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Fri, 20 Oct 2023 17:54:46 +0200 Subject: [PATCH 49/70] Improve inline comment to explain copy=False --- sklearn/decomposition/_pca.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/sklearn/decomposition/_pca.py b/sklearn/decomposition/_pca.py index 206b806d75347..28f3a96664132 100644 --- a/sklearn/decomposition/_pca.py +++ b/sklearn/decomposition/_pca.py @@ -510,9 +510,12 @@ def _fit(self, X): "PCA with svd_solver='arpack' is not supported for Array API inputs." ) - # Validate the data, without forcing a copy as it's not required for - # the `covariance_eigh` dataset and would be wasteful for large - # datasets. + # Validate the data, without ever forcing a copy as the + # `covariance_eigh` solver is written in a way to avoid the need for + # any inplace modification of the input data contrary to the other + # solvers. Forcing a copy here would be wasteful when using large + # datasets. The copy will happen later, only if needed, once the solver + # negotiation below is done. X = self._validate_data( X, dtype=[xp.float64, xp.float32], ensure_2d=True, copy=False ) From 1446f10e2203dcc679928a0fe188e57676e0c2da Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Fri, 20 Oct 2023 18:10:53 +0200 Subject: [PATCH 50/70] Improve comment to explain why we do not rely on numpy.cov --- sklearn/decomposition/_pca.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/sklearn/decomposition/_pca.py b/sklearn/decomposition/_pca.py index 28f3a96664132..6e75025c2f5ec 100644 --- a/sklearn/decomposition/_pca.py +++ b/sklearn/decomposition/_pca.py @@ -592,6 +592,16 @@ def _fit_full(self, X, n_components, xp, is_array_api_compliant): # (without centering the data X first) to avoid an unecessary copy # of X. Note that the mean_ attribute is still needed to center # test data in the transform method. + # + # Note: at the time of writing, `xp.cov` does not exist in the + # Array API standard: + # https://github.com/data-apis/array-api/issues/43 + # + # Besides, using `numpy.cov`, as of numpy 1.26.0, would not be + # memory efficient for our use case when `n_samples >> n_features`: + # `numpy.cov` centers a copy of the data before computing the + # matrix product instead of substracting a small `(n_features, + # n_features)` square matrix, a posteriori, as we do below. x_is_centered = False C = X.T @ X C -= ( From ff48edee3baaf2e5fe7bc211892c7fd698c554e3 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Sat, 4 Nov 2023 12:22:37 +0100 Subject: [PATCH 51/70] Apply suggestions from code review Co-authored-by: Guillaume Lemaitre --- sklearn/decomposition/_kernel_pca.py | 2 +- sklearn/decomposition/tests/test_sparse_pca.py | 2 +- sklearn/utils/extmath.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sklearn/decomposition/_kernel_pca.py b/sklearn/decomposition/_kernel_pca.py index f058a75116167..4622c1e0b7977 100644 --- a/sklearn/decomposition/_kernel_pca.py +++ b/sklearn/decomposition/_kernel_pca.py @@ -363,7 +363,7 @@ def _fit_transform(self, K): ) # flip eigenvectors' sign to enforce deterministic output - self.eigenvectors_, _ = svd_flip(self.eigenvectors_, None) + self.eigenvectors_, _ = svd_flip(u=self.eigenvectors_, v=None) # sort eigenvectors in descending order indices = self.eigenvalues_.argsort()[::-1] diff --git a/sklearn/decomposition/tests/test_sparse_pca.py b/sklearn/decomposition/tests/test_sparse_pca.py index 4b7834c7bfda9..1d8a6253edbad 100644 --- a/sklearn/decomposition/tests/test_sparse_pca.py +++ b/sklearn/decomposition/tests/test_sparse_pca.py @@ -117,7 +117,7 @@ def test_initialization(): model.fit(rng.randn(5, 4)) expected_components = V_init / np.linalg.norm(V_init, axis=1, keepdims=True) - expected_components = svd_flip(expected_components.T, None)[0].T + expected_components = svd_flip(u=expected_components.T, v=None)[0].T assert_allclose(model.components_, expected_components) diff --git a/sklearn/utils/extmath.py b/sklearn/utils/extmath.py index 6ec03c6db861b..f49e2943577f7 100644 --- a/sklearn/utils/extmath.py +++ b/sklearn/utils/extmath.py @@ -843,14 +843,14 @@ def svd_flip(u, v, u_based_decision=True): Parameters u and v are the output of `linalg.svd` or :func:`~sklearn.utils.extmath.randomized_svd`, with matching inner dimensions so one can compute `np.dot(u * s, v)`. - u can be None if u_based_decision is False. + u can be None if `u_based_decision` is False. v : ndarray Parameters u and v are the output of `linalg.svd` or :func:`~sklearn.utils.extmath.randomized_svd`, with matching inner dimensions so one can compute `np.dot(u * s, v)`. The input v should really be called vt to be consistent with scipy's output. - v can be None if u_based_decision is True. + v can be None if `u_based_decision` is True. u_based_decision : bool, default=True If True, use the columns of u as the basis for sign flipping. From 9e7f48ab4dba05e02071cfc29e5d799424fc8e02 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Sat, 4 Nov 2023 12:26:24 +0100 Subject: [PATCH 52/70] Apply suggestions from code review Co-authored-by: Guillaume Lemaitre --- sklearn/decomposition/_pca.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/sklearn/decomposition/_pca.py b/sklearn/decomposition/_pca.py index 6e75025c2f5ec..8831c9837e8d8 100644 --- a/sklearn/decomposition/_pca.py +++ b/sklearn/decomposition/_pca.py @@ -478,7 +478,7 @@ def fit_transform(self, X, y=None): This method returns a Fortran-ordered array. To convert it to a C-ordered array, use 'np.ascontiguousarray'. """ - U, S, Vt, X, x_is_centered, xp = self._fit(X) + U, S, _, X, x_is_centered, xp = self._fit(X) if U is not None: U = U[:, : self.n_components_] @@ -534,7 +534,7 @@ def _fit(self, X): if self._fit_svd_solver == "auto": # Tall and skinny problems are best handled by precomputing the # covariance matrix. - if X.shape[1] <= 1000 and X.shape[0] >= 10 * X.shape[1]: + if X.shape[1] <= 1_000 and X.shape[0] >= 10 * X.shape[1]: self._fit_svd_solver = "covariance_eigh" # Small problem or n_components == 'mle', just call full PCA elif max(X.shape) <= 500 or n_components == "mle": @@ -589,7 +589,7 @@ def _fit_full(self, X, n_components, xp, is_array_api_compliant): else: assert self._fit_svd_solver == "covariance_eigh" # In the following, we center the covariance matrix C a-posteriori - # (without centering the data X first) to avoid an unecessary copy + # (without centering the data X first) to avoid an unnecessary copy # of X. Note that the mean_ attribute is still needed to center # test data in the transform method. # @@ -600,7 +600,7 @@ def _fit_full(self, X, n_components, xp, is_array_api_compliant): # Besides, using `numpy.cov`, as of numpy 1.26.0, would not be # memory efficient for our use case when `n_samples >> n_features`: # `numpy.cov` centers a copy of the data before computing the - # matrix product instead of substracting a small `(n_features, + # matrix product instead of subtracting a small `(n_features, # n_features)` square matrix, a posteriori, as we do below. x_is_centered = False C = X.T @ X @@ -610,9 +610,9 @@ def _fit_full(self, X, n_components, xp, is_array_api_compliant): * xp.reshape(self.mean_, (1, -1)) ) C /= n_samples - 1 - eigenvals, Eigenvecs = xp.linalg.eigh(C) + eigenvals, eigenvecs = xp.linalg.eigh(C) eigenvals = xp.flip(eigenvals, axis=0) - Eigenvecs = xp.flip(Eigenvecs, axis=1) + eigenvecs = xp.flip(eigenvecs, axis=1) # The covariance matrix C is positive semi-definite by # construction. However, the eigenvalues returned by xp.linalg.eigh @@ -624,7 +624,7 @@ def _fit_full(self, X, n_components, xp, is_array_api_compliant): # Re-construct synthetic SVD components to be consistent with the # other solvers. S = xp.sqrt(eigenvals * (n_samples - 1)) - Vt = Eigenvecs.T + Vt = eigenvecs.T U = None # flip eigenvectors' sign to enforce deterministic output From 9c54050b94cfe1827a04097cd59c7f21c9f05501 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Fri, 10 Nov 2023 18:46:28 +0100 Subject: [PATCH 53/70] Fix typo --- sklearn/decomposition/_pca.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/decomposition/_pca.py b/sklearn/decomposition/_pca.py index 5eade43a4fdf5..b8f092a3d6b99 100644 --- a/sklearn/decomposition/_pca.py +++ b/sklearn/decomposition/_pca.py @@ -715,7 +715,7 @@ def _fit_truncated(self, X, n_components, xp): self.mean_, var = mean_variance_axis(X, axis=0) total_var = var.sum() * n_samples / (n_samples - 1) # ddof=1 X_centered = _implicit_column_offset(X, self.mean_) - x_is_centered = True + x_is_centered = False else: self.mean_ = xp.mean(X, axis=0) X_centered = xp.asarray(X, copy=True) if self.copy else X From 59ec249b84139b85b0e539dc8258b4edebecc1e9 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Sun, 12 Nov 2023 16:34:40 +0100 Subject: [PATCH 54/70] Fix broken tests --- sklearn/decomposition/_base.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/sklearn/decomposition/_base.py b/sklearn/decomposition/_base.py index 80c0c640f2ee8..8a95acc7dbb8b 100644 --- a/sklearn/decomposition/_base.py +++ b/sklearn/decomposition/_base.py @@ -12,11 +12,9 @@ import numpy as np from scipy import linalg -from scipy.sparse import issparse from ..base import BaseEstimator, ClassNamePrefixFeaturesOutMixin, TransformerMixin from ..utils._array_api import _add_to_diagonal, device, get_namespace -from ..utils.sparsefuncs import _implicit_column_offset from ..utils.validation import check_is_fitted @@ -152,12 +150,13 @@ def transform(self, X): ) def _transform(self, X, xp, x_is_centered=False): - if not x_is_centered and issparse(X): - X = _implicit_column_offset(X, self.mean_) X_transformed = X @ self.components_.T - if not x_is_centered and not issparse(X): + if not x_is_centered: # Apply the centering a posteriori for dense data so as to avoid a # copy of X or mutating the data passed by the caller. + # This also works for sparse X without having to wrap it into a + # linear operator a priori. + X_transformed = X @ self.components_.T X_transformed -= xp.reshape(self.mean_, (1, -1)) @ self.components_.T if self.whiten: # For some solvers (such as "arpack" and "covariance_eigh"), on From 31fe74111168d7eaff450ab56015e7bfc47d34c9 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Sun, 12 Nov 2023 17:32:46 +0100 Subject: [PATCH 55/70] Increase rtol for array api compliance test with float32 pytorch --- sklearn/decomposition/tests/test_pca.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sklearn/decomposition/tests/test_pca.py b/sklearn/decomposition/tests/test_pca.py index ae278f2cda20d..f09f7b2d9adab 100644 --- a/sklearn/decomposition/tests/test_pca.py +++ b/sklearn/decomposition/tests/test_pca.py @@ -954,6 +954,7 @@ def check_array_api_get_precision(name, estimator, array_namespace, device, dtyp precision_np = estimator.get_precision() covariance_np = estimator.get_covariance() + rtol = 2e-4 if iris_np.dtype == "float32" else 2e-7 with config_context(array_api_dispatch=True): estimator_xp = clone(estimator).fit(iris_xp) precision_xp = estimator_xp.get_precision() @@ -963,6 +964,7 @@ def check_array_api_get_precision(name, estimator, array_namespace, device, dtyp assert_allclose( _convert_to_numpy(precision_xp, xp=xp), precision_np, + rtol=rtol, atol=_atol_for_type(iris_np.dtype), ) covariance_xp = estimator_xp.get_covariance() @@ -972,6 +974,7 @@ def check_array_api_get_precision(name, estimator, array_namespace, device, dtyp assert_allclose( _convert_to_numpy(covariance_xp, xp=xp), covariance_np, + rtol=rtol, atol=_atol_for_type(iris_np.dtype), ) From ff3190c792bf2d8c3ef50950467ac6c343fa579a Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Sun, 12 Nov 2023 18:21:28 +0100 Subject: [PATCH 56/70] Accept sparse input data for svd_solver='covariance_eigh' --- doc/whats_new/v1.4.rst | 5 +++-- sklearn/decomposition/_pca.py | 24 ++++++++++++++++++------ sklearn/decomposition/tests/test_pca.py | 2 +- 3 files changed, 22 insertions(+), 9 deletions(-) diff --git a/doc/whats_new/v1.4.rst b/doc/whats_new/v1.4.rst index 503c288c9ce80..08e3b6f3e5f39 100644 --- a/doc/whats_new/v1.4.rst +++ b/doc/whats_new/v1.4.rst @@ -301,9 +301,10 @@ Changelog - |Enhancement| :class:`decomposition.PCA` now supports a new solver option named `svd_solver="covariance_eigh"` which offers an order of magnitude speed-up and reduced memory usage for datasets with a large number of data - points and a small number of features (say, less than 500). The + points and a small number of features (say, less than 1000). The `svd_solver="auto"` option has been updated to use the new solver - automatically for such datasets. + automatically for such datasets. This solver also accepts sparse input + data. :pr:`27491` by :user:`Olivier Grisel `. - |Enhancement| :class:`decomposition.PCA` now supports the Array API for the diff --git a/sklearn/decomposition/_pca.py b/sklearn/decomposition/_pca.py index b8f092a3d6b99..41ab4aeef80ee 100644 --- a/sklearn/decomposition/_pca.py +++ b/sklearn/decomposition/_pca.py @@ -133,8 +133,9 @@ class PCA(_BasePCA): It can also use the scipy.sparse.linalg ARPACK implementation of the truncated SVD. - Notice that this class does not support sparse input. See - :class:`TruncatedSVD` for an alternative with sparse data. + Notice that this class only supports sparse inputs for some solvers such as + "arpack" and "covariance_eigh". See :class:`TruncatedSVD` for an + alternative with sparse data. Read more in the :ref:`User Guide `. @@ -499,11 +500,11 @@ def _fit(self, X): xp, is_array_api_compliant = get_namespace(X) # Raise an error for sparse input and unsupported svd_solver - if issparse(X) and self.svd_solver != "arpack": + if issparse(X) and self.svd_solver not in ["arpack", "covariance_eigh"]: raise TypeError( - 'PCA only support sparse inputs with the "arpack" solver, while ' - f'"{self.svd_solver}" was passed. See TruncatedSVD for a possible' - " alternative." + 'PCA only support sparse inputs with the "arpack" and' + f' "covariance_eigh" solvers, while "{self.svd_solver}" was passed. See' + " TruncatedSVD for a possible alternative." ) # Raise an error for non-Numpy input and arpack solver. if self.svd_solver == "arpack" and is_array_api_compliant: @@ -573,6 +574,10 @@ def _fit_full(self, X, n_components, xp, is_array_api_compliant): ) self.mean_ = xp.mean(X, axis=0) + if isinstance(self.mean_, np.matrix): + # This can happen when X is a scipy sparse matrix. + self.mean_ = np.asarray(self.mean_).ravel() + if self._fit_svd_solver == "full": X_centered = xp.asarray(X, copy=True) if self.copy else X X_centered -= self.mean_ @@ -616,6 +621,13 @@ def _fit_full(self, X, n_components, xp, is_array_api_compliant): ) C /= n_samples - 1 eigenvals, eigenvecs = xp.linalg.eigh(C) + + # The following can happen when X is a scipy sparse matrix. + if isinstance(eigenvals, np.matrix): + eigenvals = np.asarray(eigenvals).ravel() + if isinstance(eigenvecs, np.matrix): + eigenvecs = np.asarray(eigenvecs) + eigenvals = xp.flip(eigenvals, axis=0) eigenvecs = xp.flip(eigenvecs, axis=1) diff --git a/sklearn/decomposition/tests/test_pca.py b/sklearn/decomposition/tests/test_pca.py index f09f7b2d9adab..bafaf48cb5ff4 100644 --- a/sklearn/decomposition/tests/test_pca.py +++ b/sklearn/decomposition/tests/test_pca.py @@ -69,7 +69,7 @@ def test_pca(svd_solver, n_components): @pytest.mark.parametrize("density", [0.01, 0.1, 0.30]) @pytest.mark.parametrize("n_components", [1, 2, 10]) @pytest.mark.parametrize("sparse_container", CSR_CONTAINERS + CSC_CONTAINERS) -@pytest.mark.parametrize("svd_solver", ["arpack"]) +@pytest.mark.parametrize("svd_solver", ["arpack", "covariance_eigh"]) @pytest.mark.parametrize("scale", [1, 10, 100]) def test_pca_sparse( global_random_seed, svd_solver, sparse_container, n_components, density, scale From d38650d18c93e15be9348953a216128227e13f0a Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Sun, 12 Nov 2023 19:27:12 +0100 Subject: [PATCH 57/70] Fix expected error message in test --- sklearn/decomposition/tests/test_pca.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/decomposition/tests/test_pca.py b/sklearn/decomposition/tests/test_pca.py index bafaf48cb5ff4..455a08b7a80b5 100644 --- a/sklearn/decomposition/tests/test_pca.py +++ b/sklearn/decomposition/tests/test_pca.py @@ -171,8 +171,8 @@ def test_sparse_pca_solver_error(global_random_seed, svd_solver, sparse_containe ) pca = PCA(n_components=30, svd_solver=svd_solver) error_msg_pattern = ( - f'PCA only support sparse inputs with the "arpack" solver, while "{svd_solver}"' - " was passed" + 'PCA only support sparse inputs with the "arpack" and "covariance_eigh"' + f' solvers, while "{svd_solver}" was passed' ) with pytest.raises(TypeError, match=error_msg_pattern): pca.fit(X) From 15db0220dc57dd64da4dd6fc461b4c0c18855665 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 4 Jan 2024 18:54:17 +0100 Subject: [PATCH 58/70] DOC move changelog entries to target 1.5 --- doc/whats_new/v1.4.rst | 30 ------------------------------ doc/whats_new/v1.5.rst | 41 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 30 deletions(-) diff --git a/doc/whats_new/v1.4.rst b/doc/whats_new/v1.4.rst index ba341c930bfe7..679606d28db17 100644 --- a/doc/whats_new/v1.4.rst +++ b/doc/whats_new/v1.4.rst @@ -355,14 +355,6 @@ Changelog :mod:`sklearn.decomposition` ............................ -- |Efficiency| :class:`decomposition.PCA` with `svd_solver="full"` now assigns - a contiguous `components_` attribute instead of an non-contiguous slice of - the singular vectors. When `n_components << n_features`, this can save some - memory and more importantly help speed-up subsequent calls to the `transform` - method by more than an order of magnitude by leveraging cache locality of - BLAS GEMM on contiguous arrays. - :pr:`27491` by :user:`Olivier Grisel `. - - |Feature| :class:`decomposition.PCA` now supports :class:`scipy.sparse.sparray` and :class:`scipy.sparse.spmatrix` inputs when using the `arpack` solver. When used on sparse data like :func:`datasets.fetch_20newsgroups_vectorized` this @@ -379,28 +371,6 @@ Changelog parameter will change from `None` to `auto` in version 1.6. :pr:`26634` by :user:`Alexandre Landeau ` and :user:`Alexandre Vigny `. -- |Enhancement| :class:`decomposition.PCA` now supports a new solver option - named `svd_solver="covariance_eigh"` which offers an order of magnitude - speed-up and reduced memory usage for datasets with a large number of data - points and a small number of features (say, less than 1000). The - `svd_solver="auto"` option has been updated to use the new solver - automatically for such datasets. This solver also accepts sparse input - data. - :pr:`27491` by :user:`Olivier Grisel `. - -- |Enhancement| :class:`decomposition.PCA` now supports the Array API for the - `full`, `covariance_eigh` and `randomized` solvers (with QR power iterations). - See :ref:`array_api` for more details. - :pr:`26315`, :pr:`27098`, :pr:`27431` and :pr:`27491` - by :user:`Mateusz Sokół `, :user:`Olivier Grisel ` - and :user:`Edoardo Abati `. - -- |Fix| :class:`decomposition.PCA` fit with `svd_solver="arpack"`, - `whiten=True` and a value for `n_components` that is larger than the rank of - the training set, no longer returns infinite values when transforming - held-out data. - :pr:`27491` by :user:`Olivier Grisel `. - - |Fix| :func:`decomposition.dict_learning_online` does not ignore anymore the parameter `max_iter`. :pr:`27834` by :user:`Guillaume Lemaitre `. diff --git a/doc/whats_new/v1.5.rst b/doc/whats_new/v1.5.rst index f7a521ca4f0d0..7f504edafafa1 100644 --- a/doc/whats_new/v1.5.rst +++ b/doc/whats_new/v1.5.rst @@ -11,6 +11,21 @@ Version 1.5.0 .. include:: changelog_legend.inc +Changed models +-------------- + +The following estimators and functions, when fit with the same data and +parameters, may produce different models from the previous version. This often +occurs due to changes in the modelling logic (bug fixes or enhancements), or in +random sampling procedures. + +- |Enhancement| :class:`decomposition.PCA`, :class:`decomposition.SparsePCA` and + :class:`decomposition.TruncatedSVD` now set the sign of the `components_` + attribute based on the components values instead of using the transformed + data as reference. This change is needed to be able to offer consistent + component signs across all `PCA` solvers, including the new + `svd_solver="covariance_eigh"` option introduced in this release. + Changelog --------- @@ -44,3 +59,29 @@ TODO: update at the time of the release. - |Feature| A fitted :class:`compose.ColumnTransformer` now implements `__getitem__` which returns the fitted transformers by name. :pr:`27990` by `Thomas Fan`_. + +:mod:`sklearn.decomposition` +............................ + +- |Efficiency| :class:`decomposition.PCA` with `svd_solver="full"` now assigns + a contiguous `components_` attribute instead of an non-contiguous slice of + the singular vectors. When `n_components << n_features`, this can save some + memory and more importantly help speed-up subsequent calls to the `transform` + method by more than an order of magnitude by leveraging cache locality of + BLAS GEMM on contiguous arrays. + :pr:`27491` by :user:`Olivier Grisel `. + +- |Enhancement| :class:`decomposition.PCA` now supports a new solver option + named `svd_solver="covariance_eigh"` which offers an order of magnitude + speed-up and reduced memory usage for datasets with a large number of data + points and a small number of features (say, less than 1000). The + `svd_solver="auto"` option has been updated to use the new solver + automatically for such datasets. This solver also accepts sparse input + data. + :pr:`27491` by :user:`Olivier Grisel `. + +- |Fix| :class:`decomposition.PCA` fit with `svd_solver="arpack"`, + `whiten=True` and a value for `n_components` that is larger than the rank of + the training set, no longer returns infinite values when transforming + held-out data. + :pr:`27491` by :user:`Olivier Grisel `. From 4aade1ad4693ac9ca6e3b84e86776b7bb37fef77 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 4 Jan 2024 18:55:54 +0100 Subject: [PATCH 59/70] DOC cleanup left over entry that was meant to be moved to 1.5 --- doc/whats_new/v1.4.rst | 7 ------- 1 file changed, 7 deletions(-) diff --git a/doc/whats_new/v1.4.rst b/doc/whats_new/v1.4.rst index 679606d28db17..d2de5ee433f94 100644 --- a/doc/whats_new/v1.4.rst +++ b/doc/whats_new/v1.4.rst @@ -31,13 +31,6 @@ random sampling procedures. specified `tol`, for small values you will get more precise results. :pr:`26721` by :user:`Christian Lorentzen `. -- |Enhancement| :class:`decomposition.PCA`, :class:`decomposition.SparsePCA` and - :class:`decomposition.TruncatedSVD` now set the sign of the `components_` - attribute based on the components values instead of using the transformed - data as reference. This change is needed to be able to offer consistent - component signs across all `PCA` solvers, including the new - `svd_solver="covariance_eigh"` option introduced in this release. - - |Fix| fixes a memory leak seen in PyPy for estimators using the Cython loss functions. :pr:`27670` by :user:`Guillaume Lemaitre `. From d131a641ba42e77726a631cff90b153b36d91d85 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Fri, 5 Jan 2024 09:59:07 +0100 Subject: [PATCH 60/70] Fix bad conflict resolution --- sklearn/decomposition/tests/test_pca.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/decomposition/tests/test_pca.py b/sklearn/decomposition/tests/test_pca.py index e2964f96766eb..4c6c5aa260276 100644 --- a/sklearn/decomposition/tests/test_pca.py +++ b/sklearn/decomposition/tests/test_pca.py @@ -974,7 +974,7 @@ def check_array_api_get_precision(name, estimator, array_namespace, device, dtyp assert_allclose( _convert_to_numpy(covariance_xp, xp=xp), - covariance_np + covariance_np, rtol=rtol, atol=_atol_for_type(dtype_name), ) From b6ef8ba2eed27d733a3f59c6addccfb119d483c1 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Sun, 7 Apr 2024 16:48:10 +0200 Subject: [PATCH 61/70] Apply suggestions from code review Co-authored-by: Guillaume Lemaitre --- sklearn/decomposition/_base.py | 6 +----- sklearn/decomposition/_pca.py | 4 ++-- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/sklearn/decomposition/_base.py b/sklearn/decomposition/_base.py index 8a95acc7dbb8b..abda8342ef584 100644 --- a/sklearn/decomposition/_base.py +++ b/sklearn/decomposition/_base.py @@ -143,11 +143,7 @@ def transform(self, X): X = self._validate_data( X, dtype=[xp.float64, xp.float32], accept_sparse=("csr", "csc"), reset=False ) - return self._transform( - X, - xp=xp, - x_is_centered=False, - ) + return self._transform(X, xp=xp, x_is_centered=False) def _transform(self, X, xp, x_is_centered=False): X_transformed = X @ self.components_.T diff --git a/sklearn/decomposition/_pca.py b/sklearn/decomposition/_pca.py index 4d7cd4e5a99c2..4590b8649fcf4 100644 --- a/sklearn/decomposition/_pca.py +++ b/sklearn/decomposition/_pca.py @@ -482,7 +482,7 @@ def fit_transform(self, X, y=None): U *= S[: self.n_components_] return U - else: + else: # solver="covariance_eigh" does not compute U at fit time. return self._transform(X, xp, x_is_centered=x_is_centered) def _fit(self, X): @@ -505,7 +505,7 @@ def _fit(self, X): # support sparse input data and the `covariance_eigh` solver are # written in a way to avoid the need for any inplace modification of # the input data contrary to the other solvers. Forcing a copy here - # would be wasteful when using large datasets. The copy will happen + # would be wasteful when using large datasets. The copy will happen # later, only if needed, once the solver negotiation below is done. X = self._validate_data( X, From e12fe8182869d94471e03ae32be3fade5bb4e861 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Sun, 7 Apr 2024 16:52:53 +0200 Subject: [PATCH 62/70] Use assert_allclose in test_whitening --- .../decomposition/tests/test_incremental_pca.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/sklearn/decomposition/tests/test_incremental_pca.py b/sklearn/decomposition/tests/test_incremental_pca.py index 030f9a101c47f..2e330d94558d0 100644 --- a/sklearn/decomposition/tests/test_incremental_pca.py +++ b/sklearn/decomposition/tests/test_incremental_pca.py @@ -3,7 +3,7 @@ import numpy as np import pytest -from numpy.testing import assert_array_equal +from numpy.testing import assert_allclose, assert_array_equal from sklearn import datasets from sklearn.decomposition import PCA, IncrementalPCA @@ -388,7 +388,7 @@ def test_whitening(global_random_seed): X = datasets.make_low_rank_matrix( 1000, 10, tail_strength=0.0, effective_rank=2, random_state=global_random_seed ) - prec = 3 + atol = 1e-3 for nc in [None, 9]: pca = PCA(whiten=True, n_components=nc).fit(X) ipca = IncrementalPCA(whiten=True, n_components=nc, batch_size=250).fit(X) @@ -401,10 +401,10 @@ def test_whitening(global_random_seed): Xt_pca = pca.transform(X) Xt_ipca = ipca.transform(X) - assert_almost_equal( + assert_allclose( np.abs(Xt_pca)[:, stable_mask], np.abs(Xt_ipca)[:, stable_mask], - decimal=prec, + atol=atol, ) # The noisy dimensions are in the null space of the inverse transform, @@ -412,9 +412,9 @@ def test_whitening(global_random_seed): # need to apply the mask here. Xinv_ipca = ipca.inverse_transform(Xt_ipca) Xinv_pca = pca.inverse_transform(Xt_pca) - assert_almost_equal(X, Xinv_ipca, decimal=prec) - assert_almost_equal(X, Xinv_pca, decimal=prec) - assert_almost_equal(Xinv_pca, Xinv_ipca, decimal=prec) + assert_allclose(X, Xinv_ipca, atol=atol) + assert_allclose(X, Xinv_pca, atol=atol) + assert_allclose(Xinv_pca, Xinv_ipca, atol=atol) def test_incremental_pca_partial_fit_float_division(): From 85e6f240f1405570c4d7b3bb9b61b7ea2ba89a35 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Sun, 7 Apr 2024 16:55:00 +0200 Subject: [PATCH 63/70] Move changed models entry to the existing section of the change log --- doc/whats_new/v1.5.rst | 23 +++++++---------------- 1 file changed, 7 insertions(+), 16 deletions(-) diff --git a/doc/whats_new/v1.5.rst b/doc/whats_new/v1.5.rst index 004be59110457..7e71514f58bb3 100644 --- a/doc/whats_new/v1.5.rst +++ b/doc/whats_new/v1.5.rst @@ -31,6 +31,13 @@ Changed models properties). :pr:`27344` by :user:`Xuefeng Xu `. +- |Enhancement| :class:`decomposition.PCA`, :class:`decomposition.SparsePCA` and + :class:`decomposition.TruncatedSVD` now set the sign of the `components_` + attribute based on the components values instead of using the transformed + data as reference. This change is needed to be able to offer consistent + component signs across all `PCA` solvers, including the new + `svd_solver="covariance_eigh"` option introduced in this release. + Support for Array API --------------------- @@ -101,22 +108,6 @@ more details. transformers' ``fit`` and ``fit_transform``. :pr:`28205` by :user:`Stefanie Senger `. - -Changed models --------------- - -The following estimators and functions, when fit with the same data and -parameters, may produce different models from the previous version. This often -occurs due to changes in the modelling logic (bug fixes or enhancements), or in -random sampling procedures. - -- |Enhancement| :class:`decomposition.PCA`, :class:`decomposition.SparsePCA` and - :class:`decomposition.TruncatedSVD` now set the sign of the `components_` - attribute based on the components values instead of using the transformed - data as reference. This change is needed to be able to offer consistent - component signs across all `PCA` solvers, including the new - `svd_solver="covariance_eigh"` option introduced in this release. - Changelog --------- From b47b60379e62ec097747bc4176b2cc09a64b56d3 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Wed, 10 Apr 2024 16:52:29 +0200 Subject: [PATCH 64/70] Apply suggestions from code review Co-authored-by: Christian Lorentzen --- doc/whats_new/v1.5.rst | 8 ++++---- sklearn/decomposition/_base.py | 10 +++++----- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/doc/whats_new/v1.5.rst b/doc/whats_new/v1.5.rst index 45d3edb071072..ffea1be3637a7 100644 --- a/doc/whats_new/v1.5.rst +++ b/doc/whats_new/v1.5.rst @@ -33,7 +33,7 @@ Changed models - |Enhancement| :class:`decomposition.PCA`, :class:`decomposition.SparsePCA` and :class:`decomposition.TruncatedSVD` now set the sign of the `components_` - attribute based on the components values instead of using the transformed + attribute based on the component values instead of using the transformed data as reference. This change is needed to be able to offer consistent component signs across all `PCA` solvers, including the new `svd_solver="covariance_eigh"` option introduced in this release. @@ -175,7 +175,7 @@ Changelog - |Efficiency| :class:`decomposition.PCA` with `svd_solver="full"` now assigns a contiguous `components_` attribute instead of an non-contiguous slice of the singular vectors. When `n_components << n_features`, this can save some - memory and more importantly help speed-up subsequent calls to the `transform` + memory and, more importantly, help speed-up subsequent calls to the `transform` method by more than an order of magnitude by leveraging cache locality of BLAS GEMM on contiguous arrays. :pr:`27491` by :user:`Olivier Grisel `. @@ -187,7 +187,7 @@ Changelog - |Enhancement| :class:`decomposition.PCA` now supports a new solver option named `svd_solver="covariance_eigh"` which offers an order of magnitude speed-up and reduced memory usage for datasets with a large number of data - points and a small number of features (say, less than 1000). The + points and a small number of features (say, `n_samples >> 1000 > n_features`). The `svd_solver="auto"` option has been updated to use the new solver automatically for such datasets. This solver also accepts sparse input data. @@ -196,7 +196,7 @@ Changelog - |Fix| :class:`decomposition.PCA` fit with `svd_solver="arpack"`, `whiten=True` and a value for `n_components` that is larger than the rank of the training set, no longer returns infinite values when transforming - held-out data. + hold-out data. :pr:`27491` by :user:`Olivier Grisel `. :mod:`sklearn.dummy` diff --git a/sklearn/decomposition/_base.py b/sklearn/decomposition/_base.py index abda8342ef584..5c9d8419f675e 100644 --- a/sklearn/decomposition/_base.py +++ b/sklearn/decomposition/_base.py @@ -148,11 +148,11 @@ def transform(self, X): def _transform(self, X, xp, x_is_centered=False): X_transformed = X @ self.components_.T if not x_is_centered: - # Apply the centering a posteriori for dense data so as to avoid a - # copy of X or mutating the data passed by the caller. - # This also works for sparse X without having to wrap it into a - # linear operator a priori. - X_transformed = X @ self.components_.T + # Apply the centering after the projection. + # For dense X this avoids copying or mutating the data passed by + # the caller. + # For sparse X it keeps sparsity and avoids having to wrap X into + # a linear operator. X_transformed -= xp.reshape(self.mean_, (1, -1)) @ self.components_.T if self.whiten: # For some solvers (such as "arpack" and "covariance_eigh"), on From b43f6f963211cc78685ae59cc0b9f6850bd2f11d Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Wed, 10 Apr 2024 17:30:16 +0200 Subject: [PATCH 65/70] Apply suggestions from code review Co-authored-by: Christian Lorentzen --- sklearn/decomposition/_pca.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/sklearn/decomposition/_pca.py b/sklearn/decomposition/_pca.py index 4590b8649fcf4..2f142445cbf83 100644 --- a/sklearn/decomposition/_pca.py +++ b/sklearn/decomposition/_pca.py @@ -183,13 +183,13 @@ class PCA(_BasePCA): svd_solver : {'auto', 'full', 'covariance_eigh', 'arpack', 'randomized'},\ default='auto' "auto" : - The solver is selected by a default policy based on `X.shape` and + The solver is selected by a default 'auto' policy is based on `X.shape` and `n_components`: if the input data has fewer than 1000 features and more than 10 times as many samples, then the "covariance_eigh" solver is used. Otherwise, if the input data is larger than 500x500 and the number of components to extract is lower than 80% of the smallest dimension of the data, then the more efficient - "randomized" method is enabled. Otherwise the exact "full" SVD is + "randomized" method is selected. Otherwise the exact "full" SVD is computed and optionally truncated afterwards. "full" : Run exact full SVD calling the standard LAPACK solver via @@ -198,8 +198,8 @@ class PCA(_BasePCA): Precompute the covariance matrix (on centered data), run a classical eigenvalue decomposition on the covariance matrix typically using LAPACK and select the components by postprocessing. - This solver is very efficient when the number of features is small - and not tractable otherwise (large memory footprint required to + This solver is very efficient for n_samples >> n_features and small n_features. + It is, however, not tractable otherwise for large n_features (large memory footprint required to materialize the covariance matrix). Also note that compared to the "full" solver, this solver effectively doubles the condition number and is therefore less numerical stable (e.g. on input data with a @@ -501,11 +501,11 @@ def _fit(self, X): "PCA with svd_solver='arpack' is not supported for Array API inputs." ) - # Validate the data, without ever forcing a copy as any solvers that - # support sparse input data and the `covariance_eigh` solver are + # Validate the data, without ever forcing a copy as any solver that + # supports sparse input data and the `covariance_eigh` solver are # written in a way to avoid the need for any inplace modification of - # the input data contrary to the other solvers. Forcing a copy here - # would be wasteful when using large datasets. The copy will happen + # the input data contrary to the other solvers. + # The copy will happen # later, only if needed, once the solver negotiation below is done. X = self._validate_data( X, @@ -587,7 +587,7 @@ def _fit_full(self, X, n_components, xp, is_array_api_compliant): else: assert self._fit_svd_solver == "covariance_eigh" - # In the following, we center the covariance matrix C a-posteriori + # In the following, we center the covariance matrix C afterwards # (without centering the data X first) to avoid an unnecessary copy # of X. Note that the mean_ attribute is still needed to center # test data in the transform method. @@ -600,7 +600,7 @@ def _fit_full(self, X, n_components, xp, is_array_api_compliant): # memory efficient for our use case when `n_samples >> n_features`: # `numpy.cov` centers a copy of the data before computing the # matrix product instead of subtracting a small `(n_features, - # n_features)` square matrix, a posteriori, as we do below. + # n_features)` square matrix from the gram matrix X.T @ X, as we do below. x_is_centered = False C = X.T @ X C -= ( @@ -627,7 +627,7 @@ def _fit_full(self, X, n_components, xp, is_array_api_compliant): eigenvals[eigenvals < 0.0] = 0.0 explained_variance_ = eigenvals - # Re-construct synthetic SVD components to be consistent with the + # Re-construct SVD of centered X indirectly and make it consistent with the # other solvers. S = xp.sqrt(eigenvals * (n_samples - 1)) Vt = eigenvecs.T From 310b6163a388568a2ad7fa95feeea3f69751e72a Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Wed, 10 Apr 2024 17:36:54 +0200 Subject: [PATCH 66/70] Wrap paragraphs in multiline comments --- sklearn/decomposition/_pca.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/sklearn/decomposition/_pca.py b/sklearn/decomposition/_pca.py index 2f142445cbf83..59dab17e67e69 100644 --- a/sklearn/decomposition/_pca.py +++ b/sklearn/decomposition/_pca.py @@ -198,12 +198,13 @@ class PCA(_BasePCA): Precompute the covariance matrix (on centered data), run a classical eigenvalue decomposition on the covariance matrix typically using LAPACK and select the components by postprocessing. - This solver is very efficient for n_samples >> n_features and small n_features. - It is, however, not tractable otherwise for large n_features (large memory footprint required to - materialize the covariance matrix). Also note that compared to the - "full" solver, this solver effectively doubles the condition number - and is therefore less numerical stable (e.g. on input data with a - large range of singular values). + This solver is very efficient for n_samples >> n_features and small + n_features. It is, however, not tractable otherwise for large + n_features (large memory footprint required to materialize the + covariance matrix). Also note that compared to the "full" solver, + this solver effectively doubles the condition number and is + therefore less numerical stable (e.g. on input data with a large + range of singular values). "arpack" : Run SVD truncated to `n_components` calling ARPACK solver via `scipy.sparse.linalg.svds`. It requires strictly @@ -600,7 +601,8 @@ def _fit_full(self, X, n_components, xp, is_array_api_compliant): # memory efficient for our use case when `n_samples >> n_features`: # `numpy.cov` centers a copy of the data before computing the # matrix product instead of subtracting a small `(n_features, - # n_features)` square matrix from the gram matrix X.T @ X, as we do below. + # n_features)` square matrix from the gram matrix X.T @ X, as we do + # below. x_is_centered = False C = X.T @ X C -= ( @@ -627,8 +629,8 @@ def _fit_full(self, X, n_components, xp, is_array_api_compliant): eigenvals[eigenvals < 0.0] = 0.0 explained_variance_ = eigenvals - # Re-construct SVD of centered X indirectly and make it consistent with the - # other solvers. + # Re-construct SVD of centered X indirectly and make it consistent + # with the other solvers. S = xp.sqrt(eigenvals * (n_samples - 1)) Vt = eigenvecs.T U = None @@ -682,8 +684,9 @@ def _fit_full(self, X, n_components, xp, is_array_api_compliant): # Assign a copy of the result of the truncation of the components in # order to: # - release the memory used by the discarded components, - # - ensure that the kept components are allocated contiguously in memory - # to make the transform method faster by leveraging cache locality. + # - ensure that the kept components are allocated contiguously in + # memory to make the transform method faster by leveraging cache + # locality. self.components_ = xp.asarray(components_[:n_components, :], copy=True) # We do the same for the other arrays for the sake of consistency. From 73a4d616835b0e4f489eefab9240fdf840578358 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Wed, 10 Apr 2024 20:41:15 +0200 Subject: [PATCH 67/70] Update sklearn/decomposition/tests/test_incremental_pca.py Co-authored-by: Christian Lorentzen --- sklearn/decomposition/tests/test_incremental_pca.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/decomposition/tests/test_incremental_pca.py b/sklearn/decomposition/tests/test_incremental_pca.py index d9ac271d6222f..50ddf39b04503 100644 --- a/sklearn/decomposition/tests/test_incremental_pca.py +++ b/sklearn/decomposition/tests/test_incremental_pca.py @@ -397,7 +397,7 @@ def test_whitening(global_random_seed): # Since the data is rank deficient, some components are pure noise. We # should not expect those dimensions to carry any signal and their # values might be arbitrarily changed by implementation details of the - # internal SVD solver. We therefore mask them out before comparison. + # internal SVD solver. We therefore filter them out before comparison. stable_mask = pca.explained_variance_ratio_ > 1e-12 Xt_pca = pca.transform(X) From 797e081e547a1c9f5bee0f827c57f9ad6c33e50d Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Wed, 10 Apr 2024 23:39:53 +0200 Subject: [PATCH 68/70] Wrap paragraphs in changelog --- doc/whats_new/v1.5.rst | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/doc/whats_new/v1.5.rst b/doc/whats_new/v1.5.rst index 7193fa1d9e8d1..1a8c50e408a0b 100644 --- a/doc/whats_new/v1.5.rst +++ b/doc/whats_new/v1.5.rst @@ -31,11 +31,11 @@ Changed models properties). :pr:`27344` by :user:`Xuefeng Xu `. -- |Enhancement| :class:`decomposition.PCA`, :class:`decomposition.SparsePCA` and - :class:`decomposition.TruncatedSVD` now set the sign of the `components_` - attribute based on the component values instead of using the transformed - data as reference. This change is needed to be able to offer consistent - component signs across all `PCA` solvers, including the new +- |Enhancement| :class:`decomposition.PCA`, :class:`decomposition.SparsePCA` + and :class:`decomposition.TruncatedSVD` now set the sign of the `components_` + attribute based on the component values instead of using the transformed data + as reference. This change is needed to be able to offer consistent component + signs across all `PCA` solvers, including the new `svd_solver="covariance_eigh"` option introduced in this release. Support for Array API @@ -191,9 +191,9 @@ Changelog - |Enhancement| :class:`decomposition.PCA` now supports a new solver option named `svd_solver="covariance_eigh"` which offers an order of magnitude speed-up and reduced memory usage for datasets with a large number of data - points and a small number of features (say, `n_samples >> 1000 > n_features`). The - `svd_solver="auto"` option has been updated to use the new solver - automatically for such datasets. This solver also accepts sparse input + points and a small number of features (say, `n_samples >> 1000 > + n_features`). The `svd_solver="auto"` option has been updated to use the new + solver automatically for such datasets. This solver also accepts sparse input data. :pr:`27491` by :user:`Olivier Grisel `. From 47935f512a97744ecd2dc1249e44128abbefa38e Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 11 Apr 2024 00:09:59 +0200 Subject: [PATCH 69/70] Do not test on isinstance(..., np.matrix) and only use Array API for consistency --- sklearn/decomposition/_pca.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/sklearn/decomposition/_pca.py b/sklearn/decomposition/_pca.py index 59dab17e67e69..73813e758a91c 100644 --- a/sklearn/decomposition/_pca.py +++ b/sklearn/decomposition/_pca.py @@ -613,11 +613,15 @@ def _fit_full(self, X, n_components, xp, is_array_api_compliant): C /= n_samples - 1 eigenvals, eigenvecs = xp.linalg.eigh(C) - # The following can happen when X is a scipy sparse matrix. - if isinstance(eigenvals, np.matrix): - eigenvals = np.asarray(eigenvals).ravel() - if isinstance(eigenvecs, np.matrix): - eigenvecs = np.asarray(eigenvecs) + # When X is a scipy sparse matrix, the following two datastructures + # are returned as instances of the soft-deprecated numpy.matrix + # class. Note that this problem does not occur when X is a scipy + # sparse array (or another other kind of supported array). + # TODO: remove the following two lines when scikit-learn only + # depends on scipy versions that no longer support scipy.sparse + # matrices. + eigenvals = xp.reshape(xp.asarray(eigenvals), (-1,)) + eigenvecs = xp.asarray(eigenvecs) eigenvals = xp.flip(eigenvals, axis=0) eigenvecs = xp.flip(eigenvecs, axis=1) From 9842711c1c0b1fa4d98780c41cf14c7816ab2490 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 11 Apr 2024 08:14:50 +0200 Subject: [PATCH 70/70] Do not test on isinstance(..., np.matrix) and only use Array API for consistency --- sklearn/decomposition/_pca.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/sklearn/decomposition/_pca.py b/sklearn/decomposition/_pca.py index 73813e758a91c..852547daab04d 100644 --- a/sklearn/decomposition/_pca.py +++ b/sklearn/decomposition/_pca.py @@ -564,9 +564,12 @@ def _fit_full(self, X, n_components, xp, is_array_api_compliant): ) self.mean_ = xp.mean(X, axis=0) - if isinstance(self.mean_, np.matrix): - # This can happen when X is a scipy sparse matrix. - self.mean_ = np.asarray(self.mean_).ravel() + # When X is a scipy sparse matrix, self.mean_ is a numpy matrix, so we need + # to transform it to a 1D array. Note that this is not the case when X + # is a scipy sparse array. + # TODO: remove the following two lines when scikit-learn only depends + # on scipy versions that no longer support scipy.sparse matrices. + self.mean_ = xp.reshape(xp.asarray(self.mean_), (-1,)) if self._fit_svd_solver == "full": X_centered = xp.asarray(X, copy=True) if self.copy else X