Skip to content

MNT Use check_scalar in BIRCH and DBSCAN #20816

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Oct 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
68132df
use check_scalar to validate scalar inputs in DBSCAN algorithm
SanjayMarreddi Aug 23, 2021
358adbe
use check_scalar to validate scalar inputs in BIRCH algorithm
SanjayMarreddi Aug 23, 2021
d0781c2
use check_scalar to validate scalar inputs in DBSCAN algorithm
SanjayMarreddi Aug 23, 2021
07b022f
use check_scalar to validate scalar inputs in BIRCH algorithm
SanjayMarreddi Aug 23, 2021
efb2d21
Removed for loops and called check_scalar separately
SanjayMarreddi Sep 3, 2021
fee4af1
Merge branch 'main' of https://github.com/scikit-learn/scikit-learn i…
SanjayMarreddi Sep 3, 2021
4e36c4b
Removed redundant commented lines
SanjayMarreddi Sep 3, 2021
11a42ce
Made the suggested changes in DBSCAN & BIRCH algos
SanjayMarreddi Sep 20, 2021
db8ca61
Resolved merge conflicts
SanjayMarreddi Sep 20, 2021
2ba8d71
Removed Merge Conflicts
SanjayMarreddi Sep 20, 2021
714e5a0
Implemented all the suggested changes
SanjayMarreddi Sep 20, 2021
87c582f
Corrected the typo
SanjayMarreddi Sep 20, 2021
ce3604f
Merge branch 'main' of https://github.com/scikit-learn/scikit-learn i…
SanjayMarreddi Sep 23, 2021
4136d7b
Reformatted the file to resolve failing tests
SanjayMarreddi Sep 23, 2021
93cb702
Apply suggestions from code review
glemaitre Sep 23, 2021
3cfbd8c
Removed max_val check as suggested.
SanjayMarreddi Oct 4, 2021
c83ee0e
Corrected the value of min_val
SanjayMarreddi Oct 4, 2021
dae5f1e
Merge branch 'main' of https://github.com/scikit-learn/scikit-learn i…
SanjayMarreddi Oct 5, 2021
6575587
Added comments and used Linting
SanjayMarreddi Oct 5, 2021
6e4b649
Removed unnecessary imports
SanjayMarreddi Oct 5, 2021
545ebd0
Made corrections according to p limits
SanjayMarreddi Oct 5, 2021
cecd8a5
Merge branch 'check_scalar' of https://github.com/SanjayMarreddi/scik…
SanjayMarreddi Oct 9, 2021
468c5f2
Removed default arg of include_boundaries
SanjayMarreddi Oct 9, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 27 additions & 5 deletions sklearn/cluster/_birch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from ..metrics.pairwise import euclidean_distances
from ..base import TransformerMixin, ClusterMixin, BaseEstimator
from ..utils.extmath import row_norms
from ..utils import deprecated
from ..utils import check_scalar, deprecated
from ..utils.validation import check_is_fitted
from ..exceptions import ConvergenceWarning
from . import AgglomerativeClustering
Expand Down Expand Up @@ -512,7 +512,31 @@ def fit(self, X, y=None):
self
Fitted estimator.
"""
# TODO: Remove deprecated flags in 1.2

# Validating the scalar parameters.
check_scalar(
self.threshold,
"threshold",
target_type=numbers.Real,
min_val=0.0,
include_boundaries="neither",
)
check_scalar(
self.branching_factor,
"branching_factor",
target_type=numbers.Integral,
min_val=1,
include_boundaries="neither",
)
if isinstance(self.n_clusters, numbers.Number):
check_scalar(
self.n_clusters,
"n_clusters",
target_type=numbers.Integral,
min_val=1,
)

# TODO: Remove deprected flags in 1.2
self._deprecated_fit, self._deprecated_partial_fit = True, False
return self._fit(X, partial=False)

Expand All @@ -526,8 +550,6 @@ def _fit(self, X, partial):
threshold = self.threshold
branching_factor = self.branching_factor

if branching_factor <= 1:
raise ValueError("Branching_factor should be greater than one.")
n_samples, n_features = X.shape

# If partial_fit is called for the first time or fit is called, we
Expand Down Expand Up @@ -700,7 +722,7 @@ def _global_clustering(self, X=None):
if len(centroids) < self.n_clusters:
not_enough_centroids = True
elif clusterer is not None and not hasattr(clusterer, "fit_predict"):
raise ValueError(
raise TypeError(
"n_clusters should be an instance of ClusterMixin or an int"
)

Expand Down
38 changes: 35 additions & 3 deletions sklearn/cluster/_dbscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
# License: BSD 3 clause

import numpy as np
import numbers
import warnings
from scipy import sparse

from ..utils import check_scalar
from ..base import BaseEstimator, ClusterMixin
from ..utils.validation import _check_sample_weight
from ..neighbors import NearestNeighbors
Expand Down Expand Up @@ -345,9 +347,6 @@ def fit(self, X, y=None, sample_weight=None):
"""
X = self._validate_data(X, accept_sparse="csr")

if not self.eps > 0.0:
raise ValueError("eps must be positive.")

if sample_weight is not None:
sample_weight = _check_sample_weight(sample_weight, X)

Expand All @@ -361,6 +360,39 @@ def fit(self, X, y=None, sample_weight=None):
warnings.simplefilter("ignore", sparse.SparseEfficiencyWarning)
X.setdiag(X.diagonal()) # XXX: modifies X's internals in-place

