Skip to content

Add svd_solver="lobpcg" option to PCA #30075

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 27 additions & 15 deletions sklearn/decomposition/_pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ class PCA(_BasePCA):
number of components such that the amount of variance that needs to be
explained is greater than the percentage specified by n_components.

If ``svd_solver == 'arpack'``, the number of components must be
If ``svd_solver in {'arpack', 'lobpcg'}``, the number of components must be
strictly less than the minimum of n_features and n_samples.

Hence, the None case results in::
Expand All @@ -174,7 +174,7 @@ class PCA(_BasePCA):
improve the predictive accuracy of the downstream estimators by
making their data respect some hard-wired assumptions.

svd_solver : {'auto', 'full', 'covariance_eigh', 'arpack', 'randomized'},\
svd_solver : {'auto', 'full', 'covariance_eigh', 'arpack', 'lobpcg', 'randomized'},\
default='auto'
"auto" :
The solver is selected by a default 'auto' policy is based on `X.shape` and
Expand All @@ -200,7 +200,8 @@ class PCA(_BasePCA):
therefore less numerical stable (e.g. on input data with a large
range of singular values).
"arpack" :
Run SVD truncated to `n_components` calling ARPACK solver via
"lobpcg" :
Run SVD truncated to `n_components` calling ARPACK or LOBPCG solver via
`scipy.sparse.linalg.svds`. It requires strictly
`0 < n_components < min(X.shape)`
"randomized" :
Expand Down Expand Up @@ -332,7 +333,7 @@ class PCA(_BasePCA):
<http://www.miketipping.com/papers/met-mppca.pdf>`_
via the score and score_samples methods.

For svd_solver == 'arpack', refer to `scipy.sparse.linalg.svds`.
For ``svd_solver in {'arpack', 'lobpcg'}``, refer to `scipy.sparse.linalg.svds`.

