-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
MNT Use check_scalar in BIRCH and DBSCAN #20816
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Could you merge main into your branch. You should as well use a for loop. We found it clearer to make the function call directly. |
@glemaitre Sure, I will do it and also try to finish the remaining components of this PR ( like testing & some modifications ) and make sure to make it available to merge by this Weekend/ in 1-2 days. |
@glemaitre I have made the necessary changes. Kindly review it. Thanks! |
There is a bug in |
You should check the CIs because the linter is failing. You might have forgotten to use black on the changed file. |
Concretely, this is the patch to correct the failure: diff --git a/sklearn/cluster/_birch.py b/sklearn/cluster/_birch.py
index c688bc2f95..b418e0b5b7 100644
--- a/sklearn/cluster/_birch.py
+++ b/sklearn/cluster/_birch.py
@@ -13,9 +13,8 @@ from ..metrics import pairwise_distances_argmin
from ..metrics.pairwise import euclidean_distances
from ..base import TransformerMixin, ClusterMixin, BaseEstimator
from ..utils.extmath import row_norms
-from ..utils import deprecated
-from ..utils import check_scalar
-from ..utils.validation import check_is_fitted
+from ..utils import check_scalar, deprecated
+from ..utils.validation import _num_samples, check_is_fitted
from ..exceptions import ConvergenceWarning
from . import AgglomerativeClustering
from .._config import config_context
@@ -481,7 +480,7 @@ class Birch(ClusterMixin, TransformerMixin, BaseEstimator):
# TODO: Remove in 1.2
# mypy error: Decorated property not supported
@deprecated( # type: ignore
- "`fit_` is deprecated in 1.0 and will be removed in 1.2."
+ "`fit_` is deprecated in 1.0 and will be removed in 1.2"
)
@property
def fit_(self):
@@ -490,7 +489,7 @@ class Birch(ClusterMixin, TransformerMixin, BaseEstimator):
# TODO: Remove in 1.2
# mypy error: Decorated property not supported
@deprecated( # type: ignore
- "`partial_fit_` is deprecated in 1.0 and will be removed in 1.2."
+ "`partial_fit_` is deprecated in 1.0 and will be removed in 1.2"
)
@property
def partial_fit_(self):
@@ -519,22 +518,24 @@ class Birch(ClusterMixin, TransformerMixin, BaseEstimator):
"threshold",
target_type=numbers.Real,
min_val=0.0,
- closed="neither"
+ include_boundaries="neither",
)
check_scalar(
self.branching_factor,
"branching_factor",
target_type=numbers.Integral,
min_val=1,
- closed="neither"
- )
- check_scalar(
- self.n_clusters,
- "n_clusters",
- target_type=numbers.Integral,
- min_val=1,
- closed="left"
+ include_boundaries="neither",
)
+ if isinstance(self.n_clusters, numbers.Number):
+ check_scalar(
+ self.n_clusters,
+ "n_clusters",
+ target_type=numbers.Integral,
+ min_val=1,
+ max_val=_num_samples(X),
+ include_boundaries="both",
+ )
# TODO: Remove deprected flags in 1.2
self._deprecated_fit, self._deprecated_partial_fit = True, False
@@ -722,7 +723,7 @@ class Birch(ClusterMixin, TransformerMixin, BaseEstimator):
if len(centroids) < self.n_clusters:
not_enough_centroids = True
elif clusterer is not None and not hasattr(clusterer, "fit_predict"):
- raise ValueError(
+ raise TypeError(
"n_clusters should be an instance of ClusterMixin or an int"
)
diff --git a/sklearn/cluster/_dbscan.py b/sklearn/cluster/_dbscan.py
index f248d069f1..3422a20742 100644
--- a/sklearn/cluster/_dbscan.py
+++ b/sklearn/cluster/_dbscan.py
@@ -365,34 +365,32 @@ class DBSCAN(ClusterMixin, BaseEstimator):
"eps",
target_type=numbers.Real,
min_val=0.0,
- closed="neither"
+ include_boundaries="neither",
)
check_scalar(
self.min_samples,
"min_samples",
target_type=numbers.Integral,
min_val=1,
- closed="left"
+ include_boundaries="left",
)
check_scalar(
self.leaf_size,
"leaf_size",
target_type=numbers.Integral,
min_val=1,
- closed="left"
- )
- check_scalar(
- self.p,
- "p",
- target_type=numbers.Real,
- min_val=1.0,
- closed="left"
- )
- check_scalar(
- self.n_jobs,
- "n_jobs",
- target_type=numbers.Integral
+ include_boundaries="left",
)
+ if self.p is not None:
+ check_scalar(
+ self.p,
+ "p",
+ target_type=numbers.Real,
+ min_val=1.0,
+ include_boundaries="left",
+ )
+ if self.n_jobs is not None:
+ check_scalar(self.n_jobs, "n_jobs", target_type=numbers.Integral)
neighbors_model = NearestNeighbors(
radius=self.eps,
diff --git a/sklearn/cluster/tests/test_birch.py b/sklearn/cluster/tests/test_birch.py
index fdc14e8560..5d8a3222ef 100644
--- a/sklearn/cluster/tests/test_birch.py
+++ b/sklearn/cluster/tests/test_birch.py
@@ -19,35 +19,6 @@ from sklearn.utils._testing import assert_array_equal
from sklearn.utils._testing import assert_array_almost_equal
-@pytest.mark.parametrize(
- "input, params, err_type, err_msg",
- [
- (X, {"threshold": -1.0}, ValueError,
- "threshold == -1.0, must be a positive real number."),
- (X, {"threshold": 0.0}, ValueError,
- "threshold == 0.0, must be a positive real number."),
-
- (X, {"branching_factor": 0}, ValueError,
- "branching_factor == 0, must be a positive integer greater than 1."),
- (X, {"branching_factor": 1}, ValueError,
- "branching_factor == 1, must be a positive integer greater than 1."),
- (X, {"branching_factor": 1.5}, ValueError,
- "min_samples == 1.5, must be an integer."),
- (X, {"branching_factor": -2}, ValueError,
- "branching_factor == -2, must be a positive integer."),
-
- (X, {"n_clusters": 0}, ValueError, "n_clusters == 0, must be a positive integer."),
- (X, {"n_clusters": 2.5}, ValueError, "n_clusters == 2.5, must be an integer."),
- (X, {"n_clusters": -3}, ValueError,
- "n_clusters == -2, must be a positive integer."),
- ],
-)
-def test_birch_params_validation(input, params, err_type, err_msg):
- """Check the parameters validation in `Birch`."""
- with pytest.raises(err_type, match=err_msg):
- Birch(**params).fit(input)
-
-
def test_n_samples_leaves_roots():
# Sanity check for the number of samples in leaves and roots
X, y = make_blobs(n_samples=10)
@@ -114,7 +85,8 @@ def test_n_clusters():
# Test that the wrong global clustering step raises an Error.
clf = ElasticNet()
brc3 = Birch(n_clusters=clf)
- with pytest.raises(ValueError):
+ err_msg = "n_clusters should be an instance of ClusterMixin or an int"
+ with pytest.raises(TypeError, match=err_msg):
brc3.fit(X)
# Test that a small number of clusters raises a warning.
@@ -211,3 +183,39 @@ def test_birch_fit_attributes_deprecated(attribute):
with pytest.warns(FutureWarning, match=msg):
getattr(brc, attribute)
+
+
+@pytest.mark.parametrize(
+ "params, err_type, err_msg",
+ [
+ ({"threshold": -1.0}, ValueError, "threshold == -1.0, must be > 0.0."),
+ ({"threshold": 0.0}, ValueError, "threshold == 0.0, must be > 0.0."),
+ ({"branching_factor": 0}, ValueError, "branching_factor == 0, must be > 1."),
+ ({"branching_factor": 1}, ValueError, "branching_factor == 1, must be > 1."),
+ (
+ {"branching_factor": 1.5},
+ TypeError,
+ "branching_factor must be an instance of <class 'numbers.Integral'>, not"
+ " <class 'float'>.",
+ ),
+ ({"branching_factor": -2}, ValueError, "branching_factor == -2, must be > 1."),
+ ({"n_clusters": 0}, ValueError, "n_clusters == 0, must be >= 1."),
+ (
+ {"n_clusters": 2.5},
+ TypeError,
+ "n_clusters must be an instance of <class 'numbers.Integral'>, not <class"
+ " 'float'>.",
+ ),
+ (
+ {"n_clusters": "whatever"},
+ TypeError,
+ "n_clusters should be an instance of ClusterMixin or an int",
+ ),
+ ({"n_clusters": -3}, ValueError, "n_clusters == -3, must be >= 1."),
+ ],
+)
+def test_birch_params_validation(params, err_type, err_msg):
+ """Check the parameters validation in `Birch`."""
+ X, _ = make_blobs(n_samples=80, centers=4)
+ with pytest.raises(err_type, match=err_msg):
+ Birch(**params).fit(X)
diff --git a/sklearn/cluster/tests/test_dbscan.py b/sklearn/cluster/tests/test_dbscan.py
index 592745196d..b593d83d45 100644
--- a/sklearn/cluster/tests/test_dbscan.py
+++ b/sklearn/cluster/tests/test_dbscan.py
@@ -25,35 +25,6 @@ n_clusters = 3
X = generate_clustered_data(n_clusters=n_clusters)
-@pytest.mark.parametrize(
- "input, params, err_type, err_msg",
- [
- (X, {"eps": -1.0}, ValueError, "eps == -1.0, must be a positive real number."),
- (X, {"eps": 0.0}, ValueError, "eps == 0.0, must be a positive real number."),
-
- (X, {"min_samples": 0}, ValueError,
- "min_samples == 0, must be a positive integer."),
- (X, {"min_samples": 1.5}, ValueError, "min_samples == 1.5, must be an integer."),
- (X, {"min_samples": -2}, ValueError,
- "min_samples == -2, must be a positive integer."),
-
- (X, {"leaf_size": 0}, ValueError, "leaf_size == 0, must be a positive integer."),
- (X, {"leaf_size": 2.5}, ValueError, "leaf_size == 1.5, must be an integer."),
- (X, {"leaf_size": -3}, ValueError,
- "leaf_size == -2, must be a positive integer."),
-
- (X, {"p": 0}, ValueError, "p == 0, must be >= 1"),
- (X, {"p": -2}, ValueError, "p == -2, must be a positive real number."),
-
- (X, {"n_jobs": 2.5}, ValueError, "n_jobs == 2.5, must be an integer."),
- ],
-)
-def test_dbscan_params_validation(input, params, err_type, err_msg):
- """Check the parameters validation in `DBSCAN`."""
- with pytest.raises(err_type, match=err_msg):
- dbscan(**params).fit(input)
-
-
def test_dbscan_similarity():
# Tests the DBSCAN algorithm with a similarity array.
# Parameters chosen specifically for this task.
@@ -454,3 +425,40 @@ def test_dbscan_precomputed_metric_with_initial_rows_zero():
matrix = sparse.csr_matrix(ar)
labels = DBSCAN(eps=0.2, metric="precomputed", min_samples=2).fit(matrix).labels_
assert_array_equal(labels, [-1, -1, 0, 0, 0, 1, 1])
+
+
+@pytest.mark.parametrize(
+ "params, err_type, err_msg",
+ [
+ ({"eps": -1.0}, ValueError, "eps == -1.0, must be > 0.0."),
+ ({"eps": 0.0}, ValueError, "eps == 0.0, must be > 0.0."),
+ ({"min_samples": 0}, ValueError, "min_samples == 0, must be >= 1."),
+ (
+ {"min_samples": 1.5},
+ TypeError,
+ "min_samples must be an instance of <class 'numbers.Integral'>, not <class"
+ " 'float'>.",
+ ),
+ ({"min_samples": -2}, ValueError, "min_samples == -2, must be >= 1."),
+ ({"leaf_size": 0}, ValueError, "leaf_size == 0, must be >= 1."),
+ (
+ {"leaf_size": 2.5},
+ TypeError,
+ "leaf_size must be an instance of <class 'numbers.Integral'>, not <class"
+ " 'float'>.",
+ ),
+ ({"leaf_size": -3}, ValueError, "leaf_size == -3, must be >= 1."),
+ ({"p": 0}, ValueError, "p == 0, must be >= 1.0."),
+ ({"p": -2}, ValueError, "p == -2, must be >= 1.0."),
+ (
+ {"n_jobs": 2.5},
+ TypeError,
+ "n_jobs must be an instance of <class 'numbers.Integral'>, not <class"
+ " 'float'>.",
+ ),
+ ],
+)
+def test_dbscan_params_validation(params, err_type, err_msg):
+ """Check the parameters validation in `DBSCAN`."""
+ with pytest.raises(err_type, match=err_msg):
+ DBSCAN(**params).fit(X)
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tagging as "Request changes" to acknowledge that the PR has been reviewed
@glemaitre I will try to resolve all the issues by this weekend. Thanks! |
@glemaitre I have done the suggested changes. Could you guide me in resolving the failing tests? I am using auto formatting with autopep8. Thanks! |
We are using Basically, you should reformat the file An automatic way for the lining is to install |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Thanks a lot, @glemaitre for the review and also for guiding me throughout my first PR at scikit-learn! |
Kind request for a review from @jeremiedbb and @ogrisel so that my PR can be merged! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nothing to change: LGTM! 👌
Thank you @SanjayMarreddi.
Thanks for the approval @jjerphan! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for PR @SanjayMarreddi !
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
@thomasjpfan Thanks a lot for the review. I modified the code to implement the suggestions on the GitHub UI itself. Will do the linting by tonight. Thanks! |
@thomasjpfan I have fixed the linting errors. But there are new tests that are failing. I went through the details but could not understand. Can u help me in resolving this? Thanks! |
sklearn/cluster/tests/test_dbscan.py
Outdated
({"p": 0}, ValueError, "p == 0, must be >= 1.0."), | ||
({"p": -2}, ValueError, "p == -2, must be >= 1.0."), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@SanjayMarreddi: This should fix tests failures.
({"p": 0}, ValueError, "p == 0, must be >= 1.0."), | |
({"p": -2}, ValueError, "p == -2, must be >= 1.0."), | |
({"p": -2}, ValueError, "p == -2, must be >= 0.0."), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jjerphan Thanks! I got the bug.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One small comment. Otherwise LGTM
@thomasjpfan Done with the changes. Thanks for the review. Waiting for the PR to get merged!! |
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com> Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
Reference Issues/PRs
Solves a part of the Issue #20724
See also PR #20723
What does this implement/fix? Explain your changes.
Using the helper function
check_scalar
fromsklearn.utils
to validate scalar parameters and making sure to get consistent error types and messages in different Clustering Algorithms:References:
Any other comments?
Really excited to contribute to the Machine Learning library
scikit-learn
that I have been using in my projects from the past year!