# Validating the scalar parameters.
check_scalar(
self.eps,
"eps",
target_type=numbers.Real,
min_val=0.0,
include_boundaries="neither",
)
check_scalar(
self.min_samples,
"min_samples",
target_type=numbers.Integral,
min_val=1,
include_boundaries="left",
)
check_scalar(
self.leaf_size,
"leaf_size",
target_type=numbers.Integral,
min_val=1,
include_boundaries="left",
)
if self.p is not None:
check_scalar(
self.p,
"p",
target_type=numbers.Real,
min_val=0.0,
include_boundaries="left",
)
if self.n_jobs is not None:
check_scalar(self.n_jobs, "n_jobs", target_type=numbers.Integral)

neighbors_model = NearestNeighbors(
radius=self.eps,
algorithm=self.algorithm,
Expand Down
44 changes: 38 additions & 6 deletions sklearn/cluster/tests/test_birch.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ def test_n_clusters():
# Test that the wrong global clustering step raises an Error.
clf = ElasticNet()
brc3 = Birch(n_clusters=clf)
with pytest.raises(ValueError):
err_msg = "n_clusters should be an instance of ClusterMixin or an int"
with pytest.raises(TypeError, match=err_msg):
brc3.fit(X)

# Test that a small number of clusters raises a warning.
Expand Down Expand Up @@ -141,11 +142,6 @@ def test_branching_factor():
brc.fit(X)
check_branching_factor(brc.root_, branching_factor)

# Raises error when branching_factor is set to one.
brc = Birch(n_clusters=None, branching_factor=1, threshold=0.01)
with pytest.raises(ValueError):
brc.fit(X)


def check_threshold(birch_instance, threshold):
"""Use the leaf linked list for traversal"""
Expand Down Expand Up @@ -187,3 +183,39 @@ def test_birch_fit_attributes_deprecated(attribute):

with pytest.warns(FutureWarning, match=msg):
getattr(brc, attribute)


@pytest.mark.parametrize(
"params, err_type, err_msg",
[
({"threshold": -1.0}, ValueError, "threshold == -1.0, must be > 0.0."),
({"threshold": 0.0}, ValueError, "threshold == 0.0, must be > 0.0."),
({"branching_factor": 0}, ValueError, "branching_factor == 0, must be > 1."),
({"branching_factor": 1}, ValueError, "branching_factor == 1, must be > 1."),
(
{"branching_factor": 1.5},
TypeError,
"branching_factor must be an instance of <class 'numbers.Integral'>, not"
" <class 'float'>.",
),
({"branching_factor": -2}, ValueError, "branching_factor == -2, must be > 1."),
({"n_clusters": 0}, ValueError, "n_clusters == 0, must be >= 1."),
(
{"n_clusters": 2.5},
TypeError,
"n_clusters must be an instance of <class 'numbers.Integral'>, not <class"
" 'float'>.",
),
(
{"n_clusters": "whatever"},
TypeError,
"n_clusters should be an instance of ClusterMixin or an int",
),
({"n_clusters": -3}, ValueError, "n_clusters == -3, must be >= 1."),
],
)
def test_birch_params_validation(params, err_type, err_msg):
"""Check the parameters validation in `Birch`."""
X, _ = make_blobs(n_samples=80, centers=4)
with pytest.raises(err_type, match=err_msg):
Birch(**params).fit(X)
39 changes: 36 additions & 3 deletions sklearn/cluster/tests/test_dbscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,11 +272,8 @@ def test_input_validation():
@pytest.mark.parametrize(
"args",
[
{"eps": -1.0},
{"algorithm": "blah"},
{"metric": "blah"},
{"leaf_size": -1},
{"p": -1},
],
)
def test_dbscan_badargs(args):
Expand Down Expand Up @@ -428,3 +425,39 @@ def test_dbscan_precomputed_metric_with_initial_rows_zero():
matrix = sparse.csr_matrix(ar)
labels = DBSCAN(eps=0.2, metric="precomputed", min_samples=2).fit(matrix).labels_
assert_array_equal(labels, [-1, -1, 0, 0, 0, 1, 1])


@pytest.mark.parametrize(
"params, err_type, err_msg",
[
({"eps": -1.0}, ValueError, "eps == -1.0, must be > 0.0."),
({"eps": 0.0}, ValueError, "eps == 0.0, must be > 0.0."),
({"min_samples": 0}, ValueError, "min_samples == 0, must be >= 1."),
(
{"min_samples": 1.5},
TypeError,
"min_samples must be an instance of <class 'numbers.Integral'>, not <class"
" 'float'>.",
),
({"min_samples": -2}, ValueError, "min_samples == -2, must be >= 1."),
({"leaf_size": 0}, ValueError, "leaf_size == 0, must be >= 1."),
(
{"leaf_size": 2.5},
TypeError,
"leaf_size must be an instance of <class 'numbers.Integral'>, not <class"
" 'float'>.",
),
({"leaf_size": -3}, ValueError, "leaf_size == -3, must be >= 1."),
({"p": -2}, ValueError, "p == -2, must be >= 0.0."),
(
{"n_jobs": 2.5},
TypeError,
"n_jobs must be an instance of <class 'numbers.Integral'>, not <class"
" 'float'>.",
),
],
)
def test_dbscan_params_validation(params, err_type, err_msg):
"""Check the parameters validation in `DBSCAN`."""
with pytest.raises(err_type, match=err_msg):
DBSCAN(**params).fit(X)