diff --git a/sklearn/neighbors/_base.py b/sklearn/neighbors/_base.py index 8ba1f4fa7d093..9af85a38f0b6c 100644 --- a/sklearn/neighbors/_base.py +++ b/sklearn/neighbors/_base.py @@ -36,7 +36,7 @@ from ..utils.multiclass import check_classification_targets from ..utils.validation import check_is_fitted from ..utils.validation import check_non_negative -from ..utils._param_validation import Interval, StrOptions +from ..utils._param_validation import Interval, StrOptions, validate_params from ..utils.parallel import delayed, Parallel from ..utils.fixes import parse_version, sp_base_version, sp_version from ..exceptions import DataConversionWarning, EfficiencyWarning @@ -197,6 +197,13 @@ def _check_precomputed(X): return graph +@validate_params( + { + "graph": ["sparse matrix"], + "copy": ["boolean"], + "warn_when_not_sorted": ["boolean"], + } +) def sort_graph_by_row_values(graph, copy=False, warn_when_not_sorted=True): """Sort a sparse graph such that each row is stored with increasing values. @@ -224,9 +231,6 @@ def sort_graph_by_row_values(graph, copy=False, warn_when_not_sorted=True): Distance matrix to other samples, where only non-zero elements are considered neighbors. Matrix is in CSR format. """ - if not issparse(graph): - raise TypeError(f"Input graph must be a sparse matrix, got {graph!r} instead.") - if graph.format == "csr" and _is_sorted_by_data(graph): return graph diff --git a/sklearn/neighbors/tests/test_neighbors.py b/sklearn/neighbors/tests/test_neighbors.py index 3e22193ba55c0..092b85ad9dcd0 100644 --- a/sklearn/neighbors/tests/test_neighbors.py +++ b/sklearn/neighbors/tests/test_neighbors.py @@ -497,10 +497,6 @@ def test_sort_graph_by_row_values_copy(): with pytest.raises(ValueError, match="Use copy=True to allow the conversion"): sort_graph_by_row_values(X.tocsc(), copy=False) - # raise if X is not even sparse - with pytest.raises(TypeError, match="Input graph must be a sparse matrix"): - sort_graph_by_row_values(X.toarray()) - def test_sort_graph_by_row_values_warning(): # Test that the parameter warn_when_not_sorted works as expected. diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index 62513b0f63cce..03b4ad9725748 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -238,6 +238,7 @@ def _check_function_param_validation( "sklearn.metrics.top_k_accuracy_score", "sklearn.metrics.zero_one_loss", "sklearn.model_selection.train_test_split", + "sklearn.neighbors.sort_graph_by_row_values", "sklearn.preprocessing.add_dummy_feature", "sklearn.preprocessing.binarize", "sklearn.preprocessing.label_binarize",