diff --git a/sklearn/neighbors/_graph.py b/sklearn/neighbors/_graph.py index 418761c2d21ee..ca85c64ab8f61 100644 --- a/sklearn/neighbors/_graph.py +++ b/sklearn/neighbors/_graph.py @@ -8,7 +8,13 @@ from ._base import NeighborsBase from ._unsupervised import NearestNeighbors from ..base import TransformerMixin, ClassNamePrefixFeaturesOutMixin -from ..utils._param_validation import StrOptions +from ..metrics.pairwise import distance_metrics +from ..utils._param_validation import ( + StrOptions, + Interval, + Integral, + validate_params, +) from ..utils.validation import check_is_fitted @@ -224,6 +230,23 @@ def radius_neighbors_graph( return X.radius_neighbors_graph(query, radius, mode) +@validate_params( + { + "mode": [StrOptions({"distance", "connectivity"}), None], + "n_neighbors": [Interval(Integral, 1, None, closed="left"), None], + "algorithm": [StrOptions({"auto", "ball_tree", "kd_tree", "brute"}), None], + "leaf_size": [Interval(Integral, 1, None, closed="left"), None], + "metric": [ + StrOptions(set(distance_metrics())), + StrOptions({"minkowski"}), + callable, + None, + ], + "p": [Interval(Integral, 1, None, closed="left"), None], + "metric_params": [dict, None], + "n_jobs": [Integral, None], + } +) class KNeighborsTransformer( ClassNamePrefixFeaturesOutMixin, KNeighborsMixin, TransformerMixin, NeighborsBase ): @@ -342,12 +365,6 @@ class KNeighborsTransformer( (178, 178) """ - _parameter_constraints: dict = { - **NeighborsBase._parameter_constraints, - "mode": [StrOptions({"distance", "connectivity"})], - } - _parameter_constraints.pop("radius") - def __init__( self, *, diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index 3157e344cbef3..02af7d2b2032e 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -263,6 +263,7 @@ def _check_function_param_validation( "sklearn.model_selection.permutation_test_score", "sklearn.model_selection.train_test_split", "sklearn.model_selection.validation_curve", + "sklearn.neighbors.KNeighborsTransformer", "sklearn.neighbors.sort_graph_by_row_values", "sklearn.preprocessing.add_dummy_feature", "sklearn.preprocessing.binarize",