diff --git a/sklearn/neighbors/_graph.py b/sklearn/neighbors/_graph.py index d4ece07fd7a70..5fd9be4766c6e 100644 --- a/sklearn/neighbors/_graph.py +++ b/sklearn/neighbors/_graph.py @@ -74,7 +74,7 @@ def kneighbors_graph( Parameters ---------- X : array-like of shape (n_samples, n_features) - Sample data, in the form of a numpy array. + Sample data. n_neighbors : int Number of neighbors for each sample. @@ -148,6 +148,19 @@ def kneighbors_graph( return X.kneighbors_graph(X=query, n_neighbors=n_neighbors, mode=mode) +@validate_params( + { + "X": ["array-like", RadiusNeighborsMixin], + "radius": [Interval(Real, 0, None, closed="both")], + "mode": [StrOptions({"connectivity", "distance"})], + "metric": [StrOptions(set(itertools.chain(*VALID_METRICS.values()))), callable], + "p": [Interval(Real, 0, None, closed="right"), None], + "metric_params": [dict, None], + "include_self": ["boolean", StrOptions({"auto"})], + "n_jobs": [Integral, None], + }, + prefer_skip_nested_validation=False, # metric is not validated yet +) def radius_neighbors_graph( X, radius, @@ -168,9 +181,8 @@ def radius_neighbors_graph( Parameters ---------- - X : array-like of shape (n_samples, n_features) or BallTree - Sample data, in the form of a numpy array or a precomputed - :class:`BallTree`. + X : array-like of shape (n_samples, n_features) + Sample data. radius : float Radius of neighborhoods. diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index 791ff4205e3cd..20ac2e1e1615e 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -306,6 +306,7 @@ def _check_function_param_validation( "sklearn.model_selection.train_test_split", "sklearn.model_selection.validation_curve", "sklearn.neighbors.kneighbors_graph", + "sklearn.neighbors.radius_neighbors_graph", "sklearn.neighbors.sort_graph_by_row_values", "sklearn.preprocessing.add_dummy_feature", "sklearn.preprocessing.binarize",