Skip to content

Commit 26e2c38

Browse files
hvassardhugovassardjjerphan
authored
[MRG] MNT use check_scalar to validate scalar in SpectralClustering (#21881)
* use check_scalar in SpectralClustering * Add check_scalar parameters validation for cluster.SpectralClustering * fix missing comma * tiny changelog update to relauch CI * errors are raised at fit time solely Co-authored-by: Julien Jerphanion <git@jjerphan.xyz> * fix typos Co-authored-by: Julien Jerphanion <git@jjerphan.xyz> * merge ..utils imports Co-authored-by: Julien Jerphanion <git@jjerphan.xyz> Co-authored-by: hvassard <hugo.vassard@insa-rouen.fr> Co-authored-by: Julien Jerphanion <git@jjerphan.xyz>
1 parent 6a371ae commit 26e2c38

File tree

3 files changed

+109
-1
lines changed

3 files changed

+109
-1
lines changed

doc/whats_new/v1.1.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,11 @@ Changelog
103103
is faster on many datasets, and its results are identical, hence the change.
104104
:pr:`21735` by :user:`Aurélien Geron <ageron>`.
105105

106+
- |Enhancement| :class:`cluster.SpectralClustering` now raises consistent
107+
error messages when passed invalid values for `n_clusters`, `n_init`,
108+
`gamma`, `n_neighbors`, `eigen_tol` or `degree`.
109+
:pr:`21881` by :user:`Hugo Vassard <hvassard>`.
110+
106111
:mod:`sklearn.cross_decomposition`
107112
..................................
108113

sklearn/cluster/_spectral.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
# Wei LI <kuantkid@gmail.com>
77
# Andrew Knyazev <Andrew.Knyazev@ucdenver.edu>
88
# License: BSD 3 clause
9+
10+
import numbers
911
import warnings
1012

1113
import numpy as np
@@ -14,7 +16,7 @@
1416
from scipy.sparse import csc_matrix
1517

1618
from ..base import BaseEstimator, ClusterMixin
17-
from ..utils import check_random_state, as_float_array
19+
from ..utils import as_float_array, check_random_state, check_scalar
1820
from ..utils.deprecation import deprecated
1921
from ..metrics.pairwise import pairwise_kernels
2022
from ..neighbors import kneighbors_graph, NearestNeighbors
@@ -662,6 +664,55 @@ def fit(self, X, y=None):
662664
"set ``affinity=precomputed``."
663665
)
664666

667+
check_scalar(
668+
self.n_clusters,
669+
"n_clusters",
670+
target_type=numbers.Integral,
671+
min_val=1,
672+
include_boundaries="left",
673+
)
674+
675+
check_scalar(
676+
self.n_init,
677+
"n_init",
678+
target_type=numbers.Integral,
679+
min_val=1,
680+
include_boundaries="left",
681+
)
682+
683+
check_scalar(
684+
self.gamma,
685+
"gamma",
686+
target_type=numbers.Real,
687+
min_val=1.0,
688+
include_boundaries="left",
689+
)
690+
691+
check_scalar(
692+
self.n_neighbors,
693+
"n_neighbors",
694+
target_type=numbers.Integral,
695+
min_val=1,
696+
include_boundaries="left",
697+
)
698+
699+
if self.eigen_solver == "arpack":
700+
check_scalar(
701+
self.eigen_tol,
702+
"eigen_tol",
703+
target_type=numbers.Real,
704+
min_val=0,
705+
include_boundaries="left",
706+
)
707+
708+
check_scalar(
709+
self.degree,
710+
"degree",
711+
target_type=numbers.Integral,
712+
min_val=1,
713+
include_boundaries="left",
714+
)
715+
665716
if self.affinity == "nearest_neighbors":
666717
connectivity = kneighbors_graph(
667718
X, n_neighbors=self.n_neighbors, include_self=True, n_jobs=self.n_jobs

sklearn/cluster/tests/test_spectral.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,16 @@
2828
except ImportError:
2929
amg_loaded = False
3030

31+
centers = np.array([[1, 1], [-1, -1], [1, -1]]) + 10
32+
X, _ = make_blobs(
33+
n_samples=60,
34+
n_features=2,
35+
centers=centers,
36+
cluster_std=0.4,
37+
shuffle=True,
38+
random_state=0,
39+
)
40+
3141

3242
@pytest.mark.parametrize("eigen_solver", ("arpack", "lobpcg"))
3343
@pytest.mark.parametrize("assign_labels", ("kmeans", "discretize", "cluster_qr"))
@@ -102,6 +112,48 @@ def test_spectral_unknown_assign_labels():
102112
spectral_clustering(S, n_clusters=2, random_state=0, assign_labels="<unknown>")
103113

104114

115+
@pytest.mark.parametrize(
116+
"input, params, err_type, err_msg",
117+
[
118+
(X, {"n_clusters": -1}, ValueError, "n_clusters == -1, must be >= 1"),
119+
(X, {"n_clusters": 0}, ValueError, "n_clusters == 0, must be >= 1"),
120+
(
121+
X,
122+
{"n_clusters": 1.5},
123+
TypeError,
124+
"n_clusters must be an instance of <class 'numbers.Integral'>,"
125+
" not <class 'float'>",
126+
),
127+
(X, {"n_init": -1}, ValueError, "n_init == -1, must be >= 1"),
128+
(X, {"n_init": 0}, ValueError, "n_init == 0, must be >= 1"),
129+
(
130+
X,
131+
{"n_init": 1.5},
132+
TypeError,
133+
"n_init must be an instance of <class 'numbers.Integral'>,"
134+
" not <class 'float'>",
135+
),
136+
(X, {"gamma": -1}, ValueError, "gamma == -1, must be >= 1"),
137+
(X, {"gamma": 0}, ValueError, "gamma == 0, must be >= 1"),
138+
(X, {"n_neighbors": -1}, ValueError, "n_neighbors == -1, must be >= 1"),
139+
(X, {"n_neighbors": 0}, ValueError, "n_neighbors == 0, must be >= 1"),
140+
(
141+
X,
142+
{"eigen_tol": -1, "eigen_solver": "arpack"},
143+
ValueError,
144+
"eigen_tol == -1, must be >= 0",
145+
),
146+
(X, {"degree": -1}, ValueError, "degree == -1, must be >= 1"),
147+
(X, {"degree": 0}, ValueError, "degree == 0, must be >= 1"),
148+
],
149+
)
150+
def test_spectral_params_validation(input, params, err_type, err_msg):
151+
"""Check the parameters validation in `SpectralClustering`."""
152+
est = SpectralClustering(**params)
153+
with pytest.raises(err_type, match=err_msg):
154+
est.fit(input)
155+
156+
105157
@pytest.mark.parametrize("assign_labels", ("kmeans", "discretize", "cluster_qr"))
106158
def test_spectral_clustering_sparse(assign_labels):
107159
X, y = make_blobs(

0 commit comments

Comments
 (0)