diff --git a/sklearn/model_selection/_validation.py b/sklearn/model_selection/_validation.py index d3b08169bc058..6bb4e2270f359 100644 --- a/sklearn/model_selection/_validation.py +++ b/sklearn/model_selection/_validation.py @@ -32,6 +32,7 @@ from ..utils.metaestimators import _safe_split from ..utils._param_validation import ( HasMethods, + Interval, Integral, StrOptions, validate_params, @@ -1235,6 +1236,21 @@ def _check_is_permutation(indices, n_samples): return True +@validate_params( + { + "estimator": [HasMethods("fit")], + "X": ["array-like", "sparse matrix"], + "y": ["array-like", None], + "groups": ["array-like", None], + "cv": ["cv_object"], + "n_permutations": [Interval(Integral, 1, None, closed="left")], + "n_jobs": [Integral, None], + "random_state": ["random_state"], + "verbose": ["verbose"], + "scoring": [StrOptions(set(get_scorer_names())), callable, None], + "fit_params": [dict, None], + } +) def permutation_test_score( estimator, X, diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index 2bb6846dc4cbf..55d201c87072e 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -247,6 +247,7 @@ def _check_function_param_validation( "sklearn.metrics.top_k_accuracy_score", "sklearn.metrics.zero_one_loss", "sklearn.model_selection.cross_validate", + "sklearn.model_selection.permutation_test_score", "sklearn.model_selection.train_test_split", "sklearn.neighbors.sort_graph_by_row_values", "sklearn.preprocessing.add_dummy_feature",