From 4a8df21078f6c99306842f903442f296aaf2266e Mon Sep 17 00:00:00 2001 From: ROMEEZHOU Date: Fri, 14 Apr 2023 08:15:17 +0800 Subject: [PATCH 1/3] MAINT Parameters validation for sklearn.neighbors.sort_graph_by_row_values --- sklearn/neighbors/_base.py | 9 ++++++++- sklearn/neighbors/tests/test_neighbors.py | 4 ---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/sklearn/neighbors/_base.py b/sklearn/neighbors/_base.py index 8ba1f4fa7d093..92ba487a2bbe7 100644 --- a/sklearn/neighbors/_base.py +++ b/sklearn/neighbors/_base.py @@ -36,7 +36,7 @@ from ..utils.multiclass import check_classification_targets from ..utils.validation import check_is_fitted from ..utils.validation import check_non_negative -from ..utils._param_validation import Interval, StrOptions +from ..utils._param_validation import Interval, StrOptions, validate_params from ..utils.parallel import delayed, Parallel from ..utils.fixes import parse_version, sp_base_version, sp_version from ..exceptions import DataConversionWarning, EfficiencyWarning @@ -197,6 +197,13 @@ def _check_precomputed(X): return graph +@validate_params( + { + "graph": ["sparse matrix"], + "copy": ["boolean"], + "warn_when_not_sorted": ["boolean"], + } +) def sort_graph_by_row_values(graph, copy=False, warn_when_not_sorted=True): """Sort a sparse graph such that each row is stored with increasing values. diff --git a/sklearn/neighbors/tests/test_neighbors.py b/sklearn/neighbors/tests/test_neighbors.py index 3e22193ba55c0..092b85ad9dcd0 100644 --- a/sklearn/neighbors/tests/test_neighbors.py +++ b/sklearn/neighbors/tests/test_neighbors.py @@ -497,10 +497,6 @@ def test_sort_graph_by_row_values_copy(): with pytest.raises(ValueError, match="Use copy=True to allow the conversion"): sort_graph_by_row_values(X.tocsc(), copy=False) - # raise if X is not even sparse - with pytest.raises(TypeError, match="Input graph must be a sparse matrix"): - sort_graph_by_row_values(X.toarray()) - def test_sort_graph_by_row_values_warning(): # Test that the parameter warn_when_not_sorted works as expected. From a8776d6959679a71ffcbfcfa6ec2747f4f3f0c04 Mon Sep 17 00:00:00 2001 From: ROMEEZHOU Date: Fri, 14 Apr 2023 21:44:34 +0800 Subject: [PATCH 2/3] add to test_public_functions --- sklearn/tests/test_public_functions.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index 62513b0f63cce..03b4ad9725748 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -238,6 +238,7 @@ def _check_function_param_validation( "sklearn.metrics.top_k_accuracy_score", "sklearn.metrics.zero_one_loss", "sklearn.model_selection.train_test_split", + "sklearn.neighbors.sort_graph_by_row_values", "sklearn.preprocessing.add_dummy_feature", "sklearn.preprocessing.binarize", "sklearn.preprocessing.label_binarize", From 214a4b5944d6ef50dc85ab3ddf8feacfb703f46e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9mie=20du=20Boisberranger?= <34657725+jeremiedbb@users.noreply.github.com> Date: Mon, 17 Apr 2023 15:29:39 +0200 Subject: [PATCH 3/3] Update _base.py --- sklearn/neighbors/_base.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/sklearn/neighbors/_base.py b/sklearn/neighbors/_base.py index 92ba487a2bbe7..9af85a38f0b6c 100644 --- a/sklearn/neighbors/_base.py +++ b/sklearn/neighbors/_base.py @@ -231,9 +231,6 @@ def sort_graph_by_row_values(graph, copy=False, warn_when_not_sorted=True): Distance matrix to other samples, where only non-zero elements are considered neighbors. Matrix is in CSR format. """ - if not issparse(graph): - raise TypeError(f"Input graph must be a sparse matrix, got {graph!r} instead.") - if graph.format == "csr" and _is_sorted_by_data(graph): return graph