For svd_solver == 'randomized', see:
:doi:`Halko, N., Martinsson, P. G., and Tropp, J. A. (2011).
Expand Down Expand Up @@ -386,7 +387,9 @@ class PCA(_BasePCA):
"copy": ["boolean"],
"whiten": ["boolean"],
"svd_solver": [
StrOptions({"auto", "full", "covariance_eigh", "arpack", "randomized"})
StrOptions(
{"auto", "full", "covariance_eigh", "arpack", "lobpcg", "randomized"}
)
],
"tol": [Interval(Real, 0, None, closed="left")],
"iterated_power": [
Expand Down Expand Up @@ -485,15 +488,20 @@ def _fit(self, X):
xp, is_array_api_compliant = get_namespace(X)

# Raise an error for sparse input and unsupported svd_solver
if issparse(X) and self.svd_solver not in ["auto", "arpack", "covariance_eigh"]:
if issparse(X) and self.svd_solver not in {
"auto",
"arpack",
"lobpcg",
"covariance_eigh",
}:
raise TypeError(
'PCA only support sparse inputs with the "arpack" and'
'PCA only support sparse inputs with the "arpack", "lobpcg", and'
f' "covariance_eigh" solvers, while "{self.svd_solver}" was passed. See'
" TruncatedSVD for a possible alternative."
)
if self.svd_solver == "arpack" and is_array_api_compliant:
if self.svd_solver in {"arpack", "lobpcg"} and is_array_api_compliant:
raise ValueError(
"PCA with svd_solver='arpack' is not supported for Array API inputs."
f"PCA with {self.svd_solver=} is not supported for Array API inputs."
)

# Validate the data, without ever forcing a copy as any solver that
Expand All @@ -516,7 +524,7 @@ def _fit(self, X):
self._fit_svd_solver = "arpack"

if self.n_components is None:
if self._fit_svd_solver != "arpack":
if self._fit_svd_solver not in {"arpack", "lobpcg"}:
n_components = min(X.shape)
else:
n_components = min(X.shape) - 1
Expand All @@ -540,7 +548,7 @@ def _fit(self, X):
# Call different fits for either full or truncated SVD
if self._fit_svd_solver in ("full", "covariance_eigh"):
return self._fit_full(X, n_components, xp, is_array_api_compliant)
elif self._fit_svd_solver in ["arpack", "randomized"]:
elif self._fit_svd_solver in {"arpack", "lobpcg", "randomized"}:
return self._fit_truncated(X, n_components, xp)

def _fit_full(self, X, n_components, xp, is_array_api_compliant):
Expand Down Expand Up @@ -704,7 +712,7 @@ def _fit_full(self, X, n_components, xp, is_array_api_compliant):
return U, S, Vt, X, x_is_centered, xp

def _fit_truncated(self, X, n_components, xp):
"""Fit the model by computing truncated SVD (by ARPACK or randomized)
"""Fit the model by computing truncated SVD (by ARPACK, LOBPCG or randomized)
on X.
"""
n_samples, n_features = X.shape
Expand All @@ -722,7 +730,9 @@ def _fit_truncated(self, X, n_components, xp):
"svd_solver='%s'"
% (n_components, min(n_samples, n_features), svd_solver)
)
elif svd_solver == "arpack" and n_components == min(n_samples, n_features):
elif svd_solver in {"arpack", "lobpcg"} and n_components == min(
n_samples, n_features
):
raise ValueError(
"n_components=%r must be strictly less than "
"min(n_samples, n_features)=%r with "
Expand All @@ -745,9 +755,11 @@ def _fit_truncated(self, X, n_components, xp):
X_centered -= self.mean_
x_is_centered = not self.copy

if svd_solver == "arpack":
if svd_solver in {"arpack", "lobpcg"}:
v0 = _init_arpack_v0(min(X.shape), random_state)
U, S, Vt = svds(X_centered, k=n_components, tol=self.tol, v0=v0)
U, S, Vt = svds(
X_centered, k=n_components, tol=self.tol, v0=v0, solver=svd_solver
)
# svds doesn't abide by scipy.linalg.svd/randomized_svd
# conventions, so reverse its outputs.
S = S[::-1]
Expand Down
9 changes: 5 additions & 4 deletions sklearn/decomposition/tests/test_pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from sklearn.utils.fixes import CSC_CONTAINERS, CSR_CONTAINERS

iris = datasets.load_iris()
PCA_SOLVERS = ["full", "covariance_eigh", "arpack", "randomized", "auto"]
PCA_SOLVERS = ["full", "covariance_eigh", "arpack", "lobpcg", "randomized", "auto"]

# `SPARSE_M` and `SPARSE_N` could be larger, but be aware:
# * SciPy's generation of random sparse matrix can be costly
Expand Down Expand Up @@ -73,7 +73,7 @@ def test_pca(svd_solver, n_components):
@pytest.mark.parametrize("density", [0.01, 0.1, 0.30])
@pytest.mark.parametrize("n_components", [1, 2, 10])
@pytest.mark.parametrize("sparse_container", CSR_CONTAINERS + CSC_CONTAINERS)
@pytest.mark.parametrize("svd_solver", ["arpack", "covariance_eigh"])
@pytest.mark.parametrize("svd_solver", ["arpack", "lobpcg", "covariance_eigh"])
@pytest.mark.parametrize("scale", [1, 10, 100])
def test_pca_sparse(
global_random_seed, svd_solver, sparse_container, n_components, density, scale
Expand Down Expand Up @@ -429,7 +429,7 @@ def test_pca_explained_variance_empirical(X, svd_solver):
assert_allclose(pca.explained_variance_, expected_result, rtol=5e-3)


@pytest.mark.parametrize("svd_solver", ["arpack", "randomized"])
@pytest.mark.parametrize("svd_solver", ["arpack", "lobpcg", "randomized"])
def test_pca_singular_values_consistency(svd_solver):
rng = np.random.RandomState(0)
n_samples, n_features = 100, 80
Expand Down Expand Up @@ -564,6 +564,7 @@ def test_pca_validation(svd_solver, data, n_components, err_msg):
[
("full", min(iris.data.shape)),
("arpack", min(iris.data.shape) - 1),
("lobpcg", min(iris.data.shape) - 1),
("randomized", min(iris.data.shape)),
],
)
Expand Down Expand Up @@ -719,7 +720,7 @@ def test_pca_sanity_noise_variance(svd_solver):
assert np.all((pca.explained_variance_ - pca.noise_variance_) >= 0)


@pytest.mark.parametrize("svd_solver", ["arpack", "randomized"])
@pytest.mark.parametrize("svd_solver", ["arpack", "lobpcg", "randomized"])
def test_pca_score_consistency_solvers(svd_solver):
# Check the consistency of score between solvers
X, _ = datasets.load_digits(return_X_y=True)
Expand Down
Loading