From 917817167a7ae68a010117b9fef71632b48907d3 Mon Sep 17 00:00:00 2001 From: Stefanie Senger Date: Sun, 28 May 2023 18:57:00 +0200 Subject: [PATCH 1/2] param_validation according to issue 24862 --- sklearn/neighbors/_graph.py | 31 ++++++++++++++++++++------ sklearn/tests/test_public_functions.py | 1 + 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/sklearn/neighbors/_graph.py b/sklearn/neighbors/_graph.py index 418761c2d21ee..639b973a1ea2f 100644 --- a/sklearn/neighbors/_graph.py +++ b/sklearn/neighbors/_graph.py @@ -8,7 +8,13 @@ from ._base import NeighborsBase from ._unsupervised import NearestNeighbors from ..base import TransformerMixin, ClassNamePrefixFeaturesOutMixin -from ..utils._param_validation import StrOptions +from sklearn.metrics.pairwise import distance_metrics +from ..utils._param_validation import ( + StrOptions, + Interval, + Integral, + validate_params, +) from ..utils.validation import check_is_fitted @@ -224,6 +230,23 @@ def radius_neighbors_graph( return X.radius_neighbors_graph(query, radius, mode) +@validate_params( + { + "mode": [StrOptions({"distance", "connectivity"}), None], + "n_neighbors": [Interval(Integral, 1, None, closed="left"), None], + "algorithm": [StrOptions({"auto", "ball_tree", "kd_tree", "brute"}), None], + "leaf_size": [Interval(Integral, 1, None, closed="left"), None], + "metric": [ + StrOptions(set(distance_metrics())), + StrOptions({"minkowski"}), + callable, + None, + ], + "p": [Interval(Integral, 1, None, closed="left"), None], + "metric_params": [dict, None], + "n_jobs": [Integral, None], + } +) class KNeighborsTransformer( ClassNamePrefixFeaturesOutMixin, KNeighborsMixin, TransformerMixin, NeighborsBase ): @@ -342,12 +365,6 @@ class KNeighborsTransformer( (178, 178) """ - _parameter_constraints: dict = { - **NeighborsBase._parameter_constraints, - "mode": [StrOptions({"distance", "connectivity"})], - } - _parameter_constraints.pop("radius") - def __init__( self, *, diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index 3157e344cbef3..02af7d2b2032e 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -263,6 +263,7 @@ def _check_function_param_validation( "sklearn.model_selection.permutation_test_score", "sklearn.model_selection.train_test_split", "sklearn.model_selection.validation_curve", + "sklearn.neighbors.KNeighborsTransformer", "sklearn.neighbors.sort_graph_by_row_values", "sklearn.preprocessing.add_dummy_feature", "sklearn.preprocessing.binarize", From 6fcd0d9e73f02415eec31d96938aeeda72edcd99 Mon Sep 17 00:00:00 2001 From: Stefanie Senger Date: Sun, 28 May 2023 19:24:29 +0200 Subject: [PATCH 2/2] style --- sklearn/neighbors/_graph.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/neighbors/_graph.py b/sklearn/neighbors/_graph.py index 639b973a1ea2f..ca85c64ab8f61 100644 --- a/sklearn/neighbors/_graph.py +++ b/sklearn/neighbors/_graph.py @@ -8,7 +8,7 @@ from ._base import NeighborsBase from ._unsupervised import NearestNeighbors from ..base import TransformerMixin, ClassNamePrefixFeaturesOutMixin -from sklearn.metrics.pairwise import distance_metrics +from ..metrics.pairwise import distance_metrics from ..utils._param_validation import ( StrOptions, Interval,