Skip to content

MAINT Parameters validation for SpectralEmbedding #24103

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Aug 31, 2022
42 changes: 24 additions & 18 deletions sklearn/manifold/_spectral_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# License: BSD 3 clause


from numbers import Integral, Real
import warnings

import numpy as np
Expand All @@ -22,6 +23,7 @@
)
from ..utils._arpack import _init_arpack_v0
from ..utils.extmath import _deterministic_vector_sign_flip
from ..utils._param_validation import Interval, StrOptions
from ..utils.fixes import lobpcg
from ..metrics.pairwise import rbf_kernel
from ..neighbors import kneighbors_graph, NearestNeighbors
Expand Down Expand Up @@ -542,6 +544,27 @@ class SpectralEmbedding(BaseEstimator):
(100, 2)
"""

_parameter_constraints: dict = {
"n_components": [Interval(Integral, 1, None, closed="left")],
"affinity": [
StrOptions(
{
"nearest_neighbors",
"rbf",
"precomputed",
"precomputed_nearest_neighbors",
},
),
callable,
],
"gamma": [Interval(Real, 0, None, closed="left"), None],
"random_state": ["random_state"],
"eigen_solver": [StrOptions({"arpack", "lobpcg", "amg"}), None],
"eigen_tol": [Interval(Real, 0, None, closed="left"), StrOptions({"auto"})],
"n_neighbors": [Interval(Integral, 1, None, closed="left"), None],
"n_jobs": [None, Integral],
}

def __init__(
self,
n_components=2,
Expand Down Expand Up @@ -649,28 +672,11 @@ def fit(self, X, y=None):
self : object
Returns the instance itself.
"""
self._validate_params()

X = self._validate_data(X, accept_sparse="csr", ensure_min_samples=2)

random_state = check_random_state(self.random_state)
if isinstance(self.affinity, str):
if self.affinity not in {
"nearest_neighbors",
"rbf",
"precomputed",
"precomputed_nearest_neighbors",
}:
raise ValueError(
"%s is not a valid affinity. Expected "
"'precomputed', 'rbf', 'nearest_neighbors' "
"or a callable."
% self.affinity
)
elif not callable(self.affinity):
raise ValueError(
"'affinity' is expected to be an affinity name or a callable. Got: %s"
% self.affinity
)

affinity_matrix = self._get_affinity_matrix(X)
self.embedding_ = spectral_embedding(
Expand Down
11 changes: 0 additions & 11 deletions sklearn/manifold/tests/test_spectral_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,17 +351,6 @@ def test_spectral_embedding_unknown_eigensolver(seed=36):
se.fit(S)


def test_spectral_embedding_unknown_affinity(seed=36):
# Test that SpectralClustering fails with an unknown affinity type
se = SpectralEmbedding(
n_components=1,
affinity="<unknown>",
random_state=np.random.RandomState(seed),
)
with pytest.raises(ValueError):
se.fit(S)


def test_connectivity(seed=36):
# Test that graph connectivity test works as expected
graph = np.array(
Expand Down
1 change: 0 additions & 1 deletion sklearn/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,6 @@ def test_estimators_do_not_raise_errors_in_init_or_set_params(Estimator):
"SelectFromModel",
"SpectralBiclustering",
"SpectralCoclustering",
"SpectralEmbedding",
]


Expand Down