Skip to content

Added automatic validation function for sklearn.neighbors.radius_neighbors_graph (#24862) #27245

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Sep 5, 2023
20 changes: 16 additions & 4 deletions sklearn/neighbors/_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def kneighbors_graph(
Parameters
----------
X : array-like of shape (n_samples, n_features)
Sample data, in the form of a numpy array.
Sample data.

n_neighbors : int
Number of neighbors for each sample.
Expand Down Expand Up @@ -148,6 +148,19 @@ def kneighbors_graph(
return X.kneighbors_graph(X=query, n_neighbors=n_neighbors, mode=mode)


@validate_params(
{
"X": ["array-like", RadiusNeighborsMixin],
"radius": [Interval(Real, 0, None, closed="both")],
"mode": [StrOptions({"connectivity", "distance"})],
"metric": [StrOptions(set(itertools.chain(*VALID_METRICS.values()))), callable],
"p": [Interval(Real, 0, None, closed="right"), None],
"metric_params": [dict, None],
"include_self": ["boolean", StrOptions({"auto"})],
"n_jobs": [Integral, None],
},
prefer_skip_nested_validation=False, # metric is not validated yet
)
def radius_neighbors_graph(
X,
radius,
Expand All @@ -168,9 +181,8 @@ def radius_neighbors_graph(

Parameters
----------
X : array-like of shape (n_samples, n_features) or BallTree
Sample data, in the form of a numpy array or a precomputed
:class:`BallTree`.
X : array-like of shape (n_samples, n_features)
Sample data.

radius : float
Radius of neighborhoods.
Expand Down
1 change: 1 addition & 0 deletions sklearn/tests/test_public_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,7 @@ def _check_function_param_validation(
"sklearn.model_selection.train_test_split",
"sklearn.model_selection.validation_curve",
"sklearn.neighbors.kneighbors_graph",
"sklearn.neighbors.radius_neighbors_graph",
"sklearn.neighbors.sort_graph_by_row_values",
"sklearn.preprocessing.add_dummy_feature",
"sklearn.preprocessing.binarize",
Expand Down