diff --git a/benchmarks/bench_plot_randomized_svd.py b/benchmarks/bench_plot_randomized_svd.py index 081842231039e..ecc1bbb92ce61 100644 --- a/benchmarks/bench_plot_randomized_svd.py +++ b/benchmarks/bench_plot_randomized_svd.py @@ -107,7 +107,7 @@ # Determine when to switch to batch computation for matrix norms, # in case the reconstructed (dense) matrix is too large -MAX_MEMORY = int(2e9) +MAX_MEMORY = int(4e9) # The following datasets can be downloaded manually from: # CIFAR 10: https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz @@ -323,8 +323,11 @@ def norm_diff(A, norm=2, msg=True, random_state=None): def scalable_frobenius_norm_discrepancy(X, U, s, V): - # if the input is not too big, just call scipy - if X.shape[0] * X.shape[1] < MAX_MEMORY: + if not sp.sparse.issparse(X) or ( + X.shape[0] * X.shape[1] * X.dtype.itemsize < MAX_MEMORY + ): + # if the input is not sparse or sparse but not too big, + # U.dot(np.diag(s).dot(V)) will fit in RAM A = X - U.dot(np.diag(s).dot(V)) return norm_diff(A, norm="fro") @@ -498,7 +501,7 @@ def bench_c(datasets, n_comps): if __name__ == "__main__": random_state = check_random_state(1234) - power_iter = np.linspace(0, 6, 7, dtype=int) + power_iter = np.arange(0, 6) n_comps = 50 for dataset_name in datasets: diff --git a/sklearn/utils/extmath.py b/sklearn/utils/extmath.py index 35e392f6e4540..e4513a62bf07e 100644 --- a/sklearn/utils/extmath.py +++ b/sklearn/utils/extmath.py @@ -216,6 +216,9 @@ def randomized_range_finder( # Generating normal random vectors with shape: (A.shape[1], size) Q = random_state.normal(size=(A.shape[1], size)) + if hasattr(A, "dtype") and A.dtype.kind == "f": + # Ensure f32 is preserved as f32 + Q = Q.astype(A.dtype, copy=False) # Deal with "auto" mode if power_iteration_normalizer == "auto": @@ -241,10 +244,6 @@ def randomized_range_finder( # Extract an orthonormal basis Q, _ = linalg.qr(safe_sparse_dot(A, Q), mode="economic") - if hasattr(A, "dtype") and A.dtype.kind == "f": - # Ensure f32 is preserved as f32 - Q = Q.astype(A.dtype, copy=False) - return Q