From b286daf9e316e45fdd2177af33b17a60a9d3fd8a Mon Sep 17 00:00:00 2001 From: oussama er-rabie Date: Wed, 16 Apr 2025 23:02:01 +0200 Subject: [PATCH 1/6] add randomized_eigh(selection='value') --- benchmarks/bench_isomap_auto_vs_randomized.py | 98 ++++++++++++++++ ...lvers_n_samples_vs_reconstruction_error.py | 85 ++++++++++++++ ...kernel_pca_solvers_time_vs_n_components.py | 25 ++++ ...ch_kernel_pca_solvers_time_vs_n_samples.py | 26 ++++- sklearn/decomposition/_kernel_pca.py | 10 +- .../decomposition/tests/test_kernel_pca.py | 23 ++-- sklearn/manifold/_isomap.py | 4 +- sklearn/manifold/tests/test_isomap.py | 2 +- sklearn/utils/extmath.py | 107 +++++++++++++++++- sklearn/utils/tests/test_extmath.py | 48 +++++++- 10 files changed, 407 insertions(+), 21 deletions(-) create mode 100644 benchmarks/bench_isomap_auto_vs_randomized.py create mode 100644 benchmarks/bench_isomap_solvers_n_samples_vs_reconstruction_error.py diff --git a/benchmarks/bench_isomap_auto_vs_randomized.py b/benchmarks/bench_isomap_auto_vs_randomized.py new file mode 100644 index 0000000000000..95dff4f9e21cb --- /dev/null +++ b/benchmarks/bench_isomap_auto_vs_randomized.py @@ -0,0 +1,98 @@ +""" +====================================================================== +Benchmark: Comparing Isomap Solvers - Execution Time vs. Representation +====================================================================== + +This benchmark demonstrates how different eigenvalue solvers in Isomap +can affect execution time and embedding quality. + +Description: +------------ +We use a subset of handwritten digits (`load_digits` with 6 classes). +Each data point is projected into a lower-dimensional space (2D) using +two different solvers (`auto` and `randomized`). + +What you can observe: +---------------------- +- The `auto` solver provides a reference solution. +- The `randomized` solver is tested for comparison in terms of + representation quality and execution time. + +Further exploration: +--------------------- +You can modify the number of neighbors (`n_neighbors`) or experiment with +other Isomap solvers. +""" + +import time +import numpy as np +import matplotlib.pyplot as plt +from matplotlib import offsetbox +from sklearn.datasets import load_digits +from sklearn.preprocessing import MinMaxScaler +from sklearn.manifold import Isomap + +# 1- Data Loading +# --------------- +digits = load_digits(n_class=6) +X, y = digits.data, digits.target +n_neighbors = 30 # Number of neighbors for Isomap + +# 2- Visualization Function +# ------------------------- +def plot_embedding(ax, X, title): + """Displays projected points with image annotations.""" + X = MinMaxScaler().fit_transform(X) + + for digit in digits.target_names: + ax.scatter( + *X[y == digit].T, + marker=f"${digit}$", + s=60, + color=plt.cm.Dark2(digit), + alpha=0.425, + zorder=2, + ) + + # Add digit images in the projected space + shown_images = np.array([[1.0, 1.0]]) + for i in range(X.shape[0]): + dist = np.sum((X[i] - shown_images) ** 2, 1) + if np.min(dist) < 4e-3: + continue + shown_images = np.concatenate([shown_images, [X[i]]], axis=0) + imagebox = offsetbox.AnnotationBbox( + offsetbox.OffsetImage(digits.images[i], cmap=plt.cm.gray_r), X[i] + ) + imagebox.set(zorder=1) + ax.add_artist(imagebox) + + ax.set_title(title) + ax.axis("off") + +# 3- Define Embeddings and Benchmark +# ---------------------------------- +embeddings = { + "Isomap (auto solver)": Isomap(n_neighbors=n_neighbors, n_components=2, eigen_solver='auto'), + "Isomap (randomized solver)": Isomap(n_neighbors=n_neighbors, n_components=2, eigen_solver='randomized_value'), +} + +projections, timing = {}, {} + +# Compute embeddings +for name, transformer in embeddings.items(): + print(f"Computing {name}...") + start_time = time.time() + projections[name] = transformer.fit_transform(X, y) + timing[name] = time.time() - start_time + +# 4- Display Results +# ------------------ +fig, axes = plt.subplots(1, 2, figsize=(14, 6)) + +for ax, (name, proj) in zip(axes, projections.items()): + title = f"{name} (time: {timing[name]:.3f}s)" + plot_embedding(ax, proj, title) + +plt.tight_layout() +plt.show() diff --git a/benchmarks/bench_isomap_solvers_n_samples_vs_reconstruction_error.py b/benchmarks/bench_isomap_solvers_n_samples_vs_reconstruction_error.py new file mode 100644 index 0000000000000..0291b85c8696c --- /dev/null +++ b/benchmarks/bench_isomap_solvers_n_samples_vs_reconstruction_error.py @@ -0,0 +1,85 @@ +""" +======================================================================== +Benchmark: Isomap Reconstruction Error - Standard vs. Randomized Solver +======================================================================== + +This benchmark illustrates how the number of samples impacts the quality +of the Isomap embedding, using reconstruction error as a metric. + +Description: +------------ +We generate synthetic 2D non-linear data (two concentric circles) with +varying numbers of samples. For each subset, we compare the reconstruction +error of two Isomap solvers: + +- The `auto` solver (standard dense or arpack, selected automatically). +- The `randomized_value` solver . + +What you can observe: +--------------------- +- The difference in performance between the two solvers. + +Further exploration: +--------------------- +- Modify the number of neighbors or iterations. +""" + +import numpy as np +import matplotlib.pyplot as plt +from sklearn.manifold import Isomap +from sklearn.datasets import make_circles + +# 1- Experiment Configuration +# --------------------------- +min_n_samples, max_n_samples = 100, 4000 +n_samples_grid_size = 4 # Number of sample sizes to test + +n_samples_range = [ + int(min_n_samples + np.floor((x / (n_samples_grid_size - 1)) * (max_n_samples - min_n_samples))) + for x in range(0, n_samples_grid_size) +] + +n_components = 2 +n_iter = 3 # Number of repetitions per sample size +include_arpack = False # Reserved for further testing + +# 2- Data Generation +# ------------------ +n_features = 2 +X_full, y_full = make_circles(n_samples=max_n_samples, factor=0.3, noise=0.05, random_state=0) + +# 3- Benchmark Execution +# ---------------------- +errors_randomized = [] +errors_full = [] + +for n_samples in n_samples_range: + X, y = X_full[:n_samples], y_full[:n_samples] + print(f"Computing for n_samples = {n_samples}") + + # Instantiate Isomap solvers + isomap_randomized = Isomap(n_neighbors=50, n_components=n_components, eigen_solver='randomized_value') + isomap_auto = Isomap(n_neighbors=50, n_components=n_components, eigen_solver='auto') + + # Fit and record reconstruction error + isomap_randomized.fit(X) + err_rand = isomap_randomized.reconstruction_error() + errors_randomized.append(err_rand) + + isomap_auto.fit(X) + err_auto = isomap_auto.reconstruction_error() + errors_full.append(err_auto) + +# 4- Results Visualization +# ------------------------ +plt.figure(figsize=(10, 6)) +plt.scatter(n_samples_range, errors_full, label='Isomap (auto)', color='b', marker='*') +plt.scatter(n_samples_range, errors_randomized, label='Isomap (randomized)', color='r', marker='x') + +plt.title('Isomap Reconstruction Error vs. Number of Samples') +plt.xlabel('Number of Samples') +plt.ylabel('Reconstruction Error') +plt.legend() +plt.grid(True) +plt.tight_layout() +plt.show() diff --git a/benchmarks/bench_kernel_pca_solvers_time_vs_n_components.py b/benchmarks/bench_kernel_pca_solvers_time_vs_n_components.py index a468f7b3e1abf..4e3bf5a55f93d 100644 --- a/benchmarks/bench_kernel_pca_solvers_time_vs_n_components.py +++ b/benchmarks/bench_kernel_pca_solvers_time_vs_n_components.py @@ -78,6 +78,7 @@ ref_time = np.empty((len(n_compo_range), n_iter)) * np.nan a_time = np.empty((len(n_compo_range), n_iter)) * np.nan r_time = np.empty((len(n_compo_range), n_iter)) * np.nan +rv_time = np.empty((len(n_compo_range), n_iter)) * np.nan # loop for j, n_components in enumerate(n_compo_range): n_components = int(n_components) @@ -119,6 +120,19 @@ # check that the result is still correct despite the approximation assert_array_almost_equal(np.abs(r_pred), np.abs(ref_pred)) + # D- randomized_value + print(" - randomized_value solver") + for i in range(n_iter): + start_time = time.perf_counter() + rv_pred = ( + KernelPCA(n_components, eigen_solver="randomized_value") + .fit(X_train) + .transform(X_test) + ) + rv_time[j, i] = time.perf_counter() - start_time + # check that the result is still correct despite the approximation + assert_array_almost_equal(np.abs(rv_pred), np.abs(ref_pred)) + # Compute statistics for the 3 methods avg_ref_time = ref_time.mean(axis=1) std_ref_time = ref_time.std(axis=1) @@ -126,6 +140,8 @@ std_a_time = a_time.std(axis=1) avg_r_time = r_time.mean(axis=1) std_r_time = r_time.std(axis=1) +avg_rv_time = rv_time.mean(axis=1) +std_rv_time = rv_time.std(axis=1) # 4- Plots @@ -160,6 +176,15 @@ color="b", label="randomized", ) +ax.errorbar( + n_compo_range, + avg_rv_time, + yerr=std_rv_time, + marker="x", + linestyle="", + color="purple", + label="randomized_value", +) ax.legend(loc="upper left") # customize axes diff --git a/benchmarks/bench_kernel_pca_solvers_time_vs_n_samples.py b/benchmarks/bench_kernel_pca_solvers_time_vs_n_samples.py index cae74c6f442ff..8c956d3ec0c24 100644 --- a/benchmarks/bench_kernel_pca_solvers_time_vs_n_samples.py +++ b/benchmarks/bench_kernel_pca_solvers_time_vs_n_samples.py @@ -80,6 +80,7 @@ ref_time = np.empty((len(n_samples_range), n_iter)) * np.nan a_time = np.empty((len(n_samples_range), n_iter)) * np.nan r_time = np.empty((len(n_samples_range), n_iter)) * np.nan +rv_time = np.empty((len(n_samples_range), n_iter)) * np.nan # loop for j, n_samples in enumerate(n_samples_range): @@ -125,6 +126,19 @@ # check that the result is still correct despite the approximation assert_array_almost_equal(np.abs(r_pred), np.abs(ref_pred)) + # D- randomized_value + print(" - randomized_value") + for i in range(n_iter): + start_time = time.perf_counter() + rv_pred = ( + KernelPCA(n_components, eigen_solver="randomized_value") + .fit(X_train) + .transform(X_test) + ) + rv_time[j, i] = time.perf_counter() - start_time + # check that the result is still correct despite the approximation + assert_array_almost_equal(np.abs(rv_pred), np.abs(ref_pred)) + # Compute statistics for the 3 methods avg_ref_time = ref_time.mean(axis=1) std_ref_time = ref_time.std(axis=1) @@ -132,7 +146,8 @@ std_a_time = a_time.std(axis=1) avg_r_time = r_time.mean(axis=1) std_r_time = r_time.std(axis=1) - +avg_rv_time = rv_time.mean(axis=1) +std_rv_time = rv_time.std(axis=1) # 4- Plots # -------- @@ -167,6 +182,15 @@ color="b", label="randomized", ) +ax.errorbar( + n_samples_range, + avg_rv_time, + yerr=std_rv_time, + marker="x", + linestyle="", + color="purple", + label="randomized_value", +) ax.legend(loc="upper left") # customize axes diff --git a/sklearn/decomposition/_kernel_pca.py b/sklearn/decomposition/_kernel_pca.py index 37ff77c8d7c64..c4dd0a4b00f3a 100644 --- a/sklearn/decomposition/_kernel_pca.py +++ b/sklearn/decomposition/_kernel_pca.py @@ -263,7 +263,7 @@ class KernelPCA(ClassNamePrefixFeaturesOutMixin, TransformerMixin, BaseEstimator "kernel_params": [dict, None], "alpha": [Interval(Real, 0, None, closed="left")], "fit_inverse_transform": ["boolean"], - "eigen_solver": [StrOptions({"auto", "dense", "arpack", "randomized"})], + "eigen_solver": [StrOptions({"auto", "dense", "arpack", "randomized", "randomized_value"})], "tol": [Interval(Real, 0, None, closed="left")], "max_iter": [ Interval(Integral, 1, None, closed="left"), @@ -363,6 +363,14 @@ def _fit_transform_in_place(self, K): random_state=self.random_state, selection="module", ) + elif eigen_solver == "randomized_value": + self.eigenvalues_, self.eigenvectors_ = _randomized_eigsh( + K, + n_components=n_components, + n_iter=self.iterated_power, + random_state=self.random_state, + selection="value", + ) # make sure that the eigenvalues are ok and fix numerical issues self.eigenvalues_ = _check_psd_eigenvalues( diff --git a/sklearn/decomposition/tests/test_kernel_pca.py b/sklearn/decomposition/tests/test_kernel_pca.py index 57ae75c184622..40fc888e8e152 100644 --- a/sklearn/decomposition/tests/test_kernel_pca.py +++ b/sklearn/decomposition/tests/test_kernel_pca.py @@ -37,7 +37,7 @@ def histogram(x, y, **kwargs): assert kwargs == {} # no kernel_params that we didn't ask for return np.minimum(x, y).sum() - for eigen_solver in ("auto", "dense", "arpack", "randomized"): + for eigen_solver in ("auto", "dense", "arpack", "randomized", "randomized_value"): for kernel in ("linear", "rbf", "poly", histogram): # histogram kernel produces singular matrix inside linalg.solve # XXX use a least-squares approximation? @@ -128,7 +128,7 @@ def test_kernel_pca_sparse(csr_container, global_random_seed): X_fit = csr_container(rng.random_sample((5, 4))) X_pred = csr_container(rng.random_sample((2, 4))) - for eigen_solver in ("auto", "arpack", "randomized"): + for eigen_solver in ("auto", "arpack", "randomized", "randomized_value"): for kernel in ("linear", "rbf", "poly"): # transform fit data kpca = KernelPCA( @@ -191,7 +191,7 @@ def test_kernel_pca_n_components(): X_fit = rng.random_sample((5, 4)) X_pred = rng.random_sample((2, 4)) - for eigen_solver in ("dense", "arpack", "randomized"): + for eigen_solver in ("dense", "arpack", "randomized", "randomized_value"): for c in [1, 2, 4]: kpca = KernelPCA(n_components=c, eigen_solver=eigen_solver) shape = kpca.fit(X_fit).transform(X_pred).shape @@ -252,7 +252,7 @@ def test_kernel_pca_precomputed(global_random_seed): X_fit = rng.random_sample((5, 4)) X_pred = rng.random_sample((2, 4)) - for eigen_solver in ("dense", "arpack", "randomized"): + for eigen_solver in ("dense", "arpack", "randomized", "randomized_value"): X_kpca = ( KernelPCA(4, eigen_solver=eigen_solver, random_state=0) .fit(X_fit) @@ -284,7 +284,7 @@ def test_kernel_pca_precomputed(global_random_seed): assert_array_almost_equal(np.abs(X_kpca_train), np.abs(X_kpca_train2)) -@pytest.mark.parametrize("solver", ["auto", "dense", "arpack", "randomized"]) +@pytest.mark.parametrize("solver", ["auto", "dense", "arpack", "randomized", "randomized_value"]) def test_kernel_pca_precomputed_non_symmetric(solver): """Check that the kernel centerer works. @@ -386,7 +386,7 @@ def test_kernel_conditioning(): assert np.all(kpca.eigenvalues_ == _check_psd_eigenvalues(kpca.eigenvalues_)) -@pytest.mark.parametrize("solver", ["auto", "dense", "arpack", "randomized"]) +@pytest.mark.parametrize("solver", ["auto", "dense", "arpack", "randomized", "randomized_value"]) def test_precomputed_kernel_not_psd(solver): """Check how KernelPCA works with non-PSD kernels depending on n_components @@ -440,7 +440,7 @@ def test_precomputed_kernel_not_psd(solver): @pytest.mark.parametrize("n_components", [4, 10, 20]) def test_kernel_pca_solvers_equivalence(n_components): - """Check that 'dense' 'arpack' & 'randomized' solvers give similar results""" + """Check that 'dense' 'arpack' & 'randomized' & 'randomized_value' solvers give similar results""" # Generate random data n_train, n_test = 1_000, 100 @@ -474,6 +474,15 @@ def test_kernel_pca_solvers_equivalence(n_components): # check that the result is still correct despite the approximation assert_array_almost_equal(np.abs(r_pred), np.abs(ref_pred)) + # randomized_value + rv_pred = ( + KernelPCA(n_components, eigen_solver="randomized_value", random_state=0) + .fit(X_fit) + .transform(X_pred) + ) + # check that the result is still correct despite the approximation + assert_array_almost_equal(np.abs(rv_pred), np.abs(ref_pred)) + def test_kernel_pca_inverse_transform_reconstruction(): """Test if the reconstruction is a good approximation. diff --git a/sklearn/manifold/_isomap.py b/sklearn/manifold/_isomap.py index 90154470c18a4..127a3d725d51e 100644 --- a/sklearn/manifold/_isomap.py +++ b/sklearn/manifold/_isomap.py @@ -57,6 +57,8 @@ class Isomap(ClassNamePrefixFeaturesOutMixin, TransformerMixin, BaseEstimator): 'dense' : Use a direct solver (i.e. LAPACK) for the eigenvalue decomposition. + 'randomized_value' : Use randomized solver in order to reduce complexity. + tol : float, default=0 Convergence tolerance passed to arpack or lobpcg. not used if eigen_solver == 'dense'. @@ -169,7 +171,7 @@ class Isomap(ClassNamePrefixFeaturesOutMixin, TransformerMixin, BaseEstimator): "n_neighbors": [Interval(Integral, 1, None, closed="left"), None], "radius": [Interval(Real, 0, None, closed="both"), None], "n_components": [Interval(Integral, 1, None, closed="left")], - "eigen_solver": [StrOptions({"auto", "arpack", "dense"})], + "eigen_solver": [StrOptions({"auto", "arpack", "dense", "randomized_value"})], "tol": [Interval(Real, 0, None, closed="left")], "max_iter": [Interval(Integral, 1, None, closed="left"), None], "path_method": [StrOptions({"auto", "FW", "D"})], diff --git a/sklearn/manifold/tests/test_isomap.py b/sklearn/manifold/tests/test_isomap.py index e38b92442e58d..ef6581dc973c2 100644 --- a/sklearn/manifold/tests/test_isomap.py +++ b/sklearn/manifold/tests/test_isomap.py @@ -15,7 +15,7 @@ ) from sklearn.utils.fixes import CSR_CONTAINERS -eigen_solvers = ["auto", "dense", "arpack"] +eigen_solvers = ["auto", "dense", "arpack", "randomized_value"] path_methods = ["auto", "FW", "D"] diff --git a/sklearn/utils/extmath.py b/sklearn/utils/extmath.py index b4af090344d74..b8e5b24ad7547 100644 --- a/sklearn/utils/extmath.py +++ b/sklearn/utils/extmath.py @@ -561,6 +561,92 @@ def randomized_svd( return U[:, :n_components], s[:n_components], Vt[:n_components, :] +def randomized_eigen_decomposition( + A, + n_components, + n_oversamples=10, + n_iter="auto", + power_iteration_normalizer="auto", + random_state=None, +): + """ + Approximate eigenvalue decomposition of a Hermitian matrix A ≈ U Λ U* + + Parameters + ---------- + A : ndarray of shape (n, n) + Hermitian matrix to decompose. + + n_components : int + Number of eigenvalues and eigenvectors to extract. + + n_oversamples : int, default=10 + Additional number of random vectors to sample the range of A so as to + ensure proper conditioning. The total number of random vectors used to + find the range of A is n_components + n_oversamples. + + n_iter : int or "auto", default="auto" + Number of power iterations. Can be used to deal with slow singular + value decay of A. When "auto", it is set to 4 (or 7 if n_components is + small compared to n). + + power_iteration_normalizer : {'auto', 'QR', 'LU', None}, default='auto' + Normalizer for power iterations. Used by randomized_range_finder. + + random_state : int, RandomState instance or None, default=None + Pseudo-random number generator to control randomness. + + Returns + ------- + U : ndarray of shape (n, n_components) + Approximate eigenvectors. + + Lambda : ndarray of shape (n_components,) + Approximate eigenvalues. + + Notes + ----- + This implementation follows Algorithm 5.3 (Direct Eigenvalue Decomposition) + from: + + Halko, N., Martinsson, P.G., & Tropp, J.A. (2011). + Finding structure with randomness: Probabilistic algorithms for + constructing approximate matrix decompositions. + SIAM Review, 53(2), 217–288. + https://doi.org/10.1137/090771806 + """ + random_state = check_random_state(random_state) + n_random = n_components + n_oversamples + n = A.shape[0] + + if n_iter == "auto": + n_iter = 7 if n_components < 0.1 * n else 4 + + # Step 1: compute an orthonormal matrix Q approximating the range of A + Q = randomized_range_finder( + A, + size=n_random, + n_iter=n_iter, + power_iteration_normalizer=power_iteration_normalizer, + random_state=random_state, + ) + + # Step 2: project A to the low-dimensional subspace + B = Q.T @ A @ Q + + # Step 3: compute the eigenvalue decomposition of the small matrix + xp, is_array_api_compliant = get_namespace(B) + if is_array_api_compliant: + Lambda, V = xp.linalg.eigh(B) + else: + Lambda, V = linalg.eigh(B) + + # Step 4: compute the approximate eigenvectors of A + U = Q @ V + + return Lambda[-n_components:], U[:, -n_components:] + + def _randomized_eigsh( M, n_components, @@ -647,9 +733,13 @@ def _randomized_eigsh( effective rank. Usually, `n_components` is chosen to be greater than k so increasing `n_oversamples` up to `n_components` should be enough. - Strategy 'value': not implemented yet. - Algorithms 5.3, 5.4 and 5.5 in the Halko et al paper should provide good - candidates for a future implementation. + Strategy 'value': + This randomized algorithm efficiently approximates eigendecompositions + of Hermitian matrices by projection onto a lower-dimensional subspace using basis Q. + Relying on Algorithm 5.3 from Halko et al.'s article, it computes B = Q*AQ, finds its + eigendecomposition B = VΛV*, and forms U = QV to yield A = UΛU* + with bounded error. Unlike the 'module' strategy, it works efficiently with + non-positive semidefinite matrices, handling both positive and negative eigenvalues directly. Strategy 'module': The principle is that for diagonalizable matrices, the singular values and @@ -681,8 +771,15 @@ def _randomized_eigsh( Halko, et al. (2009) """ if selection == "value": # pragma: no cover - # to do : an algorithm can be found in the Halko et al reference - raise NotImplementedError() + # Call Hako et al 5.3 randomized eigs solver + eigvals, eigvecs = randomized_eigen_decomposition( + M, + n_components=n_components, + n_oversamples=n_oversamples, + n_iter=n_iter, + power_iteration_normalizer=power_iteration_normalizer, + random_state=random_state, + ) elif selection == "module": # Note: no need for deterministic U and Vt (flip_sign=True), diff --git a/sklearn/utils/tests/test_extmath.py b/sklearn/utils/tests/test_extmath.py index 74cb47388692f..8585d9b94653e 100644 --- a/sklearn/utils/tests/test_extmath.py +++ b/sklearn/utils/tests/test_extmath.py @@ -189,11 +189,7 @@ def test_randomized_eigsh(dtype): # eigenvectors assert eigvecs.shape == (4, 2) - # with 'value' selection method, the negative eigenvalue does not show up - with pytest.raises(NotImplementedError): - _randomized_eigsh(X, n_components=2, selection="value") - - + @pytest.mark.parametrize("k", (10, 50, 100, 199, 200)) def test_randomized_eigsh_compared_to_others(k): """Check that `_randomized_eigsh` is similar to other `eigsh` @@ -265,6 +261,48 @@ def test_randomized_eigsh_compared_to_others(k): eigvecs_arpack, _ = svd_flip(eigvecs_arpack, dummy_vecs) assert_array_almost_equal(eigvecs_arpack, eigvecs_lapack, decimal=8) + +@pytest.mark.parametrize("k", (10, 50, 100, 199, 200)) +def test_randomized_eigsh_value_compared_to_others(k): + """Check that `_randomized_eigsh(value)` is similar to other `eigsh` + + Tests that for a random PSD matrix, `_randomized_eigsh(value)` provides results + comparable to LAPACK (scipy.linalg.eigh) and ARPACK + (scipy.sparse.linalg.eigsh). + """ + n_features = 200 + # make a random PSD matrix + X = make_sparse_spd_matrix(n_features, random_state=0) + + # compare two versions of randomized + # rough and fast + eigvals, eigvecs = _randomized_eigsh( + X, + n_components=k, + n_oversamples=20, + selection="value", + n_iter=25, + random_state=0, + ) + + # more accurate but slow (TODO find realistic settings here) + + # with LAPACK + eigvals_lapack, eigvecs_lapack = eigh( + X, subset_by_index=(n_features - k, n_features - 1) + ) + + # - eigenvalues comparison + assert eigvals.shape == (k,) + # comparison precision + assert_array_almost_equal(eigvals, eigvals_lapack, decimal=6) + # -- eigenvectors comparison + assert eigvecs_lapack.shape == (n_features, k) + dummy_vecs = np.zeros_like(eigvecs).T + eigvecs, _ = svd_flip(eigvecs, dummy_vecs) + eigvecs_lapack, _ = svd_flip(eigvecs_lapack, dummy_vecs) + assert_array_almost_equal(eigvecs, eigvecs_lapack, decimal=4) + @pytest.mark.parametrize( "n,rank", From 66573f0f639597db72c3a48f30635e40e91c0c10 Mon Sep 17 00:00:00 2001 From: oussama er-rabie Date: Wed, 16 Apr 2025 23:31:49 +0200 Subject: [PATCH 2/6] add changelog --- .../sklearn.utils/XXXX.feature.rst | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) create mode 100644 doc/whats_new/upcoming_changes/sklearn.utils/XXXX.feature.rst diff --git a/doc/whats_new/upcoming_changes/sklearn.utils/XXXX.feature.rst b/doc/whats_new/upcoming_changes/sklearn.utils/XXXX.feature.rst new file mode 100644 index 0000000000000..91c89d6090f98 --- /dev/null +++ b/doc/whats_new/upcoming_changes/sklearn.utils/XXXX.feature.rst @@ -0,0 +1,17 @@ +Feature +------------ + +- Added :func:`randomized_eigen_decomposition`, an approximate eigen decomposition method based on the approach from + **Halko, Martinsson, and Tropp (2011)**: *Finding Structure with Randomness: Probabilistic Algorithms for Constructing + Approximate Matrix Decompositions*. This method provides a faster alternative to existing eigen decomposition techniques. + +- Integrated :func:`eigen_decomposition_one_pass` into :class:`sklearn.manifold.Isomap` and + :class:`sklearn.decomposition.KernelPCA` as an additional option for eigen decomposition. + +- Added a test suite comparing the new method to existing solvers (:obj:`arpack`, :obj:`dense`, etc.), ensuring numerical + accuracy and stability. + +- Included benchmarks to analyze the **performance improvements** over traditional eigen decomposition approaches. + +by :user: `Sylvain Marié<@smarie>`, `Mohamed yaich<@yaichm>`, `Oussama Er-rabie<@eroussama>`, `Mohamed Dlimi<@Dlimim>`, +`Hamza Zeroual<@HamzaLuffy>` and `Amine Hannoun<@AmineHannoun>`. \ No newline at end of file From ce755ebe9ee11c4ba54c082faa90a6bbd7a71694 Mon Sep 17 00:00:00 2001 From: Mohamed Yaich Date: Sun, 20 Apr 2025 14:08:22 +0200 Subject: [PATCH 3/6] Add randomized_eigsh(selection='value') for fast eigendecomposition with tests and integration into Isomap and KernelPCA --- ...somap_execution_time_full_vs_randomized.py | 102 ++++++++++++++++++ ...lvers_n_samples_vs_reconstruction_error.py | 4 +- sklearn/utils/extmath.py | 1 - sklearn/utils/tests/test_extmath.py | 8 +- 4 files changed, 106 insertions(+), 9 deletions(-) create mode 100644 benchmarks/bench_isomap_execution_time_full_vs_randomized.py diff --git a/benchmarks/bench_isomap_execution_time_full_vs_randomized.py b/benchmarks/bench_isomap_execution_time_full_vs_randomized.py new file mode 100644 index 0000000000000..7aeb3f3b68eb2 --- /dev/null +++ b/benchmarks/bench_isomap_execution_time_full_vs_randomized.py @@ -0,0 +1,102 @@ +""" +====================================================================== +Isomap Solvers Benchmark: Execution Time vs Number of Samples +====================================================================== + +This benchmark demonstrates how the choice of eigen_solver in Isomap +can significantly affect computation time, especially as the dataset +size increases. + +Description: +------------ +Synthetic datasets are generated using `make_classification` with a +fixed number of features. The number of samples is +varied from 1000 to 4000. + +For each setting, Isomap is applied using two different solvers: +- 'auto' (full eigendecomposition) +- 'randomized_value' + +The execution time of each solver is measured for each number of +samples, and the average time over multiple runs (default: 3) is +plotted. + +What you can observe: +--------------------- +If n_components < 10, the randomized and auto solvers produce similar +results (in this case, the arpack solver is selected). +However, when n_components > 10, the randomized solver becomes significantly +faster, especially as the number of samples increases. + +""" + +import time +import numpy as np +import matplotlib.pyplot as plt +from sklearn.datasets import make_classification +from sklearn.manifold import Isomap + +# 1 - Experiment Setup +#-- -- -- -- -- -- -- -- -- - +n_samples_list = [1000, 2000, 3000, 4000] +n_neighbors = 30 +n_components_list = [2, 10] +n_features = 100 +n_iter = 3 # Number of repetitions for averaging execution time + +#Store timings for each value of n_components +timing_all = {} + +for n_components in n_components_list: +#Create containers for timing results + timing = { + "auto": np.zeros((len(n_samples_list), n_iter)), + "randomized_value": np.zeros((len(n_samples_list), n_iter)) + } + + for j, n in enumerate(n_samples_list): +#Generate synthetic classification dataset + X, _ = make_classification( + n_samples=n, + n_features=n_features, + n_redundant=0, + n_clusters_per_class=1, + n_classes=1, + random_state=42 + ) + +#Evaluate both solvers for multiple repetitions + for solver in ["auto", "randomized_value"]: + for i in range(n_iter): + model = Isomap( + n_neighbors=n_neighbors, + n_components=n_components, + eigen_solver=solver + ) + start = time.perf_counter() + model.fit(X) + elapsed = time.perf_counter() - start + timing[solver][j, i] = elapsed + + timing_all[n_components] = timing + +fig, axes = plt.subplots(1, 2, figsize=(14, 6)) + +for idx, n_components in enumerate(n_components_list): + ax = axes[idx] + timing = timing_all[n_components] + avg_full = timing["auto"].mean(axis=1) + std_full = timing["auto"].std(axis=1) + avg_rand = timing["randomized_value"].mean(axis=1) + std_rand = timing["randomized_value"].std(axis=1) + + ax.errorbar(n_samples_list, avg_full, yerr=std_full, label="Isomap (full)", marker="o", linestyle="-") + ax.errorbar(n_samples_list, avg_rand, yerr=std_rand, label="Isomap (randomized)", marker="x", linestyle="--") + ax.set_xlabel("Number of Samples") + ax.set_ylabel("Execution Time (seconds)") + ax.set_title(f"Isomap Execution Time (n_components = {n_components})") + ax.legend() + ax.grid(True) + +plt.tight_layout() +plt.show() diff --git a/benchmarks/bench_isomap_solvers_n_samples_vs_reconstruction_error.py b/benchmarks/bench_isomap_solvers_n_samples_vs_reconstruction_error.py index 0291b85c8696c..5513c7a332b27 100644 --- a/benchmarks/bench_isomap_solvers_n_samples_vs_reconstruction_error.py +++ b/benchmarks/bench_isomap_solvers_n_samples_vs_reconstruction_error.py @@ -17,7 +17,8 @@ What you can observe: --------------------- -- The difference in performance between the two solvers. +-The randomized and auto solvers yield the same errors for different samples, +which means that randomized provides the same projection quality as auto. Further exploration: --------------------- @@ -41,7 +42,6 @@ n_components = 2 n_iter = 3 # Number of repetitions per sample size -include_arpack = False # Reserved for further testing # 2- Data Generation # ------------------ diff --git a/sklearn/utils/extmath.py b/sklearn/utils/extmath.py index b8e5b24ad7547..4601567d11870 100644 --- a/sklearn/utils/extmath.py +++ b/sklearn/utils/extmath.py @@ -613,7 +613,6 @@ def randomized_eigen_decomposition( Finding structure with randomness: Probabilistic algorithms for constructing approximate matrix decompositions. SIAM Review, 53(2), 217–288. - https://doi.org/10.1137/090771806 """ random_state = check_random_state(random_state) n_random = n_components + n_oversamples diff --git a/sklearn/utils/tests/test_extmath.py b/sklearn/utils/tests/test_extmath.py index 8585d9b94653e..29d0e08bcdbde 100644 --- a/sklearn/utils/tests/test_extmath.py +++ b/sklearn/utils/tests/test_extmath.py @@ -267,15 +267,13 @@ def test_randomized_eigsh_value_compared_to_others(k): """Check that `_randomized_eigsh(value)` is similar to other `eigsh` Tests that for a random PSD matrix, `_randomized_eigsh(value)` provides results - comparable to LAPACK (scipy.linalg.eigh) and ARPACK - (scipy.sparse.linalg.eigsh). + comparable to LAPACK (scipy.linalg.eigh) """ n_features = 200 # make a random PSD matrix X = make_sparse_spd_matrix(n_features, random_state=0) - # compare two versions of randomized - # rough and fast + # with randomized_value eigvals, eigvecs = _randomized_eigsh( X, n_components=k, @@ -285,8 +283,6 @@ def test_randomized_eigsh_value_compared_to_others(k): random_state=0, ) - # more accurate but slow (TODO find realistic settings here) - # with LAPACK eigvals_lapack, eigvecs_lapack = eigh( X, subset_by_index=(n_features - k, n_features - 1) From 53e0102c57efeecb81d859d01a0ddcf7cd5438c5 Mon Sep 17 00:00:00 2001 From: Mohamed Yaich Date: Sun, 27 Apr 2025 12:05:40 +0200 Subject: [PATCH 4/6] Fix linting errors and rename the changelog to match the PR --- benchmarks/bench_isomap_auto_vs_randomized.py | 26 +++++++---- ...somap_execution_time_full_vs_randomized.py | 42 ++++++++++++------ ...lvers_n_samples_vs_reconstruction_error.py | 44 ++++++++++++------- .../{XXXX.feature.rst => 31247.feature.rst} | 0 sklearn/decomposition/_kernel_pca.py | 4 +- .../decomposition/tests/test_kernel_pca.py | 11 +++-- sklearn/utils/extmath.py | 7 +-- sklearn/utils/tests/test_extmath.py | 4 +- 8 files changed, 92 insertions(+), 46 deletions(-) rename doc/whats_new/upcoming_changes/sklearn.utils/{XXXX.feature.rst => 31247.feature.rst} (100%) diff --git a/benchmarks/bench_isomap_auto_vs_randomized.py b/benchmarks/bench_isomap_auto_vs_randomized.py index 95dff4f9e21cb..947e4e27752bc 100644 --- a/benchmarks/bench_isomap_auto_vs_randomized.py +++ b/benchmarks/bench_isomap_auto_vs_randomized.py @@ -3,34 +3,36 @@ Benchmark: Comparing Isomap Solvers - Execution Time vs. Representation ====================================================================== -This benchmark demonstrates how different eigenvalue solvers in Isomap +This benchmark demonstrates how different eigenvalue solvers in Isomap can affect execution time and embedding quality. Description: ------------ -We use a subset of handwritten digits (`load_digits` with 6 classes). -Each data point is projected into a lower-dimensional space (2D) using +We use a subset of handwritten digits (`load_digits` with 6 classes). +Each data point is projected into a lower-dimensional space (2D) using two different solvers (`auto` and `randomized`). What you can observe: ---------------------- - The `auto` solver provides a reference solution. -- The `randomized` solver is tested for comparison in terms of +- The `randomized` solver is tested for comparison in terms of representation quality and execution time. Further exploration: --------------------- -You can modify the number of neighbors (`n_neighbors`) or experiment with +You can modify the number of neighbors (`n_neighbors`) or experiment with other Isomap solvers. """ import time -import numpy as np + import matplotlib.pyplot as plt +import numpy as np from matplotlib import offsetbox + from sklearn.datasets import load_digits -from sklearn.preprocessing import MinMaxScaler from sklearn.manifold import Isomap +from sklearn.preprocessing import MinMaxScaler # 1- Data Loading # --------------- @@ -38,6 +40,7 @@ X, y = digits.data, digits.target n_neighbors = 30 # Number of neighbors for Isomap + # 2- Visualization Function # ------------------------- def plot_embedding(ax, X, title): @@ -70,11 +73,16 @@ def plot_embedding(ax, X, title): ax.set_title(title) ax.axis("off") + # 3- Define Embeddings and Benchmark # ---------------------------------- embeddings = { - "Isomap (auto solver)": Isomap(n_neighbors=n_neighbors, n_components=2, eigen_solver='auto'), - "Isomap (randomized solver)": Isomap(n_neighbors=n_neighbors, n_components=2, eigen_solver='randomized_value'), + "Isomap (auto solver)": Isomap( + n_neighbors=n_neighbors, n_components=2, eigen_solver="auto" + ), + "Isomap (randomized solver)": Isomap( + n_neighbors=n_neighbors, n_components=2, eigen_solver="randomized_value" + ), } projections, timing = {}, {} diff --git a/benchmarks/bench_isomap_execution_time_full_vs_randomized.py b/benchmarks/bench_isomap_execution_time_full_vs_randomized.py index 7aeb3f3b68eb2..68f2861f9aac0 100644 --- a/benchmarks/bench_isomap_execution_time_full_vs_randomized.py +++ b/benchmarks/bench_isomap_execution_time_full_vs_randomized.py @@ -15,7 +15,7 @@ For each setting, Isomap is applied using two different solvers: - 'auto' (full eigendecomposition) -- 'randomized_value' +- 'randomized_value' The execution time of each solver is measured for each number of samples, and the average time over multiple runs (default: 3) is @@ -23,7 +23,7 @@ What you can observe: --------------------- -If n_components < 10, the randomized and auto solvers produce similar +If n_components < 10, the randomized and auto solvers produce similar results (in this case, the arpack solver is selected). However, when n_components > 10, the randomized solver becomes significantly faster, especially as the number of samples increases. @@ -31,47 +31,49 @@ """ import time -import numpy as np + import matplotlib.pyplot as plt +import numpy as np + from sklearn.datasets import make_classification from sklearn.manifold import Isomap # 1 - Experiment Setup -#-- -- -- -- -- -- -- -- -- - +# -- -- -- -- -- -- -- -- -- - n_samples_list = [1000, 2000, 3000, 4000] n_neighbors = 30 n_components_list = [2, 10] n_features = 100 n_iter = 3 # Number of repetitions for averaging execution time -#Store timings for each value of n_components +# Store timings for each value of n_components timing_all = {} for n_components in n_components_list: -#Create containers for timing results + # Create containers for timing results timing = { "auto": np.zeros((len(n_samples_list), n_iter)), - "randomized_value": np.zeros((len(n_samples_list), n_iter)) + "randomized_value": np.zeros((len(n_samples_list), n_iter)), } for j, n in enumerate(n_samples_list): -#Generate synthetic classification dataset + # Generate synthetic classification dataset X, _ = make_classification( n_samples=n, n_features=n_features, n_redundant=0, n_clusters_per_class=1, n_classes=1, - random_state=42 + random_state=42, ) -#Evaluate both solvers for multiple repetitions + # Evaluate both solvers for multiple repetitions for solver in ["auto", "randomized_value"]: for i in range(n_iter): model = Isomap( n_neighbors=n_neighbors, n_components=n_components, - eigen_solver=solver + eigen_solver=solver, ) start = time.perf_counter() model.fit(X) @@ -90,8 +92,22 @@ avg_rand = timing["randomized_value"].mean(axis=1) std_rand = timing["randomized_value"].std(axis=1) - ax.errorbar(n_samples_list, avg_full, yerr=std_full, label="Isomap (full)", marker="o", linestyle="-") - ax.errorbar(n_samples_list, avg_rand, yerr=std_rand, label="Isomap (randomized)", marker="x", linestyle="--") + ax.errorbar( + n_samples_list, + avg_full, + yerr=std_full, + label="Isomap (full)", + marker="o", + linestyle="-", + ) + ax.errorbar( + n_samples_list, + avg_rand, + yerr=std_rand, + label="Isomap (randomized)", + marker="x", + linestyle="--", + ) ax.set_xlabel("Number of Samples") ax.set_ylabel("Execution Time (seconds)") ax.set_title(f"Isomap Execution Time (n_components = {n_components})") diff --git a/benchmarks/bench_isomap_solvers_n_samples_vs_reconstruction_error.py b/benchmarks/bench_isomap_solvers_n_samples_vs_reconstruction_error.py index 5513c7a332b27..bc518e570d29d 100644 --- a/benchmarks/bench_isomap_solvers_n_samples_vs_reconstruction_error.py +++ b/benchmarks/bench_isomap_solvers_n_samples_vs_reconstruction_error.py @@ -3,13 +3,13 @@ Benchmark: Isomap Reconstruction Error - Standard vs. Randomized Solver ======================================================================== -This benchmark illustrates how the number of samples impacts the quality +This benchmark illustrates how the number of samples impacts the quality of the Isomap embedding, using reconstruction error as a metric. Description: ------------ -We generate synthetic 2D non-linear data (two concentric circles) with -varying numbers of samples. For each subset, we compare the reconstruction +We generate synthetic 2D non-linear data (two concentric circles) with +varying numbers of samples. For each subset, we compare the reconstruction error of two Isomap solvers: - The `auto` solver (standard dense or arpack, selected automatically). @@ -25,10 +25,11 @@ - Modify the number of neighbors or iterations. """ -import numpy as np import matplotlib.pyplot as plt -from sklearn.manifold import Isomap +import numpy as np + from sklearn.datasets import make_circles +from sklearn.manifold import Isomap # 1- Experiment Configuration # --------------------------- @@ -36,7 +37,10 @@ n_samples_grid_size = 4 # Number of sample sizes to test n_samples_range = [ - int(min_n_samples + np.floor((x / (n_samples_grid_size - 1)) * (max_n_samples - min_n_samples))) + int( + min_n_samples + + np.floor((x / (n_samples_grid_size - 1)) * (max_n_samples - min_n_samples)) + ) for x in range(0, n_samples_grid_size) ] @@ -46,7 +50,9 @@ # 2- Data Generation # ------------------ n_features = 2 -X_full, y_full = make_circles(n_samples=max_n_samples, factor=0.3, noise=0.05, random_state=0) +X_full, y_full = make_circles( + n_samples=max_n_samples, factor=0.3, noise=0.05, random_state=0 +) # 3- Benchmark Execution # ---------------------- @@ -58,8 +64,10 @@ print(f"Computing for n_samples = {n_samples}") # Instantiate Isomap solvers - isomap_randomized = Isomap(n_neighbors=50, n_components=n_components, eigen_solver='randomized_value') - isomap_auto = Isomap(n_neighbors=50, n_components=n_components, eigen_solver='auto') + isomap_randomized = Isomap( + n_neighbors=50, n_components=n_components, eigen_solver="randomized_value" + ) + isomap_auto = Isomap(n_neighbors=50, n_components=n_components, eigen_solver="auto") # Fit and record reconstruction error isomap_randomized.fit(X) @@ -73,12 +81,18 @@ # 4- Results Visualization # ------------------------ plt.figure(figsize=(10, 6)) -plt.scatter(n_samples_range, errors_full, label='Isomap (auto)', color='b', marker='*') -plt.scatter(n_samples_range, errors_randomized, label='Isomap (randomized)', color='r', marker='x') - -plt.title('Isomap Reconstruction Error vs. Number of Samples') -plt.xlabel('Number of Samples') -plt.ylabel('Reconstruction Error') +plt.scatter(n_samples_range, errors_full, label="Isomap (auto)", color="b", marker="*") +plt.scatter( + n_samples_range, + errors_randomized, + label="Isomap (randomized)", + color="r", + marker="x", +) + +plt.title("Isomap Reconstruction Error vs. Number of Samples") +plt.xlabel("Number of Samples") +plt.ylabel("Reconstruction Error") plt.legend() plt.grid(True) plt.tight_layout() diff --git a/doc/whats_new/upcoming_changes/sklearn.utils/XXXX.feature.rst b/doc/whats_new/upcoming_changes/sklearn.utils/31247.feature.rst similarity index 100% rename from doc/whats_new/upcoming_changes/sklearn.utils/XXXX.feature.rst rename to doc/whats_new/upcoming_changes/sklearn.utils/31247.feature.rst diff --git a/sklearn/decomposition/_kernel_pca.py b/sklearn/decomposition/_kernel_pca.py index c4dd0a4b00f3a..9ec7d6a90b574 100644 --- a/sklearn/decomposition/_kernel_pca.py +++ b/sklearn/decomposition/_kernel_pca.py @@ -263,7 +263,9 @@ class KernelPCA(ClassNamePrefixFeaturesOutMixin, TransformerMixin, BaseEstimator "kernel_params": [dict, None], "alpha": [Interval(Real, 0, None, closed="left")], "fit_inverse_transform": ["boolean"], - "eigen_solver": [StrOptions({"auto", "dense", "arpack", "randomized", "randomized_value"})], + "eigen_solver": [ + StrOptions({"auto", "dense", "arpack", "randomized", "randomized_value"}) + ], "tol": [Interval(Real, 0, None, closed="left")], "max_iter": [ Interval(Integral, 1, None, closed="left"), diff --git a/sklearn/decomposition/tests/test_kernel_pca.py b/sklearn/decomposition/tests/test_kernel_pca.py index 40fc888e8e152..36165dc304c97 100644 --- a/sklearn/decomposition/tests/test_kernel_pca.py +++ b/sklearn/decomposition/tests/test_kernel_pca.py @@ -284,7 +284,9 @@ def test_kernel_pca_precomputed(global_random_seed): assert_array_almost_equal(np.abs(X_kpca_train), np.abs(X_kpca_train2)) -@pytest.mark.parametrize("solver", ["auto", "dense", "arpack", "randomized", "randomized_value"]) +@pytest.mark.parametrize( + "solver", ["auto", "dense", "arpack", "randomized", "randomized_value"] +) def test_kernel_pca_precomputed_non_symmetric(solver): """Check that the kernel centerer works. @@ -386,7 +388,9 @@ def test_kernel_conditioning(): assert np.all(kpca.eigenvalues_ == _check_psd_eigenvalues(kpca.eigenvalues_)) -@pytest.mark.parametrize("solver", ["auto", "dense", "arpack", "randomized", "randomized_value"]) +@pytest.mark.parametrize( + "solver", ["auto", "dense", "arpack", "randomized", "randomized_value"] +) def test_precomputed_kernel_not_psd(solver): """Check how KernelPCA works with non-PSD kernels depending on n_components @@ -440,7 +444,8 @@ def test_precomputed_kernel_not_psd(solver): @pytest.mark.parametrize("n_components", [4, 10, 20]) def test_kernel_pca_solvers_equivalence(n_components): - """Check that 'dense' 'arpack' & 'randomized' & 'randomized_value' solvers give similar results""" + """Check that 'dense' 'arpack' & 'randomized' & 'randomized_value' + solvers give similar results""" # Generate random data n_train, n_test = 1_000, 100 diff --git a/sklearn/utils/extmath.py b/sklearn/utils/extmath.py index fe8aab044bcef..bbbfc13368e67 100644 --- a/sklearn/utils/extmath.py +++ b/sklearn/utils/extmath.py @@ -778,10 +778,11 @@ def _randomized_eigsh( Strategy 'value': This randomized algorithm efficiently approximates eigendecompositions of Hermitian matrices by projection onto a lower-dimensional subspace using basis Q. - Relying on Algorithm 5.3 from Halko et al.'s article, it computes B = Q*AQ, finds its - eigendecomposition B = VΛV*, and forms U = QV to yield A = UΛU* + Relying on Algorithm 5.3 from Halko et al.'s article, it computes B = Q*AQ, + finds its eigendecomposition B = VΛV*, and forms U = QV to yield A = UΛU* with bounded error. Unlike the 'module' strategy, it works efficiently with - non-positive semidefinite matrices, handling both positive and negative eigenvalues directly. + non-positive semidefinite matrices, handling both positive and negative + eigenvalues directly. Strategy 'module': The principle is that for diagonalizable matrices, the singular values and diff --git a/sklearn/utils/tests/test_extmath.py b/sklearn/utils/tests/test_extmath.py index 15f9d337fc8fd..b4a5e0dbff026 100644 --- a/sklearn/utils/tests/test_extmath.py +++ b/sklearn/utils/tests/test_extmath.py @@ -198,7 +198,7 @@ def test_randomized_eigsh(dtype): # eigenvectors assert eigvecs.shape == (4, 2) - + @pytest.mark.parametrize("k", (10, 50, 100, 199, 200)) def test_randomized_eigsh_compared_to_others(k): """Check that `_randomized_eigsh` is similar to other `eigsh` @@ -270,7 +270,7 @@ def test_randomized_eigsh_compared_to_others(k): eigvecs_arpack, _ = svd_flip(eigvecs_arpack, dummy_vecs) assert_array_almost_equal(eigvecs_arpack, eigvecs_lapack, decimal=8) - + @pytest.mark.parametrize("k", (10, 50, 100, 199, 200)) def test_randomized_eigsh_value_compared_to_others(k): """Check that `_randomized_eigsh(value)` is similar to other `eigsh` From 80f8ee7d85550ac5543a0e7576b652d13b765ebd Mon Sep 17 00:00:00 2001 From: Mohamed Yaich Date: Sun, 27 Apr 2025 15:48:53 +0200 Subject: [PATCH 5/6] Fix docstring of randomized_eigen_decomposition --- sklearn/utils/extmath.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/utils/extmath.py b/sklearn/utils/extmath.py index bbbfc13368e67..a5f0e0a46695a 100644 --- a/sklearn/utils/extmath.py +++ b/sklearn/utils/extmath.py @@ -613,7 +613,7 @@ def randomized_eigen_decomposition( random_state=None, ): """ - Approximate eigenvalue decomposition of a Hermitian matrix A ≈ U Λ U* + Approximate eigenvalue decomposition of a Hermitian matrix A ≈ U Λ U*. Parameters ---------- From 91654c8956cb1f2dee9ca2646173655677ae35eb Mon Sep 17 00:00:00 2001 From: Mohamed Yaich Date: Tue, 29 Apr 2025 00:46:00 +0200 Subject: [PATCH 6/6] Added test for array API compliance --- sklearn/utils/tests/test_extmath.py | 30 +++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/sklearn/utils/tests/test_extmath.py b/sklearn/utils/tests/test_extmath.py index b4a5e0dbff026..b34a53f22ba9f 100644 --- a/sklearn/utils/tests/test_extmath.py +++ b/sklearn/utils/tests/test_extmath.py @@ -1153,3 +1153,33 @@ def test_randomized_range_finder_array_api_compliance(array_namespace, device, d assert get_namespace(Q_xp)[0].__name__ == xp.__name__ assert_allclose(_convert_to_numpy(Q_xp, xp), Q_np, atol=atol) + + +@pytest.mark.parametrize( + "array_namespace, device, dtype", + yield_namespace_device_dtype_combinations(), + ids=_get_namespace_device_dtype_ids, +) +def test_randomized_eigsh_value_array_api_compliance(array_namespace, device, dtype): + xp = _array_api_for_tests(array_namespace, device) + + rng = np.random.RandomState(0) + X = rng.normal(size=(30, 10)).astype(dtype) + X = X @ X.T + X_xp = xp.asarray(X, device=device) + n_components = 5 + atol = 1e-5 if dtype == "float32" else 0 + + with config_context(array_api_dispatch=True): + l_np, u_np = _randomized_eigsh( + X, n_components=n_components, selection="value", random_state=0 + ) + l_xp, u_xp = _randomized_eigsh( + X_xp, n_components=n_components, selection="value", random_state=0 + ) + + assert get_namespace(u_xp)[0].__name__ == xp.__name__ + assert get_namespace(l_xp)[0].__name__ == xp.__name__ + + assert_allclose(_convert_to_numpy(u_xp, xp), u_np, atol=atol) + assert_allclose(_convert_to_numpy(l_xp, xp), l_np, atol=atol)