diff --git a/sklearn/manifold/_spectral_embedding.py b/sklearn/manifold/_spectral_embedding.py index b43329a4f89ca..13181a653d1dc 100644 --- a/sklearn/manifold/_spectral_embedding.py +++ b/sklearn/manifold/_spectral_embedding.py @@ -5,6 +5,7 @@ # License: BSD 3 clause +from numbers import Integral, Real import warnings import numpy as np @@ -22,6 +23,7 @@ ) from ..utils._arpack import _init_arpack_v0 from ..utils.extmath import _deterministic_vector_sign_flip +from ..utils._param_validation import Interval, StrOptions from ..utils.fixes import lobpcg from ..metrics.pairwise import rbf_kernel from ..neighbors import kneighbors_graph, NearestNeighbors @@ -542,6 +544,27 @@ class SpectralEmbedding(BaseEstimator): (100, 2) """ + _parameter_constraints: dict = { + "n_components": [Interval(Integral, 1, None, closed="left")], + "affinity": [ + StrOptions( + { + "nearest_neighbors", + "rbf", + "precomputed", + "precomputed_nearest_neighbors", + }, + ), + callable, + ], + "gamma": [Interval(Real, 0, None, closed="left"), None], + "random_state": ["random_state"], + "eigen_solver": [StrOptions({"arpack", "lobpcg", "amg"}), None], + "eigen_tol": [Interval(Real, 0, None, closed="left"), StrOptions({"auto"})], + "n_neighbors": [Interval(Integral, 1, None, closed="left"), None], + "n_jobs": [None, Integral], + } + def __init__( self, n_components=2, @@ -649,28 +672,11 @@ def fit(self, X, y=None): self : object Returns the instance itself. """ + self._validate_params() X = self._validate_data(X, accept_sparse="csr", ensure_min_samples=2) random_state = check_random_state(self.random_state) - if isinstance(self.affinity, str): - if self.affinity not in { - "nearest_neighbors", - "rbf", - "precomputed", - "precomputed_nearest_neighbors", - }: - raise ValueError( - "%s is not a valid affinity. Expected " - "'precomputed', 'rbf', 'nearest_neighbors' " - "or a callable." - % self.affinity - ) - elif not callable(self.affinity): - raise ValueError( - "'affinity' is expected to be an affinity name or a callable. Got: %s" - % self.affinity - ) affinity_matrix = self._get_affinity_matrix(X) self.embedding_ = spectral_embedding( diff --git a/sklearn/manifold/tests/test_spectral_embedding.py b/sklearn/manifold/tests/test_spectral_embedding.py index 37412eac14490..cf7f253b66de1 100644 --- a/sklearn/manifold/tests/test_spectral_embedding.py +++ b/sklearn/manifold/tests/test_spectral_embedding.py @@ -351,17 +351,6 @@ def test_spectral_embedding_unknown_eigensolver(seed=36): se.fit(S) -def test_spectral_embedding_unknown_affinity(seed=36): - # Test that SpectralClustering fails with an unknown affinity type - se = SpectralEmbedding( - n_components=1, - affinity="", - random_state=np.random.RandomState(seed), - ) - with pytest.raises(ValueError): - se.fit(S) - - def test_connectivity(seed=36): # Test that graph connectivity test works as expected graph = np.array( diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index f0ee188c431a9..b4b41f5c88254 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -485,7 +485,6 @@ def test_estimators_do_not_raise_errors_in_init_or_set_params(Estimator): "SelectFromModel", "SpectralBiclustering", "SpectralCoclustering", - "SpectralEmbedding", ]