From b18f2951e4be3ce293a9d0bb61d4ca3be19496ca Mon Sep 17 00:00:00 2001 From: giorgiop Date: Wed, 19 Aug 2015 16:23:07 +0200 Subject: [PATCH] randomized_svd: power iter, normalization, benchmark --- benchmarks/bench_plot_randomized_svd.py | 457 ++++++++++++++++++++++++ doc/whats_new.rst | 37 +- sklearn/decomposition/pca.py | 4 +- sklearn/decomposition/tests/test_pca.py | 20 +- sklearn/decomposition/truncated_svd.py | 4 +- sklearn/utils/extmath.py | 81 +++-- sklearn/utils/tests/test_extmath.py | 119 ++++-- 7 files changed, 645 insertions(+), 77 deletions(-) create mode 100644 benchmarks/bench_plot_randomized_svd.py diff --git a/benchmarks/bench_plot_randomized_svd.py b/benchmarks/bench_plot_randomized_svd.py new file mode 100644 index 0000000000000..c2a347e186a85 --- /dev/null +++ b/benchmarks/bench_plot_randomized_svd.py @@ -0,0 +1,457 @@ +""" +Benchmarks on the power iterations phase in randomized SVD. + +We test on various synthetic and real datasets the effect of increasing +the number of power iterations in terms of quality of approximation +and running time. A number greater than 0 should help with noisy matrices, +which are characterized by a slow spectral decay. + +We test several policy for normalizing the power iterations. Normalization +is crucial to avoid numerical issues. + +The quality of the approximation is measured by the spectral norm discrepancy +between the original input matrix and the reconstructed one (by multiplying +the randomized_svd's outputs). The spectral norm is always equivalent to the +largest singular value of a matrix. (3) justifies this choice. However, one can +notice in these experiments that Frobenius and spectral norms behave +very similarly in a qualitative sense. Therefore, we suggest to run these +benchmarks with `enable_spectral_norm = False`, as Frobenius' is MUCH faster to +compute. + +The benchmarks follow. + +(a) plot: time vs norm, varying number of power iterations + data: many datasets + goal: compare normalization policies and study how the number of power + iterations affect time and norm + +(b) plot: n_iter vs norm, varying rank of data and number of components for + randomized_SVD + data: low-rank matrices on which we control the rank + goal: study whether the rank of the matrix and the number of components + extracted by randomized SVD affect "the optimal" number of power iterations + +(c) plot: time vs norm, varing datasets + data: many datasets + goal: compare default configurations + +We compare the following algorithms: +- randomized_svd(..., power_iteration_normalizer='none') +- randomized_svd(..., power_iteration_normalizer='LU') +- randomized_svd(..., power_iteration_normalizer='QR') +- randomized_svd(..., power_iteration_normalizer='auto') +- fbpca.pca() from https://github.com/facebook/fbpca (if installed) + +Conclusion +---------- +- n_iter=2 appears to be a good default value +- power_iteration_normalizer='none' is OK if n_iter is small, otherwise LU + gives similar errors to QR but is cheaper. That's what 'auto' implements. + +References +---------- +(1) Finding structure with randomness: Stochastic algorithms for constructing + approximate matrix decompositions + Halko, et al., 2009 http://arxiv.org/abs/arXiv:0909.4061 + +(2) A randomized algorithm for the decomposition of matrices + Per-Gunnar Martinsson, Vladimir Rokhlin and Mark Tygert + +(3) An implementation of a randomized algorithm for principal component + analysis + A. Szlam et al. 2014 +""" + +# Author: Giorgio Patrini + +import numpy as np +import scipy as sp +import matplotlib.pyplot as plt + +import gc +import pickle +from time import time +from collections import defaultdict +import os.path + +from sklearn.utils import gen_batches +from sklearn.utils.validation import check_random_state +from sklearn.utils.extmath import randomized_svd +from sklearn.datasets.samples_generator import (make_low_rank_matrix, + make_sparse_uncorrelated) +from sklearn.datasets import (fetch_lfw_people, + fetch_mldata, + fetch_20newsgroups_vectorized, + fetch_olivetti_faces, + fetch_rcv1) + +try: + import fbpca + fbpca_available = True +except ImportError: + fbpca_available = False + +# If this is enabled, tests are much slower and will crash with the large data +enable_spectral_norm = False + +# TODO: compute approximate spectral norms with the power method as in +# Estimating the largest eigenvalues by the power and Lanczos methods with +# a random start, Jacek Kuczynski and Henryk Wozniakowski, SIAM Journal on +# Matrix Analysis and Applications, 13 (4): 1094-1122, 1992. +# This approximation is a very fast estimate of the spectral norm, but depends +# on starting random vectors. + +# Determine when to switch to batch computation for matrix norms, +# in case the reconstructed (dense) matrix is too large +MAX_MEMORY = np.int(2e9) + +# The following datasets can be dowloaded manually from: +# CIFAR 10: http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz +# SVHN: http://ufldl.stanford.edu/housenumbers/train_32x32.mat +CIFAR_FOLDER = "./cifar-10-batches-py/" +SVHN_FOLDER = "./SVHN/" + +datasets = ['low rank matrix', 'lfw_people', 'olivetti_faces', '20newsgroups', + 'MNIST original', 'CIFAR', 'a1a', 'SVHN', 'uncorrelated matrix'] + +big_sparse_datasets = ['big sparse matrix', 'rcv1'] + + +def unpickle(file): + fo = open(file, 'rb') + dict = pickle.load(fo, encoding='latin1') + fo.close() + return dict['data'] + + +def handle_missing_dataset(file_folder): + if not os.path.isdir(file_folder): + print("%s file folder not found. Test skipped." % file_folder) + return 0 + + +def get_data(dataset_name): + print("Getting dataset: %s" % dataset_name) + + if dataset_name == 'lfw_people': + X = fetch_lfw_people().data + elif dataset_name == '20newsgroups': + X = fetch_20newsgroups_vectorized().data[:, :100000] + elif dataset_name == 'olivetti_faces': + X = fetch_olivetti_faces().data + elif dataset_name == 'rcv1': + X = fetch_rcv1().data + elif dataset_name == 'CIFAR': + if handle_missing_dataset(CIFAR_FOLDER) == "skip": + return + X1 = [unpickle("%sdata_batch_%d" % (CIFAR_FOLDER, i + 1)) + for i in range(5)] + X = np.vstack(X1) + del X1 + elif dataset_name == 'SVHN': + if handle_missing_dataset(SVHN_FOLDER) == 0: + return + X1 = sp.io.loadmat("%strain_32x32.mat" % SVHN_FOLDER)['X'] + X2 = [X1[:, :, :, i].reshape(32 * 32 * 3) for i in range(X1.shape[3])] + X = np.vstack(X2) + del X1 + del X2 + elif dataset_name == 'low rank matrix': + X = make_low_rank_matrix(n_samples=500, n_features=np.int(1e4), + effective_rank=100, tail_strength=.5, + random_state=random_state) + elif dataset_name == 'uncorrelated matrix': + X, _ = make_sparse_uncorrelated(n_samples=500, n_features=10000, + random_state=random_state) + elif dataset_name == 'big sparse matrix': + sparsity = np.int(1e6) + size = np.int(1e6) + small_size = np.int(1e4) + data = np.random.normal(0, 1, np.int(sparsity/10)) + data = np.repeat(data, 10) + row = np.random.uniform(0, small_size, sparsity) + col = np.random.uniform(0, small_size, sparsity) + X = sp.sparse.csr_matrix((data, (row, col)), shape=(size, small_size)) + del data + del row + del col + else: + X = fetch_mldata(dataset_name).data + return X + + +def plot_time_vs_s(time, norm, point_labels, title): + plt.figure() + colors = ['g', 'b', 'y'] + for i, l in enumerate(sorted(norm.keys())): + if l is not "fbpca": + plt.plot(time[l], norm[l], label=l, marker='o', c=colors.pop()) + else: + plt.plot(time[l], norm[l], label=l, marker='^', c='red') + + for label, x, y in zip(point_labels, list(time[l]), list(norm[l])): + plt.annotate(label, xy=(x, y), xytext=(0, -20), + textcoords='offset points', ha='right', va='bottom') + plt.legend(loc="upper right") + plt.suptitle(title) + plt.ylabel("norm discrepancy") + plt.xlabel("running time [s]") + + +def scatter_time_vs_s(time, norm, point_labels, title): + plt.figure() + size = 100 + for i, l in enumerate(sorted(norm.keys())): + if l is not "fbpca": + plt.scatter(time[l], norm[l], label=l, marker='o', c='b', s=size) + for label, x, y in zip(point_labels, list(time[l]), list(norm[l])): + plt.annotate(label, xy=(x, y), xytext=(0, -80), + textcoords='offset points', ha='right', + arrowprops=dict(arrowstyle="->", + connectionstyle="arc3"), + va='bottom', size=11, rotation=90) + else: + plt.scatter(time[l], norm[l], label=l, marker='^', c='red', s=size) + for label, x, y in zip(point_labels, list(time[l]), list(norm[l])): + plt.annotate(label, xy=(x, y), xytext=(0, 30), + textcoords='offset points', ha='right', + arrowprops=dict(arrowstyle="->", + connectionstyle="arc3"), + va='bottom', size=11, rotation=90) + + plt.legend(loc="best") + plt.suptitle(title) + plt.ylabel("norm discrepancy") + plt.xlabel("running time [s]") + + +def plot_power_iter_vs_s(power_iter, s, title): + plt.figure() + for l in sorted(s.keys()): + plt.plot(power_iter, s[l], label=l, marker='o') + plt.legend(loc="lower right", prop={'size': 10}) + plt.suptitle(title) + plt.ylabel("norm discrepancy") + plt.xlabel("n_iter") + + +def svd_timing(X, n_comps, n_iter, n_oversamples, + power_iteration_normalizer='auto', method=None): + """ + Measure time for decomposition + """ + print("... running SVD ...") + if method is not 'fbpca': + gc.collect() + t0 = time() + U, mu, V = randomized_svd(X, n_comps, n_oversamples, n_iter, + power_iteration_normalizer, + random_state=random_state, transpose=False) + call_time = time() - t0 + else: + gc.collect() + t0 = time() + # There is a different convention for l here + U, mu, V = fbpca.pca(X, n_comps, raw=True, n_iter=n_iter, + l=n_oversamples+n_comps) + call_time = time() - t0 + + return U, mu, V, call_time + + +def norm_diff(A, norm=2, msg=True): + """ + Compute the norm diff with the original matrix, when randomized + SVD is called with *params. + + norm: 2 => spectral; 'fro' => Frobenius + """ + + if msg: + print("... computing %s norm ..." % norm) + if norm == 2: + # s = sp.linalg.norm(A, ord=2) # slow + value = sp.sparse.linalg.svds(A, k=1, return_singular_vectors=False) + else: + if sp.sparse.issparse(A): + value = sp.sparse.linalg.norm(A, ord=norm) + else: + value = sp.linalg.norm(A, ord=norm) + return value + + +def scalable_frobenius_norm_discrepancy(X, U, s, V): + # if the input is not too big, just call scipy + if X.shape[0] * X.shape[1] < MAX_MEMORY: + A = X - U.dot(np.diag(s).dot(V)) + return norm_diff(A, norm='fro') + + print("... computing fro norm by batches...") + batch_size = 1000 + Vhat = np.diag(s).dot(V) + cum_norm = .0 + for batch in gen_batches(X.shape[0], batch_size): + M = X[batch, :] - U[batch, :].dot(Vhat) + cum_norm += norm_diff(M, norm='fro', msg=False) + return np.sqrt(cum_norm) + + +def bench_a(X, dataset_name, power_iter, n_oversamples, n_comps): + + all_time = defaultdict(list) + if enable_spectral_norm: + all_spectral = defaultdict(list) + X_spectral_norm = norm_diff(X, norm=2, msg=False) + all_frobenius = defaultdict(list) + X_fro_norm = norm_diff(X, norm='fro', msg=False) + + for pi in power_iter: + for pm in ['none', 'LU', 'QR']: + print("n_iter = %d on sklearn - %s" % (pi, pm)) + U, s, V, time = svd_timing(X, n_comps, n_iter=pi, + power_iteration_normalizer=pm, + n_oversamples=n_oversamples) + label = "sklearn - %s" % pm + all_time[label].append(time) + if enable_spectral_norm: + A = U.dot(np.diag(s).dot(V)) + all_spectral[label].append(norm_diff(X - A, norm=2) / + X_spectral_norm) + f = scalable_frobenius_norm_discrepancy(X, U, s, V) + all_frobenius[label].append(f / X_fro_norm) + + if fbpca_available: + print("n_iter = %d on fbca" % (pi)) + U, s, V, time = svd_timing(X, n_comps, n_iter=pi, + power_iteration_normalizer=pm, + n_oversamples=n_oversamples, + method='fbpca') + label = "fbpca" + all_time[label].append(time) + if enable_spectral_norm: + A = U.dot(np.diag(s).dot(V)) + all_spectral[label].append(norm_diff(X - A, norm=2) / + X_spectral_norm) + f = scalable_frobenius_norm_discrepancy(X, U, s, V) + all_frobenius[label].append(f / X_fro_norm) + + if enable_spectral_norm: + title = "%s: spectral norm diff vs running time" % (dataset_name) + plot_time_vs_s(all_time, all_spectral, power_iter, title) + title = "%s: Frobenius norm diff vs running time" % (dataset_name) + plot_time_vs_s(all_time, all_frobenius, power_iter, title) + + +def bench_b(power_list): + + n_samples, n_features = 1000, 10000 + data_params = {'n_samples': n_samples, 'n_features': n_features, + 'tail_strength': .7, 'random_state': random_state} + dataset_name = "low rank matrix %d x %d" % (n_samples, n_features) + ranks = [10, 50, 100] + + if enable_spectral_norm: + all_spectral = defaultdict(list) + all_frobenius = defaultdict(list) + for rank in ranks: + X = make_low_rank_matrix(effective_rank=rank, **data_params) + if enable_spectral_norm: + X_spectral_norm = norm_diff(X, norm=2, msg=False) + X_fro_norm = norm_diff(X, norm='fro', msg=False) + + for n_comp in [np.int(rank/2), rank, rank*2]: + label = "rank=%d, n_comp=%d" % (rank, n_comp) + print(label) + for pi in power_list: + U, s, V, _ = svd_timing(X, n_comp, n_iter=pi, n_oversamples=2, + power_iteration_normalizer='LU') + if enable_spectral_norm: + A = U.dot(np.diag(s).dot(V)) + all_spectral[label].append(norm_diff(X - A, norm=2) / + X_spectral_norm) + f = scalable_frobenius_norm_discrepancy(X, U, s, V) + all_frobenius[label].append(f / X_fro_norm) + + if enable_spectral_norm: + title = "%s: spectral norm diff vs n power iteration" % (dataset_name) + plot_power_iter_vs_s(power_iter, all_spectral, title) + title = "%s: frobenius norm diff vs n power iteration" % (dataset_name) + plot_power_iter_vs_s(power_iter, all_frobenius, title) + + +def bench_c(datasets, n_comps): + all_time = defaultdict(list) + if enable_spectral_norm: + all_spectral = defaultdict(list) + all_frobenius = defaultdict(list) + + for dataset_name in datasets: + X = get_data(dataset_name) + if X is None: + continue + + if enable_spectral_norm: + X_spectral_norm = norm_diff(X, norm=2, msg=False) + X_fro_norm = norm_diff(X, norm='fro', msg=False) + n_comps = np.minimum(n_comps, np.min(X.shape)) + + label = "sklearn" + print("%s %d x %d - %s" % + (dataset_name, X.shape[0], X.shape[1], label)) + U, s, V, time = svd_timing(X, n_comps, n_iter=2, n_oversamples=10, + method=label) + + all_time[label].append(time) + if enable_spectral_norm: + A = U.dot(np.diag(s).dot(V)) + all_spectral[label].append(norm_diff(X - A, norm=2) / + X_spectral_norm) + f = scalable_frobenius_norm_discrepancy(X, U, s, V) + all_frobenius[label].append(f / X_fro_norm) + + if fbpca_available: + label = "fbpca" + print("%s %d x %d - %s" % + (dataset_name, X.shape[0], X.shape[1], label)) + U, s, V, time = svd_timing(X, n_comps, n_iter=2, n_oversamples=2, + method=label) + all_time[label].append(time) + if enable_spectral_norm: + A = U.dot(np.diag(s).dot(V)) + all_spectral[label].append(norm_diff(X - A, norm=2) / + X_spectral_norm) + f = scalable_frobenius_norm_discrepancy(X, U, s, V) + all_frobenius[label].append(f / X_fro_norm) + + if len(all_time) == 0: + raise ValueError("No tests ran. Aborting.") + + if enable_spectral_norm: + title = "normalized spectral norm diff vs running time" + scatter_time_vs_s(all_time, all_spectral, datasets, title) + title = "normalized Frobenius norm diff vs running time" + scatter_time_vs_s(all_time, all_frobenius, datasets, title) + + +if __name__ == '__main__': + random_state = check_random_state(1234) + + power_iter = np.linspace(0, 6, 7, dtype=int) + n_comps = 50 + + for dataset_name in datasets: + X = get_data(dataset_name) + if X is None: + continue + print(" >>>>>> Benching sklearn and fbpca on %s %d x %d" % + (dataset_name, X.shape[0], X.shape[1])) + bench_a(X, dataset_name, power_iter, n_oversamples=2, + n_comps=np.minimum(n_comps, np.min(X.shape))) + + print(" >>>>>> Benching on simulated low rank matrix with variable rank") + bench_b(power_iter) + + print(" >>>>>> Benching sklearn and fbpca default configurations") + bench_c(datasets + big_sparse_datasets, n_comps) + + plt.show() diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 7eb070d093aa9..9ae61a3f434b5 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -15,26 +15,40 @@ Changelog New features ............ - - The Gaussian Process module has been reimplemented and now offers classification - and regression estimators through :class:`gaussian_process.GaussianProcessClassifier` - and :class:`gaussian_process.GaussianProcessRegressor`. Among other things, the new - implementation supports kernel engineering, gradient-based hyperparameter optimization or - sampling of functions from GP prior and GP posterior. Extensive documentation and + - The Gaussian Process module has been reimplemented and now offers classification + and regression estimators through :class:`gaussian_process.GaussianProcessClassifier` + and :class:`gaussian_process.GaussianProcessRegressor`. Among other things, the new + implementation supports kernel engineering, gradient-based hyperparameter optimization or + sampling of functions from GP prior and GP posterior. Extensive documentation and examples are provided. By `Jan Hendrik Metzen`_. - + Enhancements ............ - + Bug fixes ......... - - Fixed bug in :func:`manifold.spectral_embedding` where diagonal of unnormalized + + - :class:`RandomizedPCA` default number of `iterated_power` is 2 instead of 3. + This is a speed up with a minor precision decrease. By `Giorgio Patrini`_. + + - :func:`randomized_svd` performs 2 power iterations by default, instead or 0. + In practice this is often enough for obtaining a good approximation of the + true eigenvalues/vectors in the presence of noise. By `Giorgio Patrini`_. + + - :func:`randomized_range_finder` is more numerically stable when many + power iterations are requested, since it applies LU normalization by default. + If `n_iter<2` numerical issues are unlikely, thus no normalization is applied. + Other normalization options are available: 'none', 'LU' and 'QR'. By + `Giorgio Patrini`_. + + - Fixed bug in :func:`manifold.spectral_embedding` where diagonal of unnormalized Laplacian matrix was incorrectly set to 1. By `Peter Fischer`_. API changes summary ------------------- - - + + .. _changes_0_17: Version 0.17 @@ -271,7 +285,7 @@ Bug fixes in the final fit. By `Manoj Kumar`_. - Fixed bug in :class:`ensemble.forest.ForestClassifier` while computing - oob_score and X is a sparse.csc_matrix. By `Ankur Ankan`_. + oob_score and X is a sparse.csc_matrix. By `Ankur Ankan`_. - All regressors now consistently handle and warn when given ``y`` that is of shape ``(n_samples, 1)``. By `Andreas Müller`_. @@ -3799,3 +3813,4 @@ David Huard, Dave Morrill, Ed Schofield, Travis Oliphant, Pearu Peterson. .. _Jean Kossaifi: https://github.com/JeanKossaifi .. _Andrew Lamb: https://github.com/andylamb .. _Graham Clenaghan: https://github.com/gclenaghan +.. _Giorgio Patrini: https://github.com/giorgiop diff --git a/sklearn/decomposition/pca.py b/sklearn/decomposition/pca.py index e9993ad45985c..84aa4874dff46 100644 --- a/sklearn/decomposition/pca.py +++ b/sklearn/decomposition/pca.py @@ -524,7 +524,7 @@ class RandomizedPCA(BaseEstimator, TransformerMixin): >>> X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]]) >>> pca = RandomizedPCA(n_components=2) >>> pca.fit(X) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE - RandomizedPCA(copy=True, iterated_power=3, n_components=2, + RandomizedPCA(copy=True, iterated_power=2, n_components=2, random_state=None, whiten=False) >>> print(pca.explained_variance_ratio_) # doctest: +ELLIPSIS [ 0.99244... 0.00755...] @@ -546,7 +546,7 @@ class RandomizedPCA(BaseEstimator, TransformerMixin): """ - def __init__(self, n_components=None, copy=True, iterated_power=3, + def __init__(self, n_components=None, copy=True, iterated_power=2, whiten=False, random_state=None): self.n_components = n_components self.copy = copy diff --git a/sklearn/decomposition/tests/test_pca.py b/sklearn/decomposition/tests/test_pca.py index eedbc35b1ae9b..f94c4bb9e90b4 100644 --- a/sklearn/decomposition/tests/test_pca.py +++ b/sklearn/decomposition/tests/test_pca.py @@ -82,6 +82,8 @@ def test_whitening(): # whiten the data while projecting to the lower dim subspace X_ = X.copy() # make sure we keep an original across iterations. pca = this_PCA(n_components=n_components, whiten=True, copy=copy) + if hasattr(pca, 'random_state'): + pca.random_state = rng # test fit_transform X_whitened = pca.fit_transform(X_.copy()) assert_equal(X_whitened.shape, (n_samples, n_components)) @@ -89,7 +91,7 @@ def test_whitening(): assert_array_almost_equal(X_whitened, X_whitened2) assert_almost_equal(X_whitened.std(axis=0), np.ones(n_components), - decimal=6) + decimal=4) assert_almost_equal(X_whitened.mean(axis=0), np.zeros(n_components)) X_ = X.copy() @@ -112,11 +114,9 @@ def test_explained_variance(): X = rng.randn(n_samples, n_features) pca = PCA(n_components=2).fit(X) - rpca = RandomizedPCA(n_components=2, random_state=42).fit(X) - assert_array_almost_equal(pca.explained_variance_, - rpca.explained_variance_, 1) + rpca = RandomizedPCA(n_components=2, random_state=rng).fit(X) assert_array_almost_equal(pca.explained_variance_ratio_, - rpca.explained_variance_ratio_, 3) + rpca.explained_variance_ratio_, 1) # compare to empirical variances X_pca = pca.transform(X) @@ -127,6 +127,16 @@ def test_explained_variance(): assert_array_almost_equal(rpca.explained_variance_, np.var(X_rpca, axis=0), decimal=1) + # Same with correlated data + X = datasets.make_classification(n_samples, n_features, + n_informative=n_features-2, + random_state=rng)[0] + + pca = PCA(n_components=2).fit(X) + rpca = RandomizedPCA(n_components=2, random_state=rng).fit(X) + assert_array_almost_equal(pca.explained_variance_ratio_, + rpca.explained_variance_ratio_, 5) + def test_pca_check_projection(): # Test that the projection of data is correct diff --git a/sklearn/decomposition/truncated_svd.py b/sklearn/decomposition/truncated_svd.py index 343a012f04c81..0e449e5a41d8e 100644 --- a/sklearn/decomposition/truncated_svd.py +++ b/sklearn/decomposition/truncated_svd.py @@ -53,8 +53,10 @@ class TruncatedSVD(BaseEstimator, TransformerMixin): (scipy.sparse.linalg.svds), or "randomized" for the randomized algorithm due to Halko (2009). - n_iter : int, optional + n_iter : int, optional (default 5) Number of iterations for randomized SVD solver. Not used by ARPACK. + The default is larger than the default in `randomized_svd` to handle + sparse matrices that may have large slowly decaying spectrum. random_state : int or RandomState, optional (Seed for) pseudo-random number generator. If not given, the diff --git a/sklearn/utils/extmath.py b/sklearn/utils/extmath.py index 9187b219142e6..e71a2e5f1241b 100644 --- a/sklearn/utils/extmath.py +++ b/sklearn/utils/extmath.py @@ -187,7 +187,9 @@ def safe_sparse_dot(a, b, dense_output=False): return fast_dot(a, b) -def randomized_range_finder(A, size, n_iter, random_state=None): +def randomized_range_finder(A, size, n_iter=2, + power_iteration_normalizer='auto', + random_state=None): """Computes an orthonormal matrix whose range approximates the range of A. Parameters @@ -198,6 +200,13 @@ def randomized_range_finder(A, size, n_iter, random_state=None): Size of the return array n_iter: integer Number of power iterations used to stabilize the result + power_iteration_normalizer: 'auto' (default), 'QR', 'LU', 'none' + Whether the power iterations are normalized with step-by-step + QR factorization (the slowest but most accurate), 'none' + (the fastest but numerically unstable when `n_iter` is large, e.g. + typically 5 or larger), or 'LU' factorization (numerically stable + but can lose slightly in accuracy). The 'auto' mode applies no + normalization if `n_iter`<=2 and switches to LU otherwise. random_state: RandomState or an int seed (0 by default) A random number generator instance @@ -214,28 +223,45 @@ def randomized_range_finder(A, size, n_iter, random_state=None): Finding structure with randomness: Stochastic algorithms for constructing approximate matrix decompositions Halko, et al., 2009 (arXiv:909) http://arxiv.org/pdf/0909.4061 + + An implementation of a randomized algorithm for principal component + analysis + A. Szlam et al. 2014 """ random_state = check_random_state(random_state) - # generating random gaussian vectors r with shape: (A.shape[1], size) - R = random_state.normal(size=(A.shape[1], size)) - - # sampling the range of A using by linear projection of r - Y = safe_sparse_dot(A, R) - del R - - # perform power iterations with Y to further 'imprint' the top - # singular vectors of A in Y - for i in xrange(n_iter): - Y = safe_sparse_dot(A, safe_sparse_dot(A.T, Y)) + # Generating normal random vectors with shape: (A.shape[1], size) + Q = random_state.normal(size=(A.shape[1], size)) - # extracting an orthonormal basis of the A range samples - Q, R = linalg.qr(Y, mode='economic') + # Deal with "auto" mode + if power_iteration_normalizer == 'auto': + if n_iter <= 2: + power_iteration_normalizer = 'none' + else: + power_iteration_normalizer = 'LU' + + # Perform power iterations with Q to further 'imprint' the top + # singular vectors of A in Q + for i in range(n_iter): + if power_iteration_normalizer == 'none': + Q = safe_sparse_dot(A, Q) + Q = safe_sparse_dot(A.T, Q) + elif power_iteration_normalizer == 'LU': + Q, _ = linalg.lu(safe_sparse_dot(A, Q), permute_l=True) + Q, _ = linalg.lu(safe_sparse_dot(A.T, Q), permute_l=True) + elif power_iteration_normalizer == 'QR': + Q, _ = linalg.qr(safe_sparse_dot(A, Q), mode='economic') + Q, _ = linalg.qr(safe_sparse_dot(A.T, Q), mode='economic') + + # Sample the range of A using by linear projection of Q + # Extract an orthonormal basis + Q, _ = linalg.qr(safe_sparse_dot(A, Q), mode='economic') return Q -def randomized_svd(M, n_components, n_oversamples=10, n_iter=0, - transpose='auto', flip_sign=True, random_state=0): +def randomized_svd(M, n_components, n_oversamples=10, n_iter=2, + power_iteration_normalizer='auto', transpose='auto', + flip_sign=True, random_state=0): """Computes a truncated randomized SVD Parameters @@ -249,18 +275,28 @@ def randomized_svd(M, n_components, n_oversamples=10, n_iter=0, n_oversamples: int (default is 10) Additional number of random vectors to sample the range of M so as to ensure proper conditioning. The total number of random vectors - used to find the range of M is n_components + n_oversamples. + used to find the range of M is n_components + n_oversamples. Smaller + number can improve speed but can negatively impact the quality of + approximation of singular vectors and singular values. - n_iter: int (default is 0) + n_iter: int (default is 2) Number of power iterations (can be used to deal with very noisy problems). + power_iteration_normalizer: 'auto' (default), 'QR', 'LU', 'none' + Whether the power iterations are normalized with step-by-step + QR factorization (the slowest but most accurate), 'none' + (the fastest but numerically unstable when `n_iter` is large, e.g. + typically 5 or larger), or 'LU' factorization (numerically stable + but can lose slightly in accuracy). The 'auto' mode applies no + normalization if `n_iter`<=2 and switches to LU otherwise. + transpose: True, False or 'auto' (default) Whether the algorithm should be applied to M.T instead of M. The result should approximately be the same. The 'auto' mode will trigger the transposition if M.shape[1] > M.shape[0] since this implementation of randomized SVD tend to be a little faster in that - case). + case. flip_sign: boolean, (True by default) The output of a singular value decomposition is only unique up to a @@ -286,6 +322,10 @@ def randomized_svd(M, n_components, n_oversamples=10, n_iter=0, * A randomized algorithm for the decomposition of matrices Per-Gunnar Martinsson, Vladimir Rokhlin and Mark Tygert + + * An implementation of a randomized algorithm for principal component + analysis + A. Szlam et al. 2014 """ random_state = check_random_state(random_state) n_random = n_components + n_oversamples @@ -297,7 +337,8 @@ def randomized_svd(M, n_components, n_oversamples=10, n_iter=0, # this implementation is a bit faster with smaller shape[1] M = M.T - Q = randomized_range_finder(M, n_random, n_iter, random_state) + Q = randomized_range_finder(M, n_random, n_iter, + power_iteration_normalizer, random_state) # project M to the (k + p) dimensional space using the basis vectors B = safe_sparse_dot(Q.T, M) diff --git a/sklearn/utils/tests/test_extmath.py b/sklearn/utils/tests/test_extmath.py index 107f5b845e1d9..71872892a5307 100644 --- a/sklearn/utils/tests/test_extmath.py +++ b/sklearn/utils/tests/test_extmath.py @@ -108,25 +108,28 @@ def test_randomized_svd_low_rank(): # compute the singular values of X using the slow exact method U, s, V = linalg.svd(X, full_matrices=False) - # compute the singular values of X using the fast approximate method - Ua, sa, Va = randomized_svd(X, k) - assert_equal(Ua.shape, (n_samples, k)) - assert_equal(sa.shape, (k,)) - assert_equal(Va.shape, (k, n_features)) + for normalizer in ['auto', 'none', 'LU', 'QR']: + # compute the singular values of X using the fast approximate method + Ua, sa, Va = \ + randomized_svd(X, k, power_iteration_normalizer=normalizer) + assert_equal(Ua.shape, (n_samples, k)) + assert_equal(sa.shape, (k,)) + assert_equal(Va.shape, (k, n_features)) - # ensure that the singular values of both methods are equal up to the real - # rank of the matrix - assert_almost_equal(s[:k], sa) + # ensure that the singular values of both methods are equal up to the + # real rank of the matrix + assert_almost_equal(s[:k], sa) - # check the singular vectors too (while not checking the sign) - assert_almost_equal(np.dot(U[:, :k], V[:k, :]), np.dot(Ua, Va)) + # check the singular vectors too (while not checking the sign) + assert_almost_equal(np.dot(U[:, :k], V[:k, :]), np.dot(Ua, Va)) - # check the sparse matrix representation - X = sparse.csr_matrix(X) + # check the sparse matrix representation + X = sparse.csr_matrix(X) - # compute the singular values of X using the fast approximate method - Ua, sa, Va = randomized_svd(X, k) - assert_almost_equal(s[:rank], sa[:rank]) + # compute the singular values of X using the fast approximate method + Ua, sa, Va = \ + randomized_svd(X, k, power_iteration_normalizer=normalizer) + assert_almost_equal(s[:rank], sa[:rank]) def test_norm_squared_norm(): @@ -161,26 +164,29 @@ def test_randomized_svd_low_rank_with_noise(): # generate a matrix X wity structure approximate rank `rank` and an # important noisy component X = make_low_rank_matrix(n_samples=n_samples, n_features=n_features, - effective_rank=rank, tail_strength=0.5, + effective_rank=rank, tail_strength=0.1, random_state=0) assert_equal(X.shape, (n_samples, n_features)) # compute the singular values of X using the slow exact method _, s, _ = linalg.svd(X, full_matrices=False) - # compute the singular values of X using the fast approximate method - # without the iterated power method - _, sa, _ = randomized_svd(X, k, n_iter=0) + for normalizer in ['auto', 'none', 'LU', 'QR']: + # compute the singular values of X using the fast approximate + # method without the iterated power method + _, sa, _ = randomized_svd(X, k, n_iter=0, + power_iteration_normalizer=normalizer) - # the approximation does not tolerate the noise: - assert_greater(np.abs(s[:k] - sa).max(), 0.05) + # the approximation does not tolerate the noise: + assert_greater(np.abs(s[:k] - sa).max(), 0.01) - # compute the singular values of X using the fast approximate method with - # iterated power method - _, sap, _ = randomized_svd(X, k, n_iter=5) + # compute the singular values of X using the fast approximate + # method with iterated power method + _, sap, _ = randomized_svd(X, k, + power_iteration_normalizer=normalizer) - # the iterated power method is helping getting rid of the noise: - assert_almost_equal(s[:k], sap, decimal=3) + # the iterated power method is helping getting rid of the noise: + assert_almost_equal(s[:k], sap, decimal=3) def test_randomized_svd_infinite_rank(): @@ -199,21 +205,23 @@ def test_randomized_svd_infinite_rank(): # compute the singular values of X using the slow exact method _, s, _ = linalg.svd(X, full_matrices=False) + for normalizer in ['auto', 'none', 'LU', 'QR']: + # compute the singular values of X using the fast approximate method + # without the iterated power method + _, sa, _ = randomized_svd(X, k, n_iter=0, + power_iteration_normalizer=normalizer) - # compute the singular values of X using the fast approximate method - # without the iterated power method - _, sa, _ = randomized_svd(X, k, n_iter=0) - - # the approximation does not tolerate the noise: - assert_greater(np.abs(s[:k] - sa).max(), 0.1) + # the approximation does not tolerate the noise: + assert_greater(np.abs(s[:k] - sa).max(), 0.1) - # compute the singular values of X using the fast approximate method with - # iterated power method - _, sap, _ = randomized_svd(X, k, n_iter=5) + # compute the singular values of X using the fast approximate method + # with iterated power method + _, sap, _ = randomized_svd(X, k, n_iter=5, + power_iteration_normalizer=normalizer) - # the iterated power method is still managing to get most of the structure - # at the requested rank - assert_almost_equal(s[:k], sap, decimal=3) + # the iterated power method is still managing to get most of the + # structure at the requested rank + assert_almost_equal(s[:k], sap, decimal=3) def test_randomized_svd_transpose_consistency(): @@ -249,6 +257,41 @@ def test_randomized_svd_transpose_consistency(): assert_almost_equal(s2, s3) +def test_randomized_svd_power_iteration_normalizer(): + # randomized_svd with power_iteration_normalized='none' diverges for + # large number of power iterations on this dataset + rng = np.random.RandomState(42) + X = make_low_rank_matrix(300, 1000, effective_rank=50, random_state=rng) + X += 3 * rng.randint(0, 2, size=X.shape) + n_components = 50 + + # Check that it diverges with many (non-normalized) power iterations + U, s, V = randomized_svd(X, n_components, n_iter=2, + power_iteration_normalizer='none') + A = X - U.dot(np.diag(s).dot(V)) + error_2 = linalg.norm(A, ord='fro') + U, s, V = randomized_svd(X, n_components, n_iter=20, + power_iteration_normalizer='none') + A = X - U.dot(np.diag(s).dot(V)) + error_20 = linalg.norm(A, ord='fro') + print(error_2 - error_20) + assert_greater(np.abs(error_2 - error_20), 100) + + for normalizer in ['LU', 'QR', 'auto']: + U, s, V = randomized_svd(X, n_components, n_iter=2, + power_iteration_normalizer=normalizer) + A = X - U.dot(np.diag(s).dot(V)) + error_2 = linalg.norm(A, ord='fro') + + for i in [5, 10, 50]: + U, s, V = randomized_svd(X, n_components, n_iter=i, + power_iteration_normalizer=normalizer) + A = X - U.dot(np.diag(s).dot(V)) + error = linalg.norm(A, ord='fro') + print(error_2 - error) + assert_greater(15, np.abs(error_2 - error)) + + def test_svd_flip(): # Check that svd_flip works in both situations, and reconstructs input. rs = np.random.RandomState(1999)