diff --git a/.gitignore b/.gitignore index 8a31fc8f542c4..9f3b453bbfd74 100644 --- a/.gitignore +++ b/.gitignore @@ -55,6 +55,7 @@ examples/cluster/joblib reuters/ benchmarks/bench_covertype_data/ benchmarks/HIGGS.csv.gz +bench_pca_solvers.csv *.prefs .pydevproject diff --git a/benchmarks/bench_pca_solvers.py b/benchmarks/bench_pca_solvers.py new file mode 100644 index 0000000000000..337af3a42e900 --- /dev/null +++ b/benchmarks/bench_pca_solvers.py @@ -0,0 +1,165 @@ +# %% +# +# This benchmark compares the speed of PCA solvers on datasets of different +# sizes in order to determine the best solver to select by default via the +# "auto" heuristic. +# +# Note: we do not control for the accuracy of the solvers: we assume that all +# solvers yield transformed data with similar explained variance. This +# assumption is generally true, except for the randomized solver that might +# require more power iterations. +# +# We generate synthetic data with dimensions that are useful to plot: +# - time vs n_samples for a fixed n_features and, +# - time vs n_features for a fixed n_samples for a fixed n_features. +import itertools +from math import log10 +from time import perf_counter + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + +from sklearn import config_context +from sklearn.decomposition import PCA + +REF_DIMS = [100, 1000, 10_000] +data_shapes = [] +for ref_dim in REF_DIMS: + data_shapes.extend([(ref_dim, 10**i) for i in range(1, 8 - int(log10(ref_dim)))]) + data_shapes.extend( + [(ref_dim, 3 * 10**i) for i in range(1, 8 - int(log10(ref_dim)))] + ) + data_shapes.extend([(10**i, ref_dim) for i in range(1, 8 - int(log10(ref_dim)))]) + data_shapes.extend( + [(3 * 10**i, ref_dim) for i in range(1, 8 - int(log10(ref_dim)))] + ) + +# Remove duplicates: +data_shapes = sorted(set(data_shapes)) + +print("Generating test datasets...") +rng = np.random.default_rng(0) +datasets = [rng.normal(size=shape) for shape in data_shapes] + + +# %% +def measure_one(data, n_components, solver, method_name="fit"): + print( + f"Benchmarking {solver=!r}, {n_components=}, {method_name=!r} on data with" + f" shape {data.shape}" + ) + pca = PCA(n_components=n_components, svd_solver=solver, random_state=0) + timings = [] + elapsed = 0 + method = getattr(pca, method_name) + with config_context(assume_finite=True): + while elapsed < 0.5: + tic = perf_counter() + method(data) + duration = perf_counter() - tic + timings.append(duration) + elapsed += duration + return np.median(timings) + + +SOLVERS = ["full", "covariance_eigh", "arpack", "randomized", "auto"] +measurements = [] +for data, n_components, method_name in itertools.product( + datasets, [2, 50], ["fit", "fit_transform"] +): + if n_components >= min(data.shape): + continue + for solver in SOLVERS: + if solver == "covariance_eigh" and data.shape[1] > 5000: + # Too much memory and too slow. + continue + if solver in ["arpack", "full"] and log10(data.size) > 7: + # Too slow, in particular for the full solver. + continue + time = measure_one(data, n_components, solver, method_name=method_name) + measurements.append( + { + "n_components": n_components, + "n_samples": data.shape[0], + "n_features": data.shape[1], + "time": time, + "solver": solver, + "method_name": method_name, + } + ) +measurements = pd.DataFrame(measurements) +measurements.to_csv("bench_pca_solvers.csv", index=False) + +# %% +all_method_names = measurements["method_name"].unique() +all_n_components = measurements["n_components"].unique() + +for method_name in all_method_names: + fig, axes = plt.subplots( + figsize=(16, 16), + nrows=len(REF_DIMS), + ncols=len(all_n_components), + sharey=True, + constrained_layout=True, + ) + fig.suptitle(f"Benchmarks for PCA.{method_name}, varying n_samples", fontsize=16) + + for row_idx, ref_dim in enumerate(REF_DIMS): + for n_components, ax in zip(all_n_components, axes[row_idx]): + for solver in SOLVERS: + if solver == "auto": + style_kwargs = dict(linewidth=2, color="black", style="--") + else: + style_kwargs = dict(style="o-") + ax.set( + title=f"n_components={n_components}, n_features={ref_dim}", + ylabel="time (s)", + ) + measurements.query( + "n_components == @n_components and n_features == @ref_dim" + " and solver == @solver and method_name == @method_name" + ).plot.line( + x="n_samples", + y="time", + label=solver, + logx=True, + logy=True, + ax=ax, + **style_kwargs, + ) +# %% +for method_name in all_method_names: + fig, axes = plt.subplots( + figsize=(16, 16), + nrows=len(REF_DIMS), + ncols=len(all_n_components), + sharey=True, + ) + fig.suptitle(f"Benchmarks for PCA.{method_name}, varying n_features", fontsize=16) + + for row_idx, ref_dim in enumerate(REF_DIMS): + for n_components, ax in zip(all_n_components, axes[row_idx]): + for solver in SOLVERS: + if solver == "auto": + style_kwargs = dict(linewidth=2, color="black", style="--") + else: + style_kwargs = dict(style="o-") + ax.set( + title=f"n_components={n_components}, n_samples={ref_dim}", + ylabel="time (s)", + ) + measurements.query( + "n_components == @n_components and n_samples == @ref_dim " + " and solver == @solver and method_name == @method_name" + ).plot.line( + x="n_features", + y="time", + label=solver, + logx=True, + logy=True, + ax=ax, + **style_kwargs, + ) + +# %% diff --git a/doc/modules/compose.rst b/doc/modules/compose.rst index 0047ec7d8a2f0..28931cf52f283 100644 --- a/doc/modules/compose.rst +++ b/doc/modules/compose.rst @@ -254,14 +254,14 @@ inspect the original instance such as:: >>> from sklearn.datasets import load_digits >>> X_digits, y_digits = load_digits(return_X_y=True) - >>> pca1 = PCA() + >>> pca1 = PCA(n_components=10) >>> svm1 = SVC() >>> pipe = Pipeline([('reduce_dim', pca1), ('clf', svm1)]) >>> pipe.fit(X_digits, y_digits) - Pipeline(steps=[('reduce_dim', PCA()), ('clf', SVC())]) + Pipeline(steps=[('reduce_dim', PCA(n_components=10)), ('clf', SVC())]) >>> # The pca instance can be inspected directly - >>> print(pca1.components_) - [[-1.77484909e-19 ... 4.07058917e-18]] + >>> pca1.components_.shape + (10, 64) Enabling caching triggers a clone of the transformers before fitting. @@ -274,15 +274,15 @@ Instead, use the attribute ``named_steps`` to inspect estimators within the pipeline:: >>> cachedir = mkdtemp() - >>> pca2 = PCA() + >>> pca2 = PCA(n_components=10) >>> svm2 = SVC() >>> cached_pipe = Pipeline([('reduce_dim', pca2), ('clf', svm2)], ... memory=cachedir) >>> cached_pipe.fit(X_digits, y_digits) Pipeline(memory=..., - steps=[('reduce_dim', PCA()), ('clf', SVC())]) - >>> print(cached_pipe.named_steps['reduce_dim'].components_) - [[-1.77484909e-19 ... 4.07058917e-18]] + steps=[('reduce_dim', PCA(n_components=10)), ('clf', SVC())]) + >>> cached_pipe.named_steps['reduce_dim'].components_.shape + (10, 64) >>> # Remove the cache directory >>> rmtree(cachedir) diff --git a/doc/whats_new/v1.5.rst b/doc/whats_new/v1.5.rst index 9e13171e88528..1a8c50e408a0b 100644 --- a/doc/whats_new/v1.5.rst +++ b/doc/whats_new/v1.5.rst @@ -31,6 +31,13 @@ Changed models properties). :pr:`27344` by :user:`Xuefeng Xu `. +- |Enhancement| :class:`decomposition.PCA`, :class:`decomposition.SparsePCA` + and :class:`decomposition.TruncatedSVD` now set the sign of the `components_` + attribute based on the component values instead of using the transformed data + as reference. This change is needed to be able to offer consistent component + signs across all `PCA` solvers, including the new + `svd_solver="covariance_eigh"` option introduced in this release. + Support for Array API --------------------- @@ -169,10 +176,33 @@ Changelog :mod:`sklearn.decomposition` ............................ +- |Efficiency| :class:`decomposition.PCA` with `svd_solver="full"` now assigns + a contiguous `components_` attribute instead of an non-contiguous slice of + the singular vectors. When `n_components << n_features`, this can save some + memory and, more importantly, help speed-up subsequent calls to the `transform` + method by more than an order of magnitude by leveraging cache locality of + BLAS GEMM on contiguous arrays. + :pr:`27491` by :user:`Olivier Grisel `. + - |Enhancement| :class:`~decomposition.PCA` now automatically selects the ARPACK solver for sparse inputs when `svd_solver="auto"` instead of raising an error. :pr:`28498` by :user:`Thanh Lam Dang `. +- |Enhancement| :class:`decomposition.PCA` now supports a new solver option + named `svd_solver="covariance_eigh"` which offers an order of magnitude + speed-up and reduced memory usage for datasets with a large number of data + points and a small number of features (say, `n_samples >> 1000 > + n_features`). The `svd_solver="auto"` option has been updated to use the new + solver automatically for such datasets. This solver also accepts sparse input + data. + :pr:`27491` by :user:`Olivier Grisel `. + +- |Fix| :class:`decomposition.PCA` fit with `svd_solver="arpack"`, + `whiten=True` and a value for `n_components` that is larger than the rank of + the training set, no longer returns infinite values when transforming + hold-out data. + :pr:`27491` by :user:`Olivier Grisel `. + :mod:`sklearn.dummy` .................... diff --git a/sklearn/decomposition/_base.py b/sklearn/decomposition/_base.py index 9fa720751774f..5c9d8419f675e 100644 --- a/sklearn/decomposition/_base.py +++ b/sklearn/decomposition/_base.py @@ -12,11 +12,9 @@ import numpy as np from scipy import linalg -from scipy.sparse import issparse from ..base import BaseEstimator, ClassNamePrefixFeaturesOutMixin, TransformerMixin from ..utils._array_api import _add_to_diagonal, device, get_namespace -from ..utils.sparsefuncs import _implicit_column_offset from ..utils.validation import check_is_fitted @@ -138,21 +136,33 @@ def transform(self, X): Projection of X in the first principal components, where `n_samples` is the number of samples and `n_components` is the number of the components. """ - xp, _ = get_namespace(X) + xp, _ = get_namespace(X, self.components_, self.explained_variance_) check_is_fitted(self) X = self._validate_data( - X, accept_sparse=("csr", "csc"), dtype=[xp.float64, xp.float32], reset=False + X, dtype=[xp.float64, xp.float32], accept_sparse=("csr", "csc"), reset=False ) - if self.mean_ is not None: - if issparse(X): - X = _implicit_column_offset(X, self.mean_) - else: - X = X - self.mean_ + return self._transform(X, xp=xp, x_is_centered=False) + + def _transform(self, X, xp, x_is_centered=False): X_transformed = X @ self.components_.T + if not x_is_centered: + # Apply the centering after the projection. + # For dense X this avoids copying or mutating the data passed by + # the caller. + # For sparse X it keeps sparsity and avoids having to wrap X into + # a linear operator. + X_transformed -= xp.reshape(self.mean_, (1, -1)) @ self.components_.T if self.whiten: - X_transformed /= xp.sqrt(self.explained_variance_) + # For some solvers (such as "arpack" and "covariance_eigh"), on + # rank deficient data, some components can have a variance + # arbitrarily close to zero, leading to non-finite results when + # whitening. To avoid this problem we clip the variance below. + scale = xp.sqrt(self.explained_variance_) + min_scale = xp.finfo(scale.dtype).eps + scale[scale < min_scale] = min_scale + X_transformed /= scale return X_transformed def inverse_transform(self, X): diff --git a/sklearn/decomposition/_kernel_pca.py b/sklearn/decomposition/_kernel_pca.py index 8fc4aa26a6dfb..edfd49c2e87a0 100644 --- a/sklearn/decomposition/_kernel_pca.py +++ b/sklearn/decomposition/_kernel_pca.py @@ -366,9 +366,7 @@ def _fit_transform(self, K): ) # flip eigenvectors' sign to enforce deterministic output - self.eigenvectors_, _ = svd_flip( - self.eigenvectors_, np.zeros_like(self.eigenvectors_).T - ) + self.eigenvectors_, _ = svd_flip(u=self.eigenvectors_, v=None) # sort eigenvectors in descending order indices = self.eigenvalues_.argsort()[::-1] diff --git a/sklearn/decomposition/_pca.py b/sklearn/decomposition/_pca.py index 4c49337e88093..852547daab04d 100644 --- a/sklearn/decomposition/_pca.py +++ b/sklearn/decomposition/_pca.py @@ -133,6 +133,10 @@ class PCA(_BasePCA): used (i.e. through :func:`scipy.sparse.linalg.svds`). Alternatively, one may consider :class:`TruncatedSVD` where the data are not centered. + Notice that this class only supports sparse inputs for some solvers such as + "arpack" and "covariance_eigh". See :class:`TruncatedSVD` for an + alternative with sparse data. + For a usage example, see :ref:`sphx_glr_auto_examples_decomposition_plot_pca_iris.py` @@ -176,26 +180,43 @@ class PCA(_BasePCA): improve the predictive accuracy of the downstream estimators by making their data respect some hard-wired assumptions. - svd_solver : {'auto', 'full', 'arpack', 'randomized'}, default='auto' - If auto : - The solver is selected by a default policy based on `X.shape` and - `n_components`: if the input data is larger than 500x500 and the - number of components to extract is lower than 80% of the smallest - dimension of the data, then the more efficient 'randomized' - method is enabled. Otherwise the exact full SVD is computed and - optionally truncated afterwards. - If full : - run exact full SVD calling the standard LAPACK solver via + svd_solver : {'auto', 'full', 'covariance_eigh', 'arpack', 'randomized'},\ + default='auto' + "auto" : + The solver is selected by a default 'auto' policy is based on `X.shape` and + `n_components`: if the input data has fewer than 1000 features and + more than 10 times as many samples, then the "covariance_eigh" + solver is used. Otherwise, if the input data is larger than 500x500 + and the number of components to extract is lower than 80% of the + smallest dimension of the data, then the more efficient + "randomized" method is selected. Otherwise the exact "full" SVD is + computed and optionally truncated afterwards. + "full" : + Run exact full SVD calling the standard LAPACK solver via `scipy.linalg.svd` and select the components by postprocessing - If arpack : - run SVD truncated to n_components calling ARPACK solver via + "covariance_eigh" : + Precompute the covariance matrix (on centered data), run a + classical eigenvalue decomposition on the covariance matrix + typically using LAPACK and select the components by postprocessing. + This solver is very efficient for n_samples >> n_features and small + n_features. It is, however, not tractable otherwise for large + n_features (large memory footprint required to materialize the + covariance matrix). Also note that compared to the "full" solver, + this solver effectively doubles the condition number and is + therefore less numerical stable (e.g. on input data with a large + range of singular values). + "arpack" : + Run SVD truncated to `n_components` calling ARPACK solver via `scipy.sparse.linalg.svds`. It requires strictly - 0 < n_components < min(X.shape) - If randomized : - run randomized SVD by the method of Halko et al. + `0 < n_components < min(X.shape)` + "randomized" : + Run randomized SVD by the method of Halko et al. .. versionadded:: 0.18.0 + .. versionchanged:: 1.4 + Added the 'covariance_eigh' solver. + tol : float, default=0.0 Tolerance for singular values computed by svd_solver == 'arpack'. Must be of range [0.0, infinity). @@ -370,7 +391,9 @@ class PCA(_BasePCA): ], "copy": ["boolean"], "whiten": ["boolean"], - "svd_solver": [StrOptions({"auto", "full", "arpack", "randomized"})], + "svd_solver": [ + StrOptions({"auto", "full", "covariance_eigh", "arpack", "randomized"}) + ], "tol": [Interval(Real, 0, None, closed="left")], "iterated_power": [ StrOptions({"auto"}), @@ -448,39 +471,49 @@ def fit_transform(self, X, y=None): This method returns a Fortran-ordered array. To convert it to a C-ordered array, use 'np.ascontiguousarray'. """ - U, S, Vt = self._fit(X) - U = U[:, : self.n_components_] + U, S, _, X, x_is_centered, xp = self._fit(X) + if U is not None: + U = U[:, : self.n_components_] - if self.whiten: - # X_new = X * V / S * sqrt(n_samples) = U * sqrt(n_samples) - U *= sqrt(X.shape[0] - 1) - else: - # X_new = X * V = U * S * Vt * V = U * S - U *= S[: self.n_components_] + if self.whiten: + # X_new = X * V / S * sqrt(n_samples) = U * sqrt(n_samples) + U *= sqrt(X.shape[0] - 1) + else: + # X_new = X * V = U * S * Vt * V = U * S + U *= S[: self.n_components_] - return U + return U + else: # solver="covariance_eigh" does not compute U at fit time. + return self._transform(X, xp, x_is_centered=x_is_centered) def _fit(self, X): """Dispatch to the right submethod depending on the chosen solver.""" xp, is_array_api_compliant = get_namespace(X) - if issparse(X) and self.svd_solver not in {"arpack", "auto"}: + # Raise an error for sparse input and unsupported svd_solver + if issparse(X) and self.svd_solver not in ["auto", "arpack", "covariance_eigh"]: raise TypeError( - 'PCA only support sparse inputs with the "arpack" solver, while ' - f'"{self.svd_solver}" was passed. See TruncatedSVD for a possible' - " alternative." + 'PCA only support sparse inputs with the "arpack" and' + f' "covariance_eigh" solvers, while "{self.svd_solver}" was passed. See' + " TruncatedSVD for a possible alternative." ) if self.svd_solver == "arpack" and is_array_api_compliant: raise ValueError( "PCA with svd_solver='arpack' is not supported for Array API inputs." ) + # Validate the data, without ever forcing a copy as any solver that + # supports sparse input data and the `covariance_eigh` solver are + # written in a way to avoid the need for any inplace modification of + # the input data contrary to the other solvers. + # The copy will happen + # later, only if needed, once the solver negotiation below is done. X = self._validate_data( X, dtype=[xp.float64, xp.float32], accept_sparse=("csr", "csc"), ensure_2d=True, - copy=self.copy, + copy=False, ) self._fit_svd_solver = self.svd_solver if self._fit_svd_solver == "auto" and issparse(X): @@ -495,8 +528,12 @@ def _fit(self, X): n_components = self.n_components if self._fit_svd_solver == "auto": + # Tall and skinny problems are best handled by precomputing the + # covariance matrix. + if X.shape[1] <= 1_000 and X.shape[0] >= 10 * X.shape[1]: + self._fit_svd_solver = "covariance_eigh" # Small problem or n_components == 'mle', just call full PCA - if max(X.shape) <= 500 or n_components == "mle": + elif max(X.shape) <= 500 or n_components == "mle": self._fit_svd_solver = "full" elif 1 <= n_components < 0.8 * min(X.shape): self._fit_svd_solver = "randomized" @@ -504,15 +541,14 @@ def _fit(self, X): else: self._fit_svd_solver = "full" - if self._fit_svd_solver == "full": - return self._fit_full(X, n_components) + # Call different fits for either full or truncated SVD + if self._fit_svd_solver in ("full", "covariance_eigh"): + return self._fit_full(X, n_components, xp, is_array_api_compliant) elif self._fit_svd_solver in ["arpack", "randomized"]: - return self._fit_truncated(X, n_components, self._fit_svd_solver) + return self._fit_truncated(X, n_components, xp) - def _fit_full(self, X, n_components): + def _fit_full(self, X, n_components, xp, is_array_api_compliant): """Fit the model by computing full SVD on X.""" - xp, is_array_api_compliant = get_namespace(X) - n_samples, n_features = X.shape if n_components == "mle": @@ -522,33 +558,96 @@ def _fit_full(self, X, n_components): ) elif not 0 <= n_components <= min(n_samples, n_features): raise ValueError( - "n_components=%r must be between 0 and " - "min(n_samples, n_features)=%r with " - "svd_solver='full'" % (n_components, min(n_samples, n_features)) + f"n_components={n_components} must be between 0 and " + f"min(n_samples, n_features)={min(n_samples, n_features)} with " + f"svd_solver={self._fit_svd_solver!r}" ) - # Center data self.mean_ = xp.mean(X, axis=0) - X -= self.mean_ - - if not is_array_api_compliant: - # Use scipy.linalg with NumPy/SciPy inputs for the sake of not - # introducing unanticipated behavior changes. In the long run we - # could instead decide to always use xp.linalg.svd for all inputs, - # but that would make this code rely on numpy's SVD instead of - # scipy's. It's not 100% clear whether they use the same LAPACK - # solver by default though (assuming both are built against the - # same BLAS). - U, S, Vt = linalg.svd(X, full_matrices=False) + # When X is a scipy sparse matrix, self.mean_ is a numpy matrix, so we need + # to transform it to a 1D array. Note that this is not the case when X + # is a scipy sparse array. + # TODO: remove the following two lines when scikit-learn only depends + # on scipy versions that no longer support scipy.sparse matrices. + self.mean_ = xp.reshape(xp.asarray(self.mean_), (-1,)) + + if self._fit_svd_solver == "full": + X_centered = xp.asarray(X, copy=True) if self.copy else X + X_centered -= self.mean_ + x_is_centered = not self.copy + + if not is_array_api_compliant: + # Use scipy.linalg with NumPy/SciPy inputs for the sake of not + # introducing unanticipated behavior changes. In the long run we + # could instead decide to always use xp.linalg.svd for all inputs, + # but that would make this code rely on numpy's SVD instead of + # scipy's. It's not 100% clear whether they use the same LAPACK + # solver by default though (assuming both are built against the + # same BLAS). + U, S, Vt = linalg.svd(X_centered, full_matrices=False) + else: + U, S, Vt = xp.linalg.svd(X_centered, full_matrices=False) + explained_variance_ = (S**2) / (n_samples - 1) + else: - U, S, Vt = xp.linalg.svd(X, full_matrices=False) + assert self._fit_svd_solver == "covariance_eigh" + # In the following, we center the covariance matrix C afterwards + # (without centering the data X first) to avoid an unnecessary copy + # of X. Note that the mean_ attribute is still needed to center + # test data in the transform method. + # + # Note: at the time of writing, `xp.cov` does not exist in the + # Array API standard: + # https://github.com/data-apis/array-api/issues/43 + # + # Besides, using `numpy.cov`, as of numpy 1.26.0, would not be + # memory efficient for our use case when `n_samples >> n_features`: + # `numpy.cov` centers a copy of the data before computing the + # matrix product instead of subtracting a small `(n_features, + # n_features)` square matrix from the gram matrix X.T @ X, as we do + # below. + x_is_centered = False + C = X.T @ X + C -= ( + n_samples + * xp.reshape(self.mean_, (-1, 1)) + * xp.reshape(self.mean_, (1, -1)) + ) + C /= n_samples - 1 + eigenvals, eigenvecs = xp.linalg.eigh(C) + + # When X is a scipy sparse matrix, the following two datastructures + # are returned as instances of the soft-deprecated numpy.matrix + # class. Note that this problem does not occur when X is a scipy + # sparse array (or another other kind of supported array). + # TODO: remove the following two lines when scikit-learn only + # depends on scipy versions that no longer support scipy.sparse + # matrices. + eigenvals = xp.reshape(xp.asarray(eigenvals), (-1,)) + eigenvecs = xp.asarray(eigenvecs) + + eigenvals = xp.flip(eigenvals, axis=0) + eigenvecs = xp.flip(eigenvecs, axis=1) + + # The covariance matrix C is positive semi-definite by + # construction. However, the eigenvalues returned by xp.linalg.eigh + # can be slightly negative due to numerical errors. This would be + # an issue for the subsequent sqrt, hence the manual clipping. + eigenvals[eigenvals < 0.0] = 0.0 + explained_variance_ = eigenvals + + # Re-construct SVD of centered X indirectly and make it consistent + # with the other solvers. + S = xp.sqrt(eigenvals * (n_samples - 1)) + Vt = eigenvecs.T + U = None + # flip eigenvectors' sign to enforce deterministic output - U, Vt = svd_flip(U, Vt) + U, Vt = svd_flip(U, Vt, u_based_decision=False) components_ = Vt # Get variance explained by singular values - explained_variance_ = (S**2) / (n_samples - 1) total_var = xp.sum(explained_variance_) explained_variance_ratio_ = explained_variance_ / total_var singular_values_ = xp.asarray(S, copy=True) # Store the singular values. @@ -588,22 +687,33 @@ def _fit_full(self, X, n_components): self.noise_variance_ = 0.0 self.n_samples_ = n_samples - self.components_ = components_[:n_components, :] self.n_components_ = n_components - self.explained_variance_ = explained_variance_[:n_components] - self.explained_variance_ratio_ = explained_variance_ratio_[:n_components] - self.singular_values_ = singular_values_[:n_components] + # Assign a copy of the result of the truncation of the components in + # order to: + # - release the memory used by the discarded components, + # - ensure that the kept components are allocated contiguously in + # memory to make the transform method faster by leveraging cache + # locality. + self.components_ = xp.asarray(components_[:n_components, :], copy=True) + + # We do the same for the other arrays for the sake of consistency. + self.explained_variance_ = xp.asarray( + explained_variance_[:n_components], copy=True + ) + self.explained_variance_ratio_ = xp.asarray( + explained_variance_ratio_[:n_components], copy=True + ) + self.singular_values_ = xp.asarray(singular_values_[:n_components], copy=True) - return U, S, Vt + return U, S, Vt, X, x_is_centered, xp - def _fit_truncated(self, X, n_components, svd_solver): + def _fit_truncated(self, X, n_components, xp): """Fit the model by computing truncated SVD (by ARPACK or randomized) on X. """ - xp, _ = get_namespace(X) - n_samples, n_features = X.shape + svd_solver = self._fit_svd_solver if isinstance(n_components, str): raise ValueError( "n_components=%r cannot be a string with svd_solver='%s'" @@ -631,31 +741,35 @@ def _fit_truncated(self, X, n_components, svd_solver): if issparse(X): self.mean_, var = mean_variance_axis(X, axis=0) total_var = var.sum() * n_samples / (n_samples - 1) # ddof=1 - X = _implicit_column_offset(X, self.mean_) + X_centered = _implicit_column_offset(X, self.mean_) + x_is_centered = False else: self.mean_ = xp.mean(X, axis=0) - X -= self.mean_ + X_centered = xp.asarray(X, copy=True) if self.copy else X + X_centered -= self.mean_ + x_is_centered = not self.copy if svd_solver == "arpack": v0 = _init_arpack_v0(min(X.shape), random_state) - U, S, Vt = svds(X, k=n_components, tol=self.tol, v0=v0) + U, S, Vt = svds(X_centered, k=n_components, tol=self.tol, v0=v0) # svds doesn't abide by scipy.linalg.svd/randomized_svd # conventions, so reverse its outputs. S = S[::-1] # flip eigenvectors' sign to enforce deterministic output - U, Vt = svd_flip(U[:, ::-1], Vt[::-1]) + U, Vt = svd_flip(U[:, ::-1], Vt[::-1], u_based_decision=False) elif svd_solver == "randomized": # sign flipping is done inside U, S, Vt = randomized_svd( - X, + X_centered, n_components=n_components, n_oversamples=self.n_oversamples, n_iter=self.iterated_power, power_iteration_normalizer=self.power_iteration_normalizer, - flip_sign=True, + flip_sign=False, random_state=random_state, ) + U, Vt = svd_flip(U, Vt, u_based_decision=False) self.n_samples_ = n_samples self.components_ = Vt @@ -673,8 +787,8 @@ def _fit_truncated(self, X, n_components, svd_solver): # See: https://github.com/scikit-learn/scikit-learn/pull/18689#discussion_r1335540991 if total_var is None: N = X.shape[0] - 1 - X **= 2 - total_var = xp.sum(X) / N + X_centered **= 2 + total_var = xp.sum(X_centered) / N self.explained_variance_ratio_ = self.explained_variance_ / total_var self.singular_values_ = xp.asarray(S, copy=True) # Store the singular values. @@ -685,7 +799,7 @@ def _fit_truncated(self, X, n_components, svd_solver): else: self.noise_variance_ = 0.0 - return U, S, Vt + return U, S, Vt, X, x_is_centered, xp def score_samples(self, X): """Return the log-likelihood of each sample. diff --git a/sklearn/decomposition/_sparse_pca.py b/sklearn/decomposition/_sparse_pca.py index fa711ce8c0703..b284e784d4466 100644 --- a/sklearn/decomposition/_sparse_pca.py +++ b/sklearn/decomposition/_sparse_pca.py @@ -325,7 +325,7 @@ def _fit(self, X, n_components, random_state): return_n_iter=True, ) # flip eigenvectors' sign to enforce deterministic output - code, dictionary = svd_flip(code, dictionary, u_based_decision=False) + code, dictionary = svd_flip(code, dictionary, u_based_decision=True) self.components_ = code.T components_norm = np.linalg.norm(self.components_, axis=1)[:, np.newaxis] components_norm[components_norm == 0] = 1 diff --git a/sklearn/decomposition/_truncated_svd.py b/sklearn/decomposition/_truncated_svd.py index d238f35cb2167..d978191f104f7 100644 --- a/sklearn/decomposition/_truncated_svd.py +++ b/sklearn/decomposition/_truncated_svd.py @@ -234,7 +234,8 @@ def fit_transform(self, X, y=None): # svds doesn't abide by scipy.linalg.svd/randomized_svd # conventions, so reverse its outputs. Sigma = Sigma[::-1] - U, VT = svd_flip(U[:, ::-1], VT[::-1]) + # u_based_decision=False is needed to be consistent with PCA. + U, VT = svd_flip(U[:, ::-1], VT[::-1], u_based_decision=False) elif self.algorithm == "randomized": if self.n_components > X.shape[1]: @@ -249,7 +250,9 @@ def fit_transform(self, X, y=None): n_oversamples=self.n_oversamples, power_iteration_normalizer=self.power_iteration_normalizer, random_state=random_state, + flip_sign=False, ) + U, VT = svd_flip(U, VT, u_based_decision=False) self.components_ = VT diff --git a/sklearn/decomposition/tests/test_incremental_pca.py b/sklearn/decomposition/tests/test_incremental_pca.py index 646aad2db795d..50ddf39b04503 100644 --- a/sklearn/decomposition/tests/test_incremental_pca.py +++ b/sklearn/decomposition/tests/test_incremental_pca.py @@ -4,7 +4,7 @@ import numpy as np import pytest -from numpy.testing import assert_array_equal +from numpy.testing import assert_allclose, assert_array_equal from sklearn import datasets from sklearn.decomposition import PCA, IncrementalPCA @@ -384,25 +384,38 @@ def test_singular_values(): assert_array_almost_equal(ipca.singular_values_, [3.142, 2.718, 1.0], 14) -def test_whitening(): +def test_whitening(global_random_seed): # Test that PCA and IncrementalPCA transforms match to sign flip. X = datasets.make_low_rank_matrix( - 1000, 10, tail_strength=0.0, effective_rank=2, random_state=1999 + 1000, 10, tail_strength=0.0, effective_rank=2, random_state=global_random_seed ) - prec = 3 - n_samples, n_features = X.shape + atol = 1e-3 for nc in [None, 9]: pca = PCA(whiten=True, n_components=nc).fit(X) ipca = IncrementalPCA(whiten=True, n_components=nc, batch_size=250).fit(X) + # Since the data is rank deficient, some components are pure noise. We + # should not expect those dimensions to carry any signal and their + # values might be arbitrarily changed by implementation details of the + # internal SVD solver. We therefore filter them out before comparison. + stable_mask = pca.explained_variance_ratio_ > 1e-12 + Xt_pca = pca.transform(X) Xt_ipca = ipca.transform(X) - assert_almost_equal(np.abs(Xt_pca), np.abs(Xt_ipca), decimal=prec) + assert_allclose( + np.abs(Xt_pca)[:, stable_mask], + np.abs(Xt_ipca)[:, stable_mask], + atol=atol, + ) + + # The noisy dimensions are in the null space of the inverse transform, + # so they are not influencing the reconstruction. We therefore don't + # need to apply the mask here. Xinv_ipca = ipca.inverse_transform(Xt_ipca) Xinv_pca = pca.inverse_transform(Xt_pca) - assert_almost_equal(X, Xinv_ipca, decimal=prec) - assert_almost_equal(X, Xinv_pca, decimal=prec) - assert_almost_equal(Xinv_pca, Xinv_ipca, decimal=prec) + assert_allclose(X, Xinv_ipca, atol=atol) + assert_allclose(X, Xinv_pca, atol=atol) + assert_allclose(Xinv_pca, Xinv_ipca, atol=atol) def test_incremental_pca_partial_fit_float_division(): diff --git a/sklearn/decomposition/tests/test_pca.py b/sklearn/decomposition/tests/test_pca.py index b0fd32d1cbf62..d099bf9a91e00 100644 --- a/sklearn/decomposition/tests/test_pca.py +++ b/sklearn/decomposition/tests/test_pca.py @@ -8,7 +8,7 @@ from sklearn import config_context, datasets from sklearn.base import clone -from sklearn.datasets import load_iris, make_classification +from sklearn.datasets import load_iris, make_classification, make_low_rank_matrix from sklearn.decomposition import PCA from sklearn.decomposition._pca import _assess_dimension, _infer_dimension from sklearn.utils._array_api import ( @@ -25,7 +25,7 @@ from sklearn.utils.fixes import CSC_CONTAINERS, CSR_CONTAINERS iris = datasets.load_iris() -PCA_SOLVERS = ["full", "arpack", "randomized", "auto"] +PCA_SOLVERS = ["full", "covariance_eigh", "arpack", "randomized", "auto"] # `SPARSE_M` and `SPARSE_N` could be larger, but be aware: # * SciPy's generation of random sparse matrix can be costly @@ -70,7 +70,7 @@ def test_pca(svd_solver, n_components): @pytest.mark.parametrize("density", [0.01, 0.1, 0.30]) @pytest.mark.parametrize("n_components", [1, 2, 10]) @pytest.mark.parametrize("sparse_container", CSR_CONTAINERS + CSC_CONTAINERS) -@pytest.mark.parametrize("svd_solver", ["arpack"]) +@pytest.mark.parametrize("svd_solver", ["arpack", "covariance_eigh"]) @pytest.mark.parametrize("scale", [1, 10, 100]) def test_pca_sparse( global_random_seed, svd_solver, sparse_container, n_components, density, scale @@ -172,8 +172,8 @@ def test_sparse_pca_solver_error(global_random_seed, svd_solver, sparse_containe ) pca = PCA(n_components=30, svd_solver=svd_solver) error_msg_pattern = ( - f'PCA only support sparse inputs with the "arpack" solver, while "{svd_solver}"' - " was passed" + 'PCA only support sparse inputs with the "arpack" and "covariance_eigh"' + f' solvers, while "{svd_solver}" was passed' ) with pytest.raises(TypeError, match=error_msg_pattern): pca.fit(X) @@ -263,35 +263,154 @@ def test_whitening(solver, copy): # we always center, so no test for non-centering. -@pytest.mark.parametrize("svd_solver", ["arpack", "randomized"]) -def test_pca_explained_variance_equivalence_solver(svd_solver): - rng = np.random.RandomState(0) - n_samples, n_features = 100, 80 - X = rng.randn(n_samples, n_features) - - pca_full = PCA(n_components=2, svd_solver="full") - pca_other = PCA(n_components=2, svd_solver=svd_solver, random_state=0) - - pca_full.fit(X) - pca_other.fit(X) - - assert_allclose( - pca_full.explained_variance_, pca_other.explained_variance_, rtol=5e-2 +@pytest.mark.parametrize( + "other_svd_solver", sorted(list(set(PCA_SOLVERS) - {"full", "auto"})) +) +@pytest.mark.parametrize("data_shape", ["tall", "wide"]) +@pytest.mark.parametrize("rank_deficient", [False, True]) +@pytest.mark.parametrize("whiten", [False, True]) +def test_pca_solver_equivalence( + other_svd_solver, + data_shape, + rank_deficient, + whiten, + global_random_seed, + global_dtype, +): + if data_shape == "tall": + n_samples, n_features = 100, 30 + else: + n_samples, n_features = 30, 100 + n_samples_test = 10 + + if rank_deficient: + rng = np.random.default_rng(global_random_seed) + rank = min(n_samples, n_features) // 2 + X = rng.standard_normal( + size=(n_samples + n_samples_test, rank) + ) @ rng.standard_normal(size=(rank, n_features)) + else: + X = make_low_rank_matrix( + n_samples=n_samples + n_samples_test, + n_features=n_features, + tail_strength=0.5, + random_state=global_random_seed, + ) + # With a non-zero tail strength, the data is actually full-rank. + rank = min(n_samples, n_features) + + X = X.astype(global_dtype, copy=False) + X_train, X_test = X[:n_samples], X[n_samples:] + + if global_dtype == np.float32: + tols = dict(atol=1e-2, rtol=1e-5) + variance_threshold = 1e-5 + else: + tols = dict(atol=1e-10, rtol=1e-12) + variance_threshold = 1e-12 + + extra_other_kwargs = {} + if other_svd_solver == "randomized": + # Only check for a truncated result with a large number of iterations + # to make sure that we can recover precise results. + n_components = 10 + extra_other_kwargs = {"iterated_power": 50} + elif other_svd_solver == "arpack": + # Test all components except the last one which cannot be estimated by + # arpack. + n_components = np.minimum(n_samples, n_features) - 1 + else: + # Test all components to high precision. + n_components = None + + pca_full = PCA(n_components=n_components, svd_solver="full", whiten=whiten) + pca_other = PCA( + n_components=n_components, + svd_solver=other_svd_solver, + whiten=whiten, + random_state=global_random_seed, + **extra_other_kwargs, ) + X_trans_full_train = pca_full.fit_transform(X_train) + assert np.isfinite(X_trans_full_train).all() + assert X_trans_full_train.dtype == global_dtype + X_trans_other_train = pca_other.fit_transform(X_train) + assert np.isfinite(X_trans_other_train).all() + assert X_trans_other_train.dtype == global_dtype + + assert (pca_full.explained_variance_ >= 0).all() + assert_allclose(pca_full.explained_variance_, pca_other.explained_variance_, **tols) assert_allclose( pca_full.explained_variance_ratio_, pca_other.explained_variance_ratio_, - rtol=5e-2, + **tols, + ) + reference_components = pca_full.components_ + assert np.isfinite(reference_components).all() + other_components = pca_other.components_ + assert np.isfinite(other_components).all() + + # For some choice of n_components and data distribution, some components + # might be pure noise, let's ignore them in the comparison: + stable = pca_full.explained_variance_ > variance_threshold + assert stable.sum() > 1 + assert_allclose(reference_components[stable], other_components[stable], **tols) + + # As a result the output of fit_transform should be the same: + assert_allclose( + X_trans_other_train[:, stable], X_trans_full_train[:, stable], **tols ) + # And similarly for the output of transform on new data (except for the + # last component that can be underdetermined): + X_trans_full_test = pca_full.transform(X_test) + assert np.isfinite(X_trans_full_test).all() + assert X_trans_full_test.dtype == global_dtype + X_trans_other_test = pca_other.transform(X_test) + assert np.isfinite(X_trans_other_test).all() + assert X_trans_other_test.dtype == global_dtype + assert_allclose(X_trans_other_test[:, stable], X_trans_full_test[:, stable], **tols) + + # Check that inverse transform reconstructions for both solvers are + # compatible. + X_recons_full_test = pca_full.inverse_transform(X_trans_full_test) + assert np.isfinite(X_recons_full_test).all() + assert X_recons_full_test.dtype == global_dtype + X_recons_other_test = pca_other.inverse_transform(X_trans_other_test) + assert np.isfinite(X_recons_other_test).all() + assert X_recons_other_test.dtype == global_dtype + + if pca_full.components_.shape[0] == pca_full.components_.shape[1]: + # In this case, the models should have learned the same invertible + # transform. They should therefore both be able to reconstruct the test + # data. + assert_allclose(X_recons_full_test, X_test, **tols) + assert_allclose(X_recons_other_test, X_test, **tols) + elif pca_full.components_.shape[0] < rank: + # In the absence of noisy components, both models should be able to + # reconstruct the same low-rank approximation of the original data. + assert pca_full.explained_variance_.min() > variance_threshold + assert_allclose(X_recons_full_test, X_recons_other_test, **tols) + else: + # When n_features > n_samples and n_components is larger than the rank + # of the training set, the output of the `inverse_transform` function + # is ill-defined. We can only check that we reach the same fixed point + # after another round of transform: + assert_allclose( + pca_full.transform(X_recons_full_test)[:, stable], + pca_other.transform(X_recons_other_test)[:, stable], + **tols, + ) + @pytest.mark.parametrize( "X", [ np.random.RandomState(0).randn(100, 80), datasets.make_classification(100, 80, n_informative=78, random_state=0)[0], + np.random.RandomState(0).randn(10, 100), ], - ids=["random-data", "correlated-data"], + ids=["random-tall", "correlated-tall", "random-wide"], ) @pytest.mark.parametrize("svd_solver", PCA_SOLVERS) def test_pca_explained_variance_empirical(X, svd_solver): @@ -629,23 +748,28 @@ def test_pca_zero_noise_variance_edge_cases(svd_solver): @pytest.mark.parametrize( - "data, n_components, expected_solver", - [ # case: n_components in (0,1) => 'full' - (np.random.RandomState(0).uniform(size=(1000, 50)), 0.5, "full"), - # case: max(X.shape) <= 500 => 'full' - (np.random.RandomState(0).uniform(size=(10, 50)), 5, "full"), + "n_samples, n_features, n_components, expected_solver", + [ + # case: n_samples < 10 * n_features and max(X.shape) <= 500 => 'full' + (10, 50, 5, "full"), + # case: n_samples > 10 * n_features and n_features < 500 => 'covariance_eigh' + (1000, 50, 50, "covariance_eigh"), # case: n_components >= .8 * min(X.shape) => 'full' - (np.random.RandomState(0).uniform(size=(1000, 50)), 50, "full"), + (1000, 500, 400, "full"), # n_components >= 1 and n_components < .8*min(X.shape) => 'randomized' - (np.random.RandomState(0).uniform(size=(1000, 50)), 10, "randomized"), + (1000, 500, 10, "randomized"), + # case: n_components in (0,1) => 'full' + (1000, 500, 0.5, "full"), ], ) -def test_pca_svd_solver_auto(data, n_components, expected_solver): +def test_pca_svd_solver_auto(n_samples, n_features, n_components, expected_solver): + data = np.random.RandomState(0).uniform(size=(n_samples, n_features)) pca_auto = PCA(n_components=n_components, random_state=0) pca_test = PCA( n_components=n_components, svd_solver=expected_solver, random_state=0 ) pca_auto.fit(data) + assert pca_auto._fit_svd_solver == expected_solver pca_test.fit(data) assert_allclose(pca_auto.components_, pca_test.components_) @@ -663,28 +787,33 @@ def test_pca_deterministic_output(svd_solver): @pytest.mark.parametrize("svd_solver", PCA_SOLVERS) -def test_pca_dtype_preservation(svd_solver): - check_pca_float_dtype_preservation(svd_solver) +def test_pca_dtype_preservation(svd_solver, global_random_seed): + check_pca_float_dtype_preservation(svd_solver, global_random_seed) check_pca_int_dtype_upcast_to_double(svd_solver) -def check_pca_float_dtype_preservation(svd_solver): +def check_pca_float_dtype_preservation(svd_solver, seed): # Ensure that PCA does not upscale the dtype when input is float32 - X_64 = np.random.RandomState(0).rand(1000, 4).astype(np.float64, copy=False) - X_32 = X_64.astype(np.float32) + X = np.random.RandomState(seed).rand(1000, 4) + X_float64 = X.astype(np.float64, copy=False) + X_float32 = X.astype(np.float32) - pca_64 = PCA(n_components=3, svd_solver=svd_solver, random_state=0).fit(X_64) - pca_32 = PCA(n_components=3, svd_solver=svd_solver, random_state=0).fit(X_32) + pca_64 = PCA(n_components=3, svd_solver=svd_solver, random_state=seed).fit( + X_float64 + ) + pca_32 = PCA(n_components=3, svd_solver=svd_solver, random_state=seed).fit( + X_float32 + ) assert pca_64.components_.dtype == np.float64 assert pca_32.components_.dtype == np.float32 - assert pca_64.transform(X_64).dtype == np.float64 - assert pca_32.transform(X_32).dtype == np.float32 + assert pca_64.transform(X_float64).dtype == np.float64 + assert pca_32.transform(X_float32).dtype == np.float32 - # the rtol is set such that the test passes on all platforms tested on - # conda-forge: PR#15775 - # see: https://github.com/conda-forge/scikit-learn-feedstock/pull/113 - assert_allclose(pca_64.components_, pca_32.components_, rtol=2e-4) + # The atol and rtol are set such that the test passes for all random seeds + # on all supported platforms on our CI and conda-forge with the default + # random seed. + assert_allclose(pca_64.components_, pca_32.components_, rtol=1e-3, atol=1e-3) def check_pca_int_dtype_upcast_to_double(svd_solver): @@ -844,6 +973,7 @@ def check_array_api_get_precision(name, estimator, array_namespace, device, dtyp precision_np = estimator.get_precision() covariance_np = estimator.get_covariance() + rtol = 2e-4 if iris_np.dtype == "float32" else 2e-7 with config_context(array_api_dispatch=True): estimator_xp = clone(estimator).fit(iris_xp) precision_xp = estimator_xp.get_precision() @@ -853,6 +983,7 @@ def check_array_api_get_precision(name, estimator, array_namespace, device, dtyp assert_allclose( _convert_to_numpy(precision_xp, xp=xp), precision_np, + rtol=rtol, atol=_atol_for_type(dtype_name), ) covariance_xp = estimator_xp.get_covariance() @@ -862,6 +993,7 @@ def check_array_api_get_precision(name, estimator, array_namespace, device, dtyp assert_allclose( _convert_to_numpy(covariance_xp, xp=xp), covariance_np, + rtol=rtol, atol=_atol_for_type(dtype_name), ) @@ -878,7 +1010,10 @@ def check_array_api_get_precision(name, estimator, array_namespace, device, dtyp "estimator", [ PCA(n_components=2, svd_solver="full"), + PCA(n_components=2, svd_solver="full", whiten=True), PCA(n_components=0.1, svd_solver="full", whiten=True), + PCA(n_components=2, svd_solver="covariance_eigh"), + PCA(n_components=2, svd_solver="covariance_eigh", whiten=True), PCA( n_components=2, svd_solver="randomized", diff --git a/sklearn/decomposition/tests/test_sparse_pca.py b/sklearn/decomposition/tests/test_sparse_pca.py index 3797970e3d6ba..532d8dbd5e82f 100644 --- a/sklearn/decomposition/tests/test_sparse_pca.py +++ b/sklearn/decomposition/tests/test_sparse_pca.py @@ -14,6 +14,7 @@ assert_array_almost_equal, if_safe_multiprocessing_with_blas, ) +from sklearn.utils.extmath import svd_flip def generate_toy_data(n_components, n_samples, image_size, random_state=None): @@ -114,7 +115,10 @@ def test_initialization(): n_components=3, U_init=U_init, V_init=V_init, max_iter=0, random_state=rng ) model.fit(rng.randn(5, 4)) - assert_allclose(model.components_, V_init / np.linalg.norm(V_init, axis=1)[:, None]) + + expected_components = V_init / np.linalg.norm(V_init, axis=1, keepdims=True) + expected_components = svd_flip(u=expected_components.T, v=None)[0].T + assert_allclose(model.components_, expected_components) def test_mini_batch_correct_shapes(): diff --git a/sklearn/pipeline.py b/sklearn/pipeline.py index b26b83e66510f..1b17599068d7a 100644 --- a/sklearn/pipeline.py +++ b/sklearn/pipeline.py @@ -1412,12 +1412,12 @@ class FeatureUnion(TransformerMixin, _BaseComposition): ... ("svd", TruncatedSVD(n_components=2))]) >>> X = [[0., 1., 3], [2., 2., 5]] >>> union.fit_transform(X) - array([[ 1.5 , 3.0..., 0.8...], - [-1.5 , 5.7..., -0.4...]]) + array([[-1.5 , 3.0..., -0.8...], + [ 1.5 , 5.7..., 0.4...]]) >>> # An estimator's parameter can be set using '__' syntax >>> union.set_params(svd__n_components=1).fit_transform(X) - array([[ 1.5 , 3.0...], - [-1.5 , 5.7...]]) + array([[-1.5 , 3.0...], + [ 1.5 , 5.7...]]) For a more detailed example of usage, see :ref:`sphx_glr_auto_examples_compose_plot_feature_union.py`. diff --git a/sklearn/utils/extmath.py b/sklearn/utils/extmath.py index 2fe7dbc3cc179..44f70deaa3f18 100644 --- a/sklearn/utils/extmath.py +++ b/sklearn/utils/extmath.py @@ -864,12 +864,14 @@ def svd_flip(u, v, u_based_decision=True): Parameters u and v are the output of `linalg.svd` or :func:`~sklearn.utils.extmath.randomized_svd`, with matching inner dimensions so one can compute `np.dot(u * s, v)`. + u can be None if `u_based_decision` is False. v : ndarray Parameters u and v are the output of `linalg.svd` or :func:`~sklearn.utils.extmath.randomized_svd`, with matching inner dimensions so one can compute `np.dot(u * s, v)`. The input v should really be called vt to be consistent with scipy's output. + v can be None if `u_based_decision` is True. u_based_decision : bool, default=True If True, use the columns of u as the basis for sign flipping. @@ -884,24 +886,25 @@ def svd_flip(u, v, u_based_decision=True): v_adjusted : ndarray Array v with adjusted rows and the same dimensions as v. """ - xp, _ = get_namespace(u, v) - device = getattr(u, "device", None) + xp, _ = get_namespace(*[a for a in [u, v] if a is not None]) if u_based_decision: # columns of u, rows of v, or equivalently rows of u.T and v max_abs_u_cols = xp.argmax(xp.abs(u.T), axis=1) - shift = xp.arange(u.T.shape[0], device=device) + shift = xp.arange(u.T.shape[0], device=device(u)) indices = max_abs_u_cols + shift * u.T.shape[1] signs = xp.sign(xp.take(xp.reshape(u.T, (-1,)), indices, axis=0)) u *= signs[np.newaxis, :] - v *= signs[:, np.newaxis] + if v is not None: + v *= signs[:, np.newaxis] else: # rows of v, columns of u max_abs_v_rows = xp.argmax(xp.abs(v), axis=1) - shift = xp.arange(v.shape[0], device=device) + shift = xp.arange(v.shape[0], device=device(v)) indices = max_abs_v_rows + shift * v.shape[1] - signs = xp.sign(xp.take(xp.reshape(v, (-1,)), indices)) - u *= signs[np.newaxis, :] + signs = xp.sign(xp.take(xp.reshape(v, (-1,)), indices, axis=0)) + if u is not None: + u *= signs[np.newaxis, :] v *= signs[:, np.newaxis] return u, v