Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 30 additions & 22 deletions sklearn/decomposition/pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@
from .base import _BasePCA
from ..utils import check_random_state
from ..utils import check_array
from ..utils.extmath import fast_logdet, randomized_svd, svd_flip
from ..utils.extmath import fast_logdet, randomized_pca, svd_flip
from ..utils.extmath import stable_cumsum
from ..utils.validation import check_is_fitted
from ..utils.sparsefuncs import mean_variance_axis


def _assess_dimension_(spectrum, rank, n_samples, n_features):
Expand Down Expand Up @@ -116,9 +117,6 @@ class PCA(_BasePCA):
It can also use the scipy.sparse.linalg ARPACK implementation of the
truncated SVD.

Notice that this class does not support sparse input. See
:class:`TruncatedSVD` for an alternative with sparse data.

Read more in the :ref:`User Guide <PCA>`.

Parameters
Expand Down Expand Up @@ -161,21 +159,25 @@ class PCA(_BasePCA):

svd_solver : string {'auto', 'full', 'arpack', 'randomized'}
auto :
the solver is selected by a default policy based on `X.shape` and
The solver is selected by a default policy based on `X.shape` and
`n_components`: if the input data is larger than 500x500 and the
number of components to extract is lower than 80% of the smallest
dimension of the data, then the more efficient 'randomized'
method is enabled. Otherwise the exact full SVD is computed and
optionally truncated afterwards.

In case sparse data is used, 'randomized' is used as this is the
only method that supports sparse data.
full :
run exact full SVD calling the standard LAPACK solver via
Run exact full SVD calling the standard LAPACK solver via
`scipy.linalg.svd` and select the components by postprocessing
arpack :
run SVD truncated to n_components calling ARPACK solver via
Run SVD truncated to n_components calling ARPACK solver via
`scipy.sparse.linalg.svds`. It requires strictly
0 < n_components < min(X.shape)
randomized :
run randomized SVD by the method of Halko et al.
Run randomized SVD by the method of Halko et al. This is the only
method that supports sparse data.

.. versionadded:: 0.18.0

Expand Down Expand Up @@ -370,14 +372,8 @@ def fit_transform(self, X, y=None):

def _fit(self, X):
"""Dispatch to the right submethod depending on the chosen solver."""

# Raise an error for sparse input.
# This is more informative than the generic one raised by check_array.
if issparse(X):
raise TypeError('PCA does not support sparse input. See '
'TruncatedSVD for a possible alternative.')

