Skip to content

MNT Use check_scalar in BIRCH and DBSCAN #20816

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Oct 10, 2021

Conversation

SanjayMarreddi
Copy link
Contributor

@SanjayMarreddi SanjayMarreddi commented Aug 23, 2021

Reference Issues/PRs

Solves a part of the Issue #20724
See also PR #20723

What does this implement/fix? Explain your changes.

Using the helper function check_scalar from sklearn.utils to validate scalar parameters and making sure to get consistent error types and messages in different Clustering Algorithms:

  • DBSCAN algorithm.
  • BIRCH algorithm.
  • Validate the changes made and Perform System tests.

References:

  1. docs

Any other comments?

Really excited to contribute to the Machine Learning library scikit-learn that I have been using in my projects from the past year!

@glemaitre
Copy link
Member

Could you merge main into your branch. You should as well use a for loop. We found it clearer to make the function call directly.

@SanjayMarreddi
Copy link
Contributor Author

SanjayMarreddi commented Sep 2, 2021

@glemaitre Sure, I will do it and also try to finish the remaining components of this PR ( like testing & some modifications ) and make sure to make it available to merge by this Weekend/ in 1-2 days.
I am sorry if there is any inconvenience from my side. Thanks!

@SanjayMarreddi SanjayMarreddi marked this pull request as ready for review September 3, 2021 13:30
@SanjayMarreddi
Copy link
Contributor Author

@glemaitre I have made the necessary changes. Kindly review it. Thanks!

@SanjayMarreddi SanjayMarreddi changed the title [WIP] Using check_scalar to validate scalars in different Clustering algorithms [MRG] Using check_scalar to validate scalars in different Clustering algorithms Sep 3, 2021
@glemaitre
Copy link
Member

There is a bug in check_scalar that will be fixed soon: #20921
You can merge this PR into yours in order to use the right parameter naming. Alternatively, you can wait that it is merged in main.

@glemaitre
Copy link
Member

You should check the CIs because the linter is failing. You might have forgotten to use black on the changed file.
I will look at the changes and check if the code is OK.

@glemaitre glemaitre self-requested a review September 15, 2021 11:43
@glemaitre glemaitre changed the title [MRG] Using check_scalar to validate scalars in different Clustering algorithms MNT Use check_scalar in BIRCH and DBSCAN Sep 15, 2021
@glemaitre
Copy link
Member

glemaitre commented Sep 15, 2021

Concretely, this is the patch to correct the failure:

diff --git a/sklearn/cluster/_birch.py b/sklearn/cluster/_birch.py
index c688bc2f95..b418e0b5b7 100644
--- a/sklearn/cluster/_birch.py
+++ b/sklearn/cluster/_birch.py
@@ -13,9 +13,8 @@ from ..metrics import pairwise_distances_argmin
 from ..metrics.pairwise import euclidean_distances
 from ..base import TransformerMixin, ClusterMixin, BaseEstimator
 from ..utils.extmath import row_norms
-from ..utils import deprecated
-from ..utils import check_scalar
-from ..utils.validation import check_is_fitted
+from ..utils import check_scalar, deprecated
+from ..utils.validation import _num_samples, check_is_fitted
 from ..exceptions import ConvergenceWarning
 from . import AgglomerativeClustering
 from .._config import config_context
@@ -481,7 +480,7 @@ class Birch(ClusterMixin, TransformerMixin, BaseEstimator):
     # TODO: Remove in 1.2
     # mypy error: Decorated property not supported
     @deprecated(  # type: ignore
-        "`fit_` is deprecated in 1.0 and will be removed in 1.2."
+        "`fit_` is deprecated in 1.0 and will be removed in 1.2"
     )
     @property
     def fit_(self):
@@ -490,7 +489,7 @@ class Birch(ClusterMixin, TransformerMixin, BaseEstimator):
     # TODO: Remove in 1.2
     # mypy error: Decorated property not supported
     @deprecated(  # type: ignore
-        "`partial_fit_` is deprecated in 1.0 and will be removed in 1.2."
+        "`partial_fit_` is deprecated in 1.0 and will be removed in 1.2"
     )
     @property
     def partial_fit_(self):
@@ -519,22 +518,24 @@ class Birch(ClusterMixin, TransformerMixin, BaseEstimator):
             "threshold",
             target_type=numbers.Real,
             min_val=0.0,
-            closed="neither"
+            include_boundaries="neither",
         )
         check_scalar(
             self.branching_factor,
             "branching_factor",
             target_type=numbers.Integral,
             min_val=1,
-            closed="neither"
-        )
-        check_scalar(
-            self.n_clusters,
-            "n_clusters",
-            target_type=numbers.Integral,
-            min_val=1,
-            closed="left"
+            include_boundaries="neither",
         )
+        if isinstance(self.n_clusters, numbers.Number):
+            check_scalar(
+                self.n_clusters,
+                "n_clusters",
+                target_type=numbers.Integral,
+                min_val=1,
+                max_val=_num_samples(X),
+                include_boundaries="both",
+            )
 
         # TODO: Remove deprected flags in 1.2
         self._deprecated_fit, self._deprecated_partial_fit = True, False
@@ -722,7 +723,7 @@ class Birch(ClusterMixin, TransformerMixin, BaseEstimator):
             if len(centroids) < self.n_clusters:
                 not_enough_centroids = True
         elif clusterer is not None and not hasattr(clusterer, "fit_predict"):
-            raise ValueError(
+            raise TypeError(
                 "n_clusters should be an instance of ClusterMixin or an int"
             )
 
diff --git a/sklearn/cluster/_dbscan.py b/sklearn/cluster/_dbscan.py
index f248d069f1..3422a20742 100644
--- a/sklearn/cluster/_dbscan.py
+++ b/sklearn/cluster/_dbscan.py
@@ -365,34 +365,32 @@ class DBSCAN(ClusterMixin, BaseEstimator):
             "eps",
             target_type=numbers.Real,
             min_val=0.0,
-            closed="neither"
+            include_boundaries="neither",
         )
         check_scalar(
             self.min_samples,
             "min_samples",
             target_type=numbers.Integral,
             min_val=1,
-            closed="left"
+            include_boundaries="left",
         )
         check_scalar(
             self.leaf_size,
             "leaf_size",
             target_type=numbers.Integral,
             min_val=1,
-            closed="left"
-        )
-        check_scalar(
-            self.p,
-            "p",
-            target_type=numbers.Real,
-            min_val=1.0,
-            closed="left"
-        )
-        check_scalar(
-            self.n_jobs,
-            "n_jobs",
-            target_type=numbers.Integral
+            include_boundaries="left",
         )
+        if self.p is not None:
+            check_scalar(
+                self.p,
+                "p",
+                target_type=numbers.Real,
+                min_val=1.0,
+                include_boundaries="left",
+            )
+        if self.n_jobs is not None:
+            check_scalar(self.n_jobs, "n_jobs", target_type=numbers.Integral)
 
         neighbors_model = NearestNeighbors(
             radius=self.eps,
diff --git a/sklearn/cluster/tests/test_birch.py b/sklearn/cluster/tests/test_birch.py
index fdc14e8560..5d8a3222ef 100644
--- a/sklearn/cluster/tests/test_birch.py
+++ b/sklearn/cluster/tests/test_birch.py
@@ -19,35 +19,6 @@ from sklearn.utils._testing import assert_array_equal
 from sklearn.utils._testing import assert_array_almost_equal
 
 
-@pytest.mark.parametrize(
-    "input, params, err_type, err_msg",
-    [
-        (X, {"threshold": -1.0}, ValueError,
-         "threshold == -1.0, must be a positive real number."),
-        (X, {"threshold": 0.0}, ValueError,
-         "threshold == 0.0, must be a positive real number."),
-
-        (X, {"branching_factor": 0}, ValueError,
-         "branching_factor == 0, must be a positive integer greater than 1."),
-        (X, {"branching_factor": 1}, ValueError,
-         "branching_factor == 1, must be a positive integer greater than 1."),
-        (X, {"branching_factor": 1.5}, ValueError,
-         "min_samples == 1.5, must be an integer."),
-        (X, {"branching_factor": -2}, ValueError,
-         "branching_factor == -2, must be a positive integer."),
-
-        (X, {"n_clusters": 0}, ValueError, "n_clusters == 0, must be a positive integer."),
-        (X, {"n_clusters": 2.5}, ValueError,  "n_clusters == 2.5, must be an integer."),
-        (X, {"n_clusters": -3}, ValueError,
-         "n_clusters == -2, must be a positive integer."),
-    ],
-)
-def test_birch_params_validation(input, params, err_type, err_msg):
-    """Check the parameters validation in `Birch`."""
-    with pytest.raises(err_type, match=err_msg):
-        Birch(**params).fit(input)
-
-
 def test_n_samples_leaves_roots():
     # Sanity check for the number of samples in leaves and roots
     X, y = make_blobs(n_samples=10)
@@ -114,7 +85,8 @@ def test_n_clusters():
     # Test that the wrong global clustering step raises an Error.
     clf = ElasticNet()
     brc3 = Birch(n_clusters=clf)
-    with pytest.raises(ValueError):
+    err_msg = "n_clusters should be an instance of ClusterMixin or an int"
+    with pytest.raises(TypeError, match=err_msg):
         brc3.fit(X)
 
     # Test that a small number of clusters raises a warning.
@@ -211,3 +183,39 @@ def test_birch_fit_attributes_deprecated(attribute):
 
     with pytest.warns(FutureWarning, match=msg):
         getattr(brc, attribute)
+
+
+@pytest.mark.parametrize(
+    "params, err_type, err_msg",
+    [
+        ({"threshold": -1.0}, ValueError, "threshold == -1.0, must be > 0.0."),
+        ({"threshold": 0.0}, ValueError, "threshold == 0.0, must be > 0.0."),
+        ({"branching_factor": 0}, ValueError, "branching_factor == 0, must be > 1."),
+        ({"branching_factor": 1}, ValueError, "branching_factor == 1, must be > 1."),
+        (
+            {"branching_factor": 1.5},
+            TypeError,
+            "branching_factor must be an instance of <class 'numbers.Integral'>, not"
+            " <class 'float'>.",
+        ),
+        ({"branching_factor": -2}, ValueError, "branching_factor == -2, must be > 1."),
+        ({"n_clusters": 0}, ValueError, "n_clusters == 0, must be >= 1."),
+        (
+            {"n_clusters": 2.5},
+            TypeError,
+            "n_clusters must be an instance of <class 'numbers.Integral'>, not <class"
+            " 'float'>.",
+        ),
+        (
+            {"n_clusters": "whatever"},
+            TypeError,
+            "n_clusters should be an instance of ClusterMixin or an int",
+        ),
+        ({"n_clusters": -3}, ValueError, "n_clusters == -3, must be >= 1."),
+    ],
+)
+def test_birch_params_validation(params, err_type, err_msg):
+    """Check the parameters validation in `Birch`."""
+    X, _ = make_blobs(n_samples=80, centers=4)
+    with pytest.raises(err_type, match=err_msg):
+        Birch(**params).fit(X)
diff --git a/sklearn/cluster/tests/test_dbscan.py b/sklearn/cluster/tests/test_dbscan.py
index 592745196d..b593d83d45 100644
--- a/sklearn/cluster/tests/test_dbscan.py
+++ b/sklearn/cluster/tests/test_dbscan.py
@@ -25,35 +25,6 @@ n_clusters = 3
 X = generate_clustered_data(n_clusters=n_clusters)
 
 
-@pytest.mark.parametrize(
-    "input, params, err_type, err_msg",
-    [
-        (X, {"eps": -1.0}, ValueError, "eps == -1.0, must be a positive real number."),
-        (X, {"eps": 0.0}, ValueError, "eps == 0.0, must be a positive real number."),
-
-        (X, {"min_samples": 0}, ValueError,
-         "min_samples == 0, must be a positive integer."),
-        (X, {"min_samples": 1.5}, ValueError, "min_samples == 1.5, must be an integer."),
-        (X, {"min_samples": -2}, ValueError,
-         "min_samples == -2, must be a positive integer."),
-
-        (X, {"leaf_size": 0},   ValueError, "leaf_size == 0, must be a positive integer."),
-        (X, {"leaf_size": 2.5}, ValueError, "leaf_size == 1.5, must be an integer."),
-        (X, {"leaf_size": -3},  ValueError,
-         "leaf_size == -2, must be a positive integer."),
-
-        (X, {"p": 0},  ValueError, "p == 0, must be >= 1"),
-        (X, {"p": -2}, ValueError, "p == -2, must be a positive real number."),
-
-        (X, {"n_jobs": 2.5}, ValueError, "n_jobs == 2.5, must be an integer."),
-    ],
-)
-def test_dbscan_params_validation(input, params, err_type, err_msg):
-    """Check the parameters validation in `DBSCAN`."""
-    with pytest.raises(err_type, match=err_msg):
-        dbscan(**params).fit(input)
-
-
 def test_dbscan_similarity():
     # Tests the DBSCAN algorithm with a similarity array.
     # Parameters chosen specifically for this task.
@@ -454,3 +425,40 @@ def test_dbscan_precomputed_metric_with_initial_rows_zero():
     matrix = sparse.csr_matrix(ar)
     labels = DBSCAN(eps=0.2, metric="precomputed", min_samples=2).fit(matrix).labels_
     assert_array_equal(labels, [-1, -1, 0, 0, 0, 1, 1])
+
+
+@pytest.mark.parametrize(
+    "params, err_type, err_msg",
+    [
+        ({"eps": -1.0}, ValueError, "eps == -1.0, must be > 0.0."),
+        ({"eps": 0.0}, ValueError, "eps == 0.0, must be > 0.0."),
+        ({"min_samples": 0}, ValueError, "min_samples == 0, must be >= 1."),
+        (
+            {"min_samples": 1.5},
+            TypeError,
+            "min_samples must be an instance of <class 'numbers.Integral'>, not <class"
+            " 'float'>.",
+        ),
+        ({"min_samples": -2}, ValueError, "min_samples == -2, must be >= 1."),
+        ({"leaf_size": 0}, ValueError, "leaf_size == 0, must be >= 1."),
+        (
+            {"leaf_size": 2.5},
+            TypeError,
+            "leaf_size must be an instance of <class 'numbers.Integral'>, not <class"
+            " 'float'>.",
+        ),
+        ({"leaf_size": -3}, ValueError, "leaf_size == -3, must be >= 1."),
+        ({"p": 0}, ValueError, "p == 0, must be >= 1.0."),
+        ({"p": -2}, ValueError, "p == -2, must be >= 1.0."),
+        (
+            {"n_jobs": 2.5},
+            TypeError,
+            "n_jobs must be an instance of <class 'numbers.Integral'>, not <class"
+            " 'float'>.",
+        ),
+    ],
+)
+def test_dbscan_params_validation(params, err_type, err_msg):
+    """Check the parameters validation in `DBSCAN`."""
+    with pytest.raises(err_type, match=err_msg):
+        DBSCAN(**params).fit(X)

Copy link
Member

@glemaitre glemaitre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tagging as "Request changes" to acknowledge that the PR has been reviewed

@SanjayMarreddi
Copy link
Contributor Author

@glemaitre I will try to resolve all the issues by this weekend. Thanks!

@SanjayMarreddi
Copy link
Contributor Author

@glemaitre I have done the suggested changes. Could you guide me in resolving the failing tests? I am using auto formatting with autopep8. Thanks!

@glemaitre
Copy link
Member

glemaitre commented Sep 23, 2021

I am using auto formatting with autopep8. Thanks!

We are using black and you can check the failing CI there:

https://dev.azure.com/scikit-learn/scikit-learn/_build/results?buildId=32760&view=logs&jobId=32e2e1bb-a28f-5b18-6cfc-3f01273f5609&j=32e2e1bb-a28f-5b18-6cfc-3f01273f5609&t=fc67071d-c3d4-58b8-d38e-cafc0d3c731a

Basically, you should reformat the file sklearn/cluster/tests/test_birch.py such that no file is reformatted by the CI, as mentioned in the contributing guide: https://scikit-learn.org/stable/developers/contributing.html#pull-request-checklist

An automatic way for the lining is to install pre-commit that will use black and flake8 before committing.

Copy link
Member

@glemaitre glemaitre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@SanjayMarreddi
Copy link
Contributor Author

Thanks a lot, @glemaitre for the review and also for guiding me throughout my first PR at scikit-learn!

@SanjayMarreddi
Copy link
Contributor Author

SanjayMarreddi commented Sep 24, 2021

Kind request for a review from @jeremiedbb and @ogrisel so that my PR can be merged!
Thanks!

Copy link
Member

@jjerphan jjerphan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nothing to change: LGTM! 👌

Thank you @SanjayMarreddi.

@SanjayMarreddi
Copy link
Contributor Author

Thanks for the approval @jjerphan!
Will be waiting for the PR to get merged.

Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for PR @SanjayMarreddi !

SanjayMarreddi and others added 2 commits October 5, 2021 00:27
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
@SanjayMarreddi
Copy link
Contributor Author

SanjayMarreddi commented Oct 4, 2021

@thomasjpfan Thanks a lot for the review. I modified the code to implement the suggestions on the GitHub UI itself. Will do the linting by tonight. Thanks!

@SanjayMarreddi
Copy link
Contributor Author

@thomasjpfan I have fixed the linting errors. But there are new tests that are failing. I went through the details but could not understand. Can u help me in resolving this? Thanks!

Comment on lines 451 to 452
({"p": 0}, ValueError, "p == 0, must be >= 1.0."),
({"p": -2}, ValueError, "p == -2, must be >= 1.0."),
Copy link
Member

@jjerphan jjerphan Oct 5, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SanjayMarreddi: This should fix tests failures.

Suggested change
({"p": 0}, ValueError, "p == 0, must be >= 1.0."),
({"p": -2}, ValueError, "p == -2, must be >= 1.0."),
({"p": -2}, ValueError, "p == -2, must be >= 0.0."),

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jjerphan Thanks! I got the bug.

Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One small comment. Otherwise LGTM

@SanjayMarreddi
Copy link
Contributor Author

@thomasjpfan Done with the changes. Thanks for the review. Waiting for the PR to get merged!!

@thomasjpfan thomasjpfan merged commit 2f2364d into scikit-learn:main Oct 10, 2021
@glemaitre glemaitre mentioned this pull request Oct 23, 2021
10 tasks
samronsin pushed a commit to samronsin/scikit-learn that referenced this pull request Nov 30, 2021
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants