From 7343d72ffc78e4d0649f9216dd19e27359297291 Mon Sep 17 00:00:00 2001 From: Sayed Qaiser Ali <66676360+sqali@users.noreply.github.com> Date: Wed, 30 Aug 2023 22:31:05 +0530 Subject: [PATCH 01/10] Update _graph.py Added validation params to sklearn.neighbors.radius_neighbors_graph --- sklearn/neighbors/_graph.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/sklearn/neighbors/_graph.py b/sklearn/neighbors/_graph.py index d4ece07fd7a70..79fbea1f2248e 100644 --- a/sklearn/neighbors/_graph.py +++ b/sklearn/neighbors/_graph.py @@ -147,7 +147,18 @@ def kneighbors_graph( query = _query_include_self(X._fit_X, include_self, mode) return X.kneighbors_graph(X=query, n_neighbors=n_neighbors, mode=mode) - +@validate_params( + { + "X": ["array-like", "BallTree", "RadiusNeighborsMixin"], + "radius": [Interval(Real, 0, None, closed="right")] + "mode": [StrOptions({"connectivity", "distance"})], + "metric": [StrOptions(set(itertools.chain(*VALID_METRICS.values()))), callable], + "p": [Interval(Real, 0, None, closed="right"), None], + "metric_params": [dict, None], + "include_self": ["boolean", StrOptions({"auto"})], + "n_jobs": [Integral, None], + } +) def radius_neighbors_graph( X, radius, From 407baaeae26371d9d108c267c4c3eb6ca3ccbc30 Mon Sep 17 00:00:00 2001 From: Sayed Qaiser Ali <66676360+sqali@users.noreply.github.com> Date: Wed, 30 Aug 2023 22:34:25 +0530 Subject: [PATCH 02/10] Update test_public_functions.py Added sklearn.neighbors.radius_neighbors_graph to the test_public_functions list --- sklearn/tests/test_public_functions.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index 1d9c75180c1ea..f3ecb30ffa531 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -306,6 +306,7 @@ def _check_function_param_validation( "sklearn.model_selection.train_test_split", "sklearn.model_selection.validation_curve", "sklearn.neighbors.kneighbors_graph", + "sklearn.neighbors.radius_neighbors_graph", "sklearn.neighbors.sort_graph_by_row_values", "sklearn.preprocessing.add_dummy_feature", "sklearn.preprocessing.binarize", From a8b04d113bfbbc7b87f7129472d5be684731c20a Mon Sep 17 00:00:00 2001 From: Sayed Qaiser Ali <66676360+sqali@users.noreply.github.com> Date: Wed, 30 Aug 2023 23:03:15 +0530 Subject: [PATCH 03/10] Update _graph.py corrected a comma --- sklearn/neighbors/_graph.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/neighbors/_graph.py b/sklearn/neighbors/_graph.py index 79fbea1f2248e..f62238c79c899 100644 --- a/sklearn/neighbors/_graph.py +++ b/sklearn/neighbors/_graph.py @@ -150,7 +150,7 @@ def kneighbors_graph( @validate_params( { "X": ["array-like", "BallTree", "RadiusNeighborsMixin"], - "radius": [Interval(Real, 0, None, closed="right")] + "radius": [Interval(Real, 0, None, closed="right")], "mode": [StrOptions({"connectivity", "distance"})], "metric": [StrOptions(set(itertools.chain(*VALID_METRICS.values()))), callable], "p": [Interval(Real, 0, None, closed="right"), None], From 4bd4d67b87cc3bf4ad008b57e2cf0cc467c76207 Mon Sep 17 00:00:00 2001 From: sqali Date: Thu, 31 Aug 2023 20:43:39 +0530 Subject: [PATCH 04/10] file reformatting done using black, mypy, added prefer_skip argument --- doc/whats_new/v1.4.rst | 4 ++++ sklearn/neighbors/_graph.py | 4 +++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/doc/whats_new/v1.4.rst b/doc/whats_new/v1.4.rst index c13922c6cb22e..a89bd2ab5317f 100644 --- a/doc/whats_new/v1.4.rst +++ b/doc/whats_new/v1.4.rst @@ -246,6 +246,10 @@ Changelog :class:`metric.DistanceMetric` objects. :pr:`26267` by :user:`Meekail Zain ` +- |Feature| :func:`neighbors.radius_neighbors_graph` now uses the + `validate_params` decorator for parameter validation. + :pr:`27245` by :user:`Sayed Qaiser Ali ` + :mod:`sklearn.metrics` ...................... diff --git a/sklearn/neighbors/_graph.py b/sklearn/neighbors/_graph.py index f62238c79c899..ed9c2a9c750b5 100644 --- a/sklearn/neighbors/_graph.py +++ b/sklearn/neighbors/_graph.py @@ -147,6 +147,7 @@ def kneighbors_graph( query = _query_include_self(X._fit_X, include_self, mode) return X.kneighbors_graph(X=query, n_neighbors=n_neighbors, mode=mode) + @validate_params( { "X": ["array-like", "BallTree", "RadiusNeighborsMixin"], @@ -157,7 +158,8 @@ def kneighbors_graph( "metric_params": [dict, None], "include_self": ["boolean", StrOptions({"auto"})], "n_jobs": [Integral, None], - } + }, + prefer_skip_nested_validation=False, ) def radius_neighbors_graph( X, From 782a44a2cb119ce28f1ab6637b7b3a50c21f580e Mon Sep 17 00:00:00 2001 From: sqali Date: Thu, 31 Aug 2023 21:36:10 +0530 Subject: [PATCH 05/10] removed balltree from the x parameters --- sklearn/neighbors/_graph.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/neighbors/_graph.py b/sklearn/neighbors/_graph.py index ed9c2a9c750b5..4d72b5a009924 100644 --- a/sklearn/neighbors/_graph.py +++ b/sklearn/neighbors/_graph.py @@ -150,7 +150,7 @@ def kneighbors_graph( @validate_params( { - "X": ["array-like", "BallTree", "RadiusNeighborsMixin"], + "X": ["array-like", RadiusNeighborsMixin], "radius": [Interval(Real, 0, None, closed="right")], "mode": [StrOptions({"connectivity", "distance"})], "metric": [StrOptions(set(itertools.chain(*VALID_METRICS.values()))), callable], From 785bab3245049c5ae65fc072d48b69d4b15cac97 Mon Sep 17 00:00:00 2001 From: Sayed Qaiser Ali <66676360+sqali@users.noreply.github.com> Date: Fri, 1 Sep 2023 18:33:37 +0530 Subject: [PATCH 06/10] Update v1.4.rst removed changelogs --- doc/whats_new/v1.4.rst | 3 --- 1 file changed, 3 deletions(-) diff --git a/doc/whats_new/v1.4.rst b/doc/whats_new/v1.4.rst index a89bd2ab5317f..3aad85594a3ee 100644 --- a/doc/whats_new/v1.4.rst +++ b/doc/whats_new/v1.4.rst @@ -246,9 +246,6 @@ Changelog :class:`metric.DistanceMetric` objects. :pr:`26267` by :user:`Meekail Zain ` -- |Feature| :func:`neighbors.radius_neighbors_graph` now uses the - `validate_params` decorator for parameter validation. - :pr:`27245` by :user:`Sayed Qaiser Ali ` :mod:`sklearn.metrics` ...................... From 2f86295a373cea8d9e4b43881ba925115cdcd86d Mon Sep 17 00:00:00 2001 From: Sayed Qaiser Ali <66676360+sqali@users.noreply.github.com> Date: Fri, 1 Sep 2023 18:33:53 +0530 Subject: [PATCH 07/10] Update sklearn/neighbors/_graph.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com> --- sklearn/neighbors/_graph.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/neighbors/_graph.py b/sklearn/neighbors/_graph.py index 4d72b5a009924..b8ad8bdf94970 100644 --- a/sklearn/neighbors/_graph.py +++ b/sklearn/neighbors/_graph.py @@ -151,7 +151,7 @@ def kneighbors_graph( @validate_params( { "X": ["array-like", RadiusNeighborsMixin], - "radius": [Interval(Real, 0, None, closed="right")], + "radius": [Interval(Real, 0, None, closed="both")], "mode": [StrOptions({"connectivity", "distance"})], "metric": [StrOptions(set(itertools.chain(*VALID_METRICS.values()))), callable], "p": [Interval(Real, 0, None, closed="right"), None], From f781cd87931d3492aab61331f7fd2494d6e31b75 Mon Sep 17 00:00:00 2001 From: Sayed Qaiser Ali <66676360+sqali@users.noreply.github.com> Date: Fri, 1 Sep 2023 18:34:43 +0530 Subject: [PATCH 08/10] Update sklearn/neighbors/_graph.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com> --- sklearn/neighbors/_graph.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/neighbors/_graph.py b/sklearn/neighbors/_graph.py index b8ad8bdf94970..f5dda51484d14 100644 --- a/sklearn/neighbors/_graph.py +++ b/sklearn/neighbors/_graph.py @@ -159,7 +159,7 @@ def kneighbors_graph( "include_self": ["boolean", StrOptions({"auto"})], "n_jobs": [Integral, None], }, - prefer_skip_nested_validation=False, + prefer_skip_nested_validation=False, # metric is not validated yet ) def radius_neighbors_graph( X, From 1537473519b4a93ab0f5ddf6f1b0db9aabf0dcdb Mon Sep 17 00:00:00 2001 From: Sayed Qaiser Ali <66676360+sqali@users.noreply.github.com> Date: Fri, 1 Sep 2023 18:36:13 +0530 Subject: [PATCH 09/10] Update v1.4.rst --- doc/whats_new/v1.4.rst | 1 - 1 file changed, 1 deletion(-) diff --git a/doc/whats_new/v1.4.rst b/doc/whats_new/v1.4.rst index 3aad85594a3ee..c13922c6cb22e 100644 --- a/doc/whats_new/v1.4.rst +++ b/doc/whats_new/v1.4.rst @@ -246,7 +246,6 @@ Changelog :class:`metric.DistanceMetric` objects. :pr:`26267` by :user:`Meekail Zain ` - :mod:`sklearn.metrics` ...................... From ce2cf5ceae3c0a51ee93ee2e77280fd43e632341 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Fri, 1 Sep 2023 16:21:33 +0200 Subject: [PATCH 10/10] fix docstring --- sklearn/neighbors/_graph.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/sklearn/neighbors/_graph.py b/sklearn/neighbors/_graph.py index f5dda51484d14..5fd9be4766c6e 100644 --- a/sklearn/neighbors/_graph.py +++ b/sklearn/neighbors/_graph.py @@ -74,7 +74,7 @@ def kneighbors_graph( Parameters ---------- X : array-like of shape (n_samples, n_features) - Sample data, in the form of a numpy array. + Sample data. n_neighbors : int Number of neighbors for each sample. @@ -181,9 +181,8 @@ def radius_neighbors_graph( Parameters ---------- - X : array-like of shape (n_samples, n_features) or BallTree - Sample data, in the form of a numpy array or a precomputed - :class:`BallTree`. + X : array-like of shape (n_samples, n_features) + Sample data. radius : float Radius of neighborhoods.