X = check_array(X, dtype=[np.float64, np.float32], ensure_2d=True,
X = check_array(X, accept_sparse=['csr', 'csc'],
dtype=[np.float64, np.float32], ensure_2d=True,
copy=self.copy)

# Handle n_components==None
Expand All @@ -392,15 +388,24 @@ def _fit(self, X):
# Handle svd_solver
self._fit_svd_solver = self.svd_solver
if self._fit_svd_solver == 'auto':
# Sparse data can only be handled with the randomized solver
if issparse(X):
self._fit_svd_solver = 'randomized'
# Small problem or n_components == 'mle', just call full PCA
if max(X.shape) <= 500 or n_components == 'mle':
elif max(X.shape) <= 500 or n_components == 'mle':
self._fit_svd_solver = 'full'
elif n_components >= 1 and n_components < .8 * min(X.shape):
self._fit_svd_solver = 'randomized'
# This is also the case of n_components in (0,1)
else:
self._fit_svd_solver = 'full'

# Ensure we don't try call arpack or full on a sparse matrix
if issparse(X) and self._fit_svd_solver != 'randomized':
raise ValueError(
'Only the randomized solver supports sparse matrices'
)

# Call different fits for either full or truncated SVD
if self._fit_svd_solver == 'full':
return self._fit_full(X, n_components)
Expand Down Expand Up @@ -503,11 +508,15 @@ def _fit_truncated(self, X, n_components, svd_solver):

random_state = check_random_state(self.random_state)

# Center data
self.mean_ = np.mean(X, axis=0)
X -= self.mean_
if issparse(X):
self.mean_, total_var = mean_variance_axis(X, axis=0, ddof=1)
else:
self.mean_ = np.mean(X, axis=0)
total_var = np.var(X, axis=0, ddof=1)

if svd_solver == 'arpack':
# Center data
X -= self.mean_
# random init solution, as ARPACK does it internally
v0 = random_state.uniform(-1, 1, size=min(X.shape))
U, S, V = svds(X, k=n_components, tol=self.tol, v0=v0)
Expand All @@ -519,7 +528,7 @@ def _fit_truncated(self, X, n_components, svd_solver):

elif svd_solver == 'randomized':
# sign flipping is done inside
U, S, V = randomized_svd(X, n_components=n_components,
U, S, V = randomized_pca(X, n_components=n_components,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should play it safe here and only call randomized_pca when X is sparse. Who knows, there might be bugs in randomized_pca, and randomized_svd is probably faster for dense matrices.

n_iter=self.iterated_power,
flip_sign=True,
random_state=random_state)
Expand All @@ -530,7 +539,6 @@ def _fit_truncated(self, X, n_components, svd_solver):

# Get variance explained by singular values
self.explained_variance_ = (S ** 2) / (n_samples - 1)
total_var = np.var(X, ddof=1, axis=0)
self.explained_variance_ratio_ = \
self.explained_variance_ / total_var.sum()
self.singular_values_ = S.copy() # Store the singular values.
Expand Down
92 changes: 64 additions & 28 deletions sklearn/decomposition/tests/test_pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from sklearn.utils.testing import assert_no_warnings
from sklearn.utils.testing import ignore_warnings
from sklearn.utils.testing import assert_less
from sklearn.utils.testing import assert_allclose

from sklearn import datasets
from sklearn.decomposition import PCA
Expand Down Expand Up @@ -256,34 +257,33 @@ def test_singular_values():

X = rng.randn(n_samples, n_features)

pca = PCA(n_components=2, svd_solver='full',
random_state=rng).fit(X)
apca = PCA(n_components=2, svd_solver='arpack',
pca = PCA(n_components=2, svd_solver='full', random_state=rng).fit(X)
apca = PCA(n_components=2, svd_solver='arpack', random_state=rng).fit(X)
# Increase the number of power iterations to get greater accuracy in tests
rpca = PCA(n_components=2, svd_solver='randomized', iterated_power=40,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This whole test_singular_values() test should not be changed IMO.

If you have to set iterated_power=40 for the test to pass now, that's a regression, and that confirms my suggestion: randomized_pca should only be called for sparse matrices. Dense matrices should still use randomized_svd.

Also in general, please avoid style fixes that are not related to the issue (like spaces between **, etc). I know you mean well but then it makes it harder to determine what changes are important during the review.

random_state=rng).fit(X)
rpca = PCA(n_components=2, svd_solver='randomized',
random_state=rng).fit(X)
assert_array_almost_equal(pca.singular_values_, apca.singular_values_, 12)
assert_array_almost_equal(pca.singular_values_, rpca.singular_values_, 1)
assert_array_almost_equal(apca.singular_values_, rpca.singular_values_, 1)
assert_allclose(pca.singular_values_, apca.singular_values_)
assert_allclose(pca.singular_values_, rpca.singular_values_)
assert_allclose(apca.singular_values_, rpca.singular_values_)

# Compare to the Frobenius norm
X_pca = pca.transform(X)
X_apca = apca.transform(X)
X_rpca = rpca.transform(X)
assert_array_almost_equal(np.sum(pca.singular_values_**2.0),
np.linalg.norm(X_pca, "fro")**2.0, 12)
assert_array_almost_equal(np.sum(apca.singular_values_**2.0),
np.linalg.norm(X_apca, "fro")**2.0, 9)
assert_array_almost_equal(np.sum(rpca.singular_values_**2.0),
np.linalg.norm(X_rpca, "fro")**2.0, 0)
assert_array_almost_equal(np.sum(pca.singular_values_ ** 2.0),
np.linalg.norm(X_pca, "fro") ** 2.0, 12)
assert_array_almost_equal(np.sum(apca.singular_values_ ** 2.0),
np.linalg.norm(X_apca, "fro") ** 2.0, 9)
assert_array_almost_equal(np.sum(rpca.singular_values_ ** 2.0),
np.linalg.norm(X_rpca, "fro") ** 2.0, 0)

# Compare to the 2-norms of the score vectors
assert_array_almost_equal(pca.singular_values_,
np.sqrt(np.sum(X_pca**2.0, axis=0)), 12)
assert_array_almost_equal(apca.singular_values_,
np.sqrt(np.sum(X_apca**2.0, axis=0)), 12)
assert_array_almost_equal(rpca.singular_values_,
np.sqrt(np.sum(X_rpca**2.0, axis=0)), 2)
assert_allclose(pca.singular_values_,
np.sqrt(np.sum(X_pca ** 2.0, axis=0)))
assert_allclose(apca.singular_values_,
np.sqrt(np.sum(X_apca ** 2.0, axis=0)))
assert_allclose(rpca.singular_values_,
np.sqrt(np.sum(X_rpca ** 2.0, axis=0)))

# Set the singular values and see what we get back
rng = np.random.RandomState(0)
Expand All @@ -297,17 +297,18 @@ def test_singular_values():
rpca = PCA(n_components=3, svd_solver='randomized', random_state=rng)
X_pca = pca.fit_transform(X)

X_pca /= np.sqrt(np.sum(X_pca**2.0, axis=0))
X_pca /= np.sqrt(np.sum(X_pca ** 2.0, axis=0))
X_pca[:, 0] *= 3.142
X_pca[:, 1] *= 2.718

X_hat = np.dot(X_pca, pca.components_)
pca.fit(X_hat)
apca.fit(X_hat)
rpca.fit(X_hat)
assert_array_almost_equal(pca.singular_values_, [3.142, 2.718, 1.0], 14)
assert_array_almost_equal(apca.singular_values_, [3.142, 2.718, 1.0], 14)
assert_array_almost_equal(rpca.singular_values_, [3.142, 2.718, 1.0], 14)

assert_allclose(pca.singular_values_, [3.142, 2.718, 1.0])
assert_allclose(apca.singular_values_, [3.142, 2.718, 1.0])
assert_allclose(rpca.singular_values_, [3.142, 2.718, 1.0])


def test_pca_check_projection():
Expand Down Expand Up @@ -683,15 +684,50 @@ def test_svd_solver_auto():
assert_array_almost_equal(pca.components_, pca_test.components_)


@pytest.mark.parametrize('svd_solver', solver_list)
def test_pca_sparse_input(svd_solver):
def test_pca_sparse_input_randomized_solver():
rng = np.random.RandomState(0)
n_samples = 100
n_features = 80

X = rng.binomial(1, 0.1, (n_samples, n_features))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: use size=

X_sp = sp.sparse.csr_matrix(X)

# Compute the randomized decomposition on the dense matrix
pca = PCA(n_components=3, svd_solver='randomized',
random_state=0).fit(X)
# And compute the randomized decomposition on the sparse matrix.
pca_sp = PCA(n_components=3, svd_solver='randomized',
random_state=0).fit(X_sp)

# Ensure the singular values are close to the exact singular values
assert_allclose(pca_sp.singular_values_, pca.singular_values_)

# Ensure that the basis is close to the true basis
X_pca = pca.transform(X)
X_sppca = pca_sp.transform(X)
assert_allclose(X_sppca, X_pca)


@pytest.mark.parametrize('svd_solver', ['full', 'arpack'])
def test_pca_sparse_input_bad_solvers(svd_solver):
X = np.random.RandomState(0).rand(5, 4)
X = sp.sparse.csr_matrix(X)
assert(sp.sparse.issparse(X))

pca = PCA(n_components=3, svd_solver=svd_solver)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
pca = PCA(n_components=3, svd_solver=svd_solver)
pca = PCA(svd_solver=svd_solver)


assert_raises(TypeError, pca.fit, X)
with pytest.raises(ValueError, match='Only the randomized solver supports '
'sparse matrices'):
pca.fit(X)


def test_pca_auto_solver_selects_randomized_solver_for_sparse_matrices():
X = np.random.RandomState(0).rand(5, 4)
X = sp.sparse.csr_matrix(X)

pca = PCA(n_components=3, svd_solver='auto')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
pca = PCA(n_components=3, svd_solver='auto')
pca = PCA(svd_solver='auto')

pca.fit(X)

assert pca._fit_svd_solver == 'randomized'


def test_pca_bad_solver():
Expand Down
Loading