diff --git a/sklearn/model_selection/_validation.py b/sklearn/model_selection/_validation.py index 6bb4e2270f359..60276db8ca361 100644 --- a/sklearn/model_selection/_validation.py +++ b/sklearn/model_selection/_validation.py @@ -1818,6 +1818,23 @@ def _incremental_fit_estimator( return np.array(ret).T +@validate_params( + { + "estimator": [HasMethods(["fit"])], + "X": ["array-like", "sparse matrix"], + "y": ["array-like", None], + "param_name": [str], + "param_range": ["array-like"], + "groups": ["array-like", None], + "cv": ["cv_object"], + "scoring": [StrOptions(set(get_scorer_names())), callable, None], + "n_jobs": [Integral, None], + "pre_dispatch": [Integral, str], + "verbose": ["verbose"], + "error_score": [StrOptions({"raise"}), Real], + "fit_params": [dict, None], + } +) def validation_curve( estimator, X, @@ -1847,10 +1864,12 @@ def validation_curve( Parameters ---------- - estimator : object type that implements the "fit" and "predict" methods - An object of that type which is cloned for each validation. + estimator : object type that implements the "fit" method + An object of that type which is cloned for each validation. It must + also implement "predict" unless `scoring` is a callable that doesn't + rely on "predict" to compute a score. - X : array-like of shape (n_samples, n_features) + X : {array-like, sparse matrix} of shape (n_samples, n_features) Training vector, where `n_samples` is the number of samples and `n_features` is the number of features. diff --git a/sklearn/model_selection/tests/test_validation.py b/sklearn/model_selection/tests/test_validation.py index 2c275af617e40..08721921d872a 100644 --- a/sklearn/model_selection/tests/test_validation.py +++ b/sklearn/model_selection/tests/test_validation.py @@ -2118,17 +2118,6 @@ def test_fit_and_score_failing(): with pytest.raises(ValueError, match=error_message): learning_curve(failing_clf, X, y, cv=3, error_score="unvalid-string") - with pytest.raises(ValueError, match=error_message): - validation_curve( - failing_clf, - X, - y, - param_name="parameter", - param_range=[FailingClassifier.FAILING_PARAMETER], - cv=3, - error_score="unvalid-string", - ) - assert failing_clf.score() == 0.0 # FailingClassifier coverage diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index 461b9364ca006..205a8efa9c9d8 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -258,6 +258,7 @@ def _check_function_param_validation( "sklearn.model_selection.cross_validate", "sklearn.model_selection.permutation_test_score", "sklearn.model_selection.train_test_split", + "sklearn.model_selection.validation_curve", "sklearn.neighbors.sort_graph_by_row_values", "sklearn.preprocessing.add_dummy_feature", "sklearn.preprocessing.binarize",