Skip to content

Commit 1438ded

Browse files
committed
Use _validate_params in Birch
1 parent 2f787f4 commit 1438ded

File tree

3 files changed

+12
-59
lines changed

3 files changed

+12
-59
lines changed

sklearn/cluster/_birch.py

Lines changed: 12 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
# License: BSD 3 clause
55

66
import warnings
7-
import numbers
87
import numpy as np
8+
from numbers import Integral, Real
99
from scipy import sparse
1010
from math import sqrt
1111

@@ -19,6 +19,7 @@
1919
)
2020
from ..utils.extmath import row_norms
2121
from ..utils import check_scalar, deprecated
22+
from ..utils._param_validation import Interval
2223
from ..utils.validation import check_is_fitted
2324
from ..exceptions import ConvergenceWarning
2425
from . import AgglomerativeClustering
@@ -478,6 +479,14 @@ class Birch(
478479
array([0, 0, 0, 1, 1, 1])
479480
"""
480481

482+
_parameter_constraints = {
483+
"threshold": [Interval(Real, 0.0, 1.0, closed="neither")],
484+
"branching_factor": [Interval(Integral, 1, None, closed="left")],
485+
"n_clusters": [None, ClusterMixin, Interval(Integral, 1, None, closed="left")],
486+
"compute_labels": [bool],
487+
"copy": [bool],
488+
}
489+
481490
def __init__(
482491
self,
483492
*,
@@ -529,28 +538,7 @@ def fit(self, X, y=None):
529538
Fitted estimator.
530539
"""
531540

532-
# Validating the scalar parameters.
533-
check_scalar(
534-
self.threshold,
535-
"threshold",
536-
target_type=numbers.Real,
537-
min_val=0.0,
538-
include_boundaries="neither",
539-
)
540-
check_scalar(
541-
self.branching_factor,
542-
"branching_factor",
543-
target_type=numbers.Integral,
544-
min_val=1,
545-
include_boundaries="neither",
546-
)
547-
if isinstance(self.n_clusters, numbers.Number):
548-
check_scalar(
549-
self.n_clusters,
550-
"n_clusters",
551-
target_type=numbers.Integral,
552-
min_val=1,
553-
)
541+
self._validate_params()
554542

555543
# TODO: Remove deprecated flags in 1.2
556544
self._deprecated_fit, self._deprecated_partial_fit = True, False
@@ -744,7 +732,7 @@ def _global_clustering(self, X=None):
744732

745733
# Preprocessing for the global clustering.
746734
not_enough_centroids = False
747-
if isinstance(clusterer, numbers.Integral):
735+
if isinstance(clusterer, Integral):
748736
clusterer = AgglomerativeClustering(n_clusters=self.n_clusters)
749737
# There is no need to perform the global clustering step.
750738
if len(centroids) < self.n_clusters:

sklearn/cluster/tests/test_birch.py

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -186,40 +186,6 @@ def test_birch_fit_attributes_deprecated(attribute):
186186
getattr(brc, attribute)
187187

188188

189-
@pytest.mark.parametrize(
190-
"params, err_type, err_msg",
191-
[
192-
({"threshold": -1.0}, ValueError, "threshold == -1.0, must be > 0.0."),
193-
({"threshold": 0.0}, ValueError, "threshold == 0.0, must be > 0.0."),
194-
({"branching_factor": 0}, ValueError, "branching_factor == 0, must be > 1."),
195-
({"branching_factor": 1}, ValueError, "branching_factor == 1, must be > 1."),
196-
(
197-
{"branching_factor": 1.5},
198-
TypeError,
199-
"branching_factor must be an instance of int, not float.",
200-
),
201-
({"branching_factor": -2}, ValueError, "branching_factor == -2, must be > 1."),
202-
({"n_clusters": 0}, ValueError, "n_clusters == 0, must be >= 1."),
203-
(
204-
{"n_clusters": 2.5},
205-
TypeError,
206-
"n_clusters must be an instance of int, not float.",
207-
),
208-
(
209-
{"n_clusters": "whatever"},
210-
TypeError,
211-
"n_clusters should be an instance of ClusterMixin or an int",
212-
),
213-
({"n_clusters": -3}, ValueError, "n_clusters == -3, must be >= 1."),
214-
],
215-
)
216-
def test_birch_params_validation(params, err_type, err_msg):
217-
"""Check the parameters validation in `Birch`."""
218-
X, _ = make_blobs(n_samples=80, centers=4)
219-
with pytest.raises(err_type, match=err_msg):
220-
Birch(**params).fit(X)
221-
222-
223189
def test_feature_names_out():
224190
"""Check `get_feature_names_out` for `Birch`."""
225191
X, _ = make_blobs(n_samples=80, n_features=4, random_state=0)

sklearn/tests/test_common.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,6 @@ def test_estimators_do_not_raise_errors_in_init_or_set_params(Estimator):
459459
"BernoulliNB",
460460
"BernoulliRBM",
461461
"Binarizer",
462-
"Birch",
463462
"CCA",
464463
"CalibratedClassifierCV",
465464
"CategoricalNB",

0 commit comments

Comments
 (0)