From 06e9171915702a315dfbac871b8b248f64af4892 Mon Sep 17 00:00:00 2001 From: LongBao Date: Sun, 15 May 2022 23:51:42 +0900 Subject: [PATCH 1/2] fix benchmark randonmizd svd --- benchmarks/bench_plot_randomized_svd.py | 2 +- sklearn/utils/extmath.py | 3 --- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/benchmarks/bench_plot_randomized_svd.py b/benchmarks/bench_plot_randomized_svd.py index c7d67fa2a545d..081842231039e 100644 --- a/benchmarks/bench_plot_randomized_svd.py +++ b/benchmarks/bench_plot_randomized_svd.py @@ -153,7 +153,7 @@ def get_data(dataset_name): elif dataset_name == "rcv1": X = fetch_rcv1().data elif dataset_name == "CIFAR": - if handle_missing_dataset(CIFAR_FOLDER) == "skip": + if handle_missing_dataset(CIFAR_FOLDER) == 0: return X1 = [unpickle("%sdata_batch_%d" % (CIFAR_FOLDER, i + 1)) for i in range(5)] X = np.vstack(X1) diff --git a/sklearn/utils/extmath.py b/sklearn/utils/extmath.py index 2521990e6cc68..8b75628ca61c4 100644 --- a/sklearn/utils/extmath.py +++ b/sklearn/utils/extmath.py @@ -216,9 +216,6 @@ def randomized_range_finder( # Generating normal random vectors with shape: (A.shape[1], size) Q = random_state.normal(size=(A.shape[1], size)) - if A.dtype.kind == "f": - # Ensure f32 is preserved as f32 - Q = Q.astype(A.dtype, copy=False) # Deal with "auto" mode if power_iteration_normalizer == "auto": From 0c9d48afba513b5fae6a70ba7733f951b0a60dee Mon Sep 17 00:00:00 2001 From: LongBao Date: Mon, 16 May 2022 10:53:09 +0900 Subject: [PATCH 2/2] addback type casting --- sklearn/utils/extmath.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/sklearn/utils/extmath.py b/sklearn/utils/extmath.py index 8b75628ca61c4..4438f67fb5729 100644 --- a/sklearn/utils/extmath.py +++ b/sklearn/utils/extmath.py @@ -240,6 +240,11 @@ def randomized_range_finder( # Sample the range of A using by linear projection of Q # Extract an orthonormal basis Q, _ = linalg.qr(safe_sparse_dot(A, Q), mode="economic") + + if hasattr(A, "dtype") and A.dtype.kind == "f": + # Ensure f32 is preserved as f32 + Q = Q.astype(A.dtype, copy=False) + return Q