diff --git a/doc/whats_new/v1.1.rst b/doc/whats_new/v1.1.rst index 83395c4180c44..ef08c7d7dd454 100644 --- a/doc/whats_new/v1.1.rst +++ b/doc/whats_new/v1.1.rst @@ -89,6 +89,11 @@ Changelog is faster on many datasets, and its results are identical, hence the change. :pr:`21735` by :user:`Aurélien Geron `. +- |Enhancement| :class:`cluster.SpectralClustering` now raises consistent + error messages when passed invalid values for `n_clusters`, `n_init`, + `gamma`, `n_neighbors`, `eigen_tol` or `degree`. + :pr:`21881` by :user:`Hugo Vassard `. + :mod:`sklearn.cross_decomposition` .................................. diff --git a/sklearn/cluster/_spectral.py b/sklearn/cluster/_spectral.py index 9e755769f6294..f33aed91ed24f 100644 --- a/sklearn/cluster/_spectral.py +++ b/sklearn/cluster/_spectral.py @@ -6,6 +6,8 @@ # Wei LI # Andrew Knyazev # License: BSD 3 clause + +import numbers import warnings import numpy as np @@ -14,7 +16,7 @@ from scipy.sparse import csc_matrix from ..base import BaseEstimator, ClusterMixin -from ..utils import check_random_state, as_float_array +from ..utils import as_float_array, check_random_state, check_scalar from ..utils.deprecation import deprecated from ..metrics.pairwise import pairwise_kernels from ..neighbors import kneighbors_graph, NearestNeighbors @@ -662,6 +664,55 @@ def fit(self, X, y=None): "set ``affinity=precomputed``." ) + check_scalar( + self.n_clusters, + "n_clusters", + target_type=numbers.Integral, + min_val=1, + include_boundaries="left", + ) + + check_scalar( + self.n_init, + "n_init", + target_type=numbers.Integral, + min_val=1, + include_boundaries="left", + ) + + check_scalar( + self.gamma, + "gamma", + target_type=numbers.Real, + min_val=1.0, + include_boundaries="left", + ) + + check_scalar( + self.n_neighbors, + "n_neighbors", + target_type=numbers.Integral, + min_val=1, + include_boundaries="left", + ) + + if self.eigen_solver == "arpack": + check_scalar( + self.eigen_tol, + "eigen_tol", + target_type=numbers.Real, + min_val=0, + include_boundaries="left", + ) + + check_scalar( + self.degree, + "degree", + target_type=numbers.Integral, + min_val=1, + include_boundaries="left", + ) + if self.affinity == "nearest_neighbors": connectivity = kneighbors_graph( X, n_neighbors=self.n_neighbors, include_self=True, n_jobs=self.n_jobs diff --git a/sklearn/cluster/tests/test_spectral.py b/sklearn/cluster/tests/test_spectral.py index 702906b3fa0e7..29785c0869fed 100644 --- a/sklearn/cluster/tests/test_spectral.py +++ b/sklearn/cluster/tests/test_spectral.py @@ -28,6 +28,16 @@ except ImportError: amg_loaded = False +centers = np.array([[1, 1], [-1, -1], [1, -1]]) + 10 +X, _ = make_blobs( + n_samples=60, + n_features=2, + centers=centers, + cluster_std=0.4, + shuffle=True, + random_state=0, +) + @pytest.mark.parametrize("eigen_solver", ("arpack", "lobpcg")) @pytest.mark.parametrize("assign_labels", ("kmeans", "discretize", "cluster_qr")) @@ -102,6 +112,48 @@ def test_spectral_unknown_assign_labels(): spectral_clustering(S, n_clusters=2, random_state=0, assign_labels="") +@pytest.mark.parametrize( + "input, params, err_type, err_msg", + [ + (X, {"n_clusters": -1}, ValueError, "n_clusters == -1, must be >= 1"), + (X, {"n_clusters": 0}, ValueError, "n_clusters == 0, must be >= 1"), + ( + X, + {"n_clusters": 1.5}, + TypeError, + "n_clusters must be an instance of ," + " not ", + ), + (X, {"n_init": -1}, ValueError, "n_init == -1, must be >= 1"), + (X, {"n_init": 0}, ValueError, "n_init == 0, must be >= 1"), + ( + X, + {"n_init": 1.5}, + TypeError, + "n_init must be an instance of ," + " not ", + ), + (X, {"gamma": -1}, ValueError, "gamma == -1, must be >= 1"), + (X, {"gamma": 0}, ValueError, "gamma == 0, must be >= 1"), + (X, {"n_neighbors": -1}, ValueError, "n_neighbors == -1, must be >= 1"), + (X, {"n_neighbors": 0}, ValueError, "n_neighbors == 0, must be >= 1"), + ( + X, + {"eigen_tol": -1, "eigen_solver": "arpack"}, + ValueError, + "eigen_tol == -1, must be >= 0", + ), + (X, {"degree": -1}, ValueError, "degree == -1, must be >= 1"), + (X, {"degree": 0}, ValueError, "degree == 0, must be >= 1"), + ], +) +def test_spectral_params_validation(input, params, err_type, err_msg): + """Check the parameters validation in `SpectralClustering`.""" + est = SpectralClustering(**params) + with pytest.raises(err_type, match=err_msg): + est.fit(input) + + @pytest.mark.parametrize("assign_labels", ("kmeans", "discretize", "cluster_qr")) def test_spectral_clustering_sparse(assign_labels): X, y = make_blobs(