From 2cfa1af404bfa34d62afb7376f51510d36a7f877 Mon Sep 17 00:00:00 2001 From: Charlie-XIAO Date: Thu, 20 Apr 2023 11:44:26 +0800 Subject: [PATCH 1/4] MAINT Parameters validation for sklearn.model_selection.validation_curve --- sklearn/model_selection/_validation.py | 19 ++++++++++++- .../model_selection/tests/test_validation.py | 28 ------------------- sklearn/tests/test_public_functions.py | 1 + 3 files changed, 19 insertions(+), 29 deletions(-) diff --git a/sklearn/model_selection/_validation.py b/sklearn/model_selection/_validation.py index d3b08169bc058..f80577eaf547e 100644 --- a/sklearn/model_selection/_validation.py +++ b/sklearn/model_selection/_validation.py @@ -1802,6 +1802,23 @@ def _incremental_fit_estimator( return np.array(ret).T +@validate_params( + { + "estimator": [HasMethods(["fit", "predict"])], + "X": ["array-like", "sparse matrix"], + "y": ["array-like", None], + "param_name": [str], + "param_range": ["array-like"], + "groups": ["array-like", None], + "cv": ["cv_object"], + "scoring": [StrOptions(set(get_scorer_names())), callable, None], + "n_jobs": [Integral, None], + "pre_dispatch": [Integral, str], + "verbose": ["verbose"], + "error_score": [StrOptions({"raise"}), Real], + "fit_params": [dict, None], + } +) def validation_curve( estimator, X, @@ -1834,7 +1851,7 @@ def validation_curve( estimator : object type that implements the "fit" and "predict" methods An object of that type which is cloned for each validation. - X : array-like of shape (n_samples, n_features) + X : {array-like, sparse matrix} of shape (n_samples, n_features) Training vector, where `n_samples` is the number of samples and `n_features` is the number of features. diff --git a/sklearn/model_selection/tests/test_validation.py b/sklearn/model_selection/tests/test_validation.py index 2c275af617e40..9c1c3856f4b3e 100644 --- a/sklearn/model_selection/tests/test_validation.py +++ b/sklearn/model_selection/tests/test_validation.py @@ -2094,7 +2094,6 @@ def test_fit_and_score_failing(): failing_clf = FailingClassifier(FailingClassifier.FAILING_PARAMETER) # dummy X data X = np.arange(1, 10) - y = np.ones(9) fit_and_score_args = [failing_clf, X, None, dict(), None, None, 0, None, None] # passing error score to trigger the warning message fit_and_score_kwargs = {"error_score": "raise"} @@ -2102,33 +2101,6 @@ def test_fit_and_score_failing(): with pytest.raises(ValueError, match="Failing classifier failed as required"): _fit_and_score(*fit_and_score_args, **fit_and_score_kwargs) - # check that functions upstream pass error_score param to _fit_and_score - error_message = re.escape( - "error_score must be the string 'raise' or a numeric value. (Hint: if " - "using 'raise', please make sure that it has been spelled correctly.)" - ) - - error_message_cross_validate = ( - "The 'error_score' parameter of cross_validate must be .*. Got .* instead." - ) - - with pytest.raises(ValueError, match=error_message_cross_validate): - cross_val_score(failing_clf, X, cv=3, error_score="unvalid-string") - - with pytest.raises(ValueError, match=error_message): - learning_curve(failing_clf, X, y, cv=3, error_score="unvalid-string") - - with pytest.raises(ValueError, match=error_message): - validation_curve( - failing_clf, - X, - y, - param_name="parameter", - param_range=[FailingClassifier.FAILING_PARAMETER], - cv=3, - error_score="unvalid-string", - ) - assert failing_clf.score() == 0.0 # FailingClassifier coverage diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index 2bb6846dc4cbf..6d5d92ab5a7e4 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -248,6 +248,7 @@ def _check_function_param_validation( "sklearn.metrics.zero_one_loss", "sklearn.model_selection.cross_validate", "sklearn.model_selection.train_test_split", + "sklearn.model_selection.validation_curve", "sklearn.neighbors.sort_graph_by_row_values", "sklearn.preprocessing.add_dummy_feature", "sklearn.preprocessing.binarize", From 90cf17874a4ea3a141f5aae7f6ec1b66d3193aac Mon Sep 17 00:00:00 2001 From: Charlie-XIAO Date: Fri, 21 Apr 2023 04:42:35 +0800 Subject: [PATCH 2/4] resolved conversations --- .../model_selection/tests/test_validation.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/sklearn/model_selection/tests/test_validation.py b/sklearn/model_selection/tests/test_validation.py index 9c1c3856f4b3e..08721921d872a 100644 --- a/sklearn/model_selection/tests/test_validation.py +++ b/sklearn/model_selection/tests/test_validation.py @@ -2094,6 +2094,7 @@ def test_fit_and_score_failing(): failing_clf = FailingClassifier(FailingClassifier.FAILING_PARAMETER) # dummy X data X = np.arange(1, 10) + y = np.ones(9) fit_and_score_args = [failing_clf, X, None, dict(), None, None, 0, None, None] # passing error score to trigger the warning message fit_and_score_kwargs = {"error_score": "raise"} @@ -2101,6 +2102,22 @@ def test_fit_and_score_failing(): with pytest.raises(ValueError, match="Failing classifier failed as required"): _fit_and_score(*fit_and_score_args, **fit_and_score_kwargs) + # check that functions upstream pass error_score param to _fit_and_score + error_message = re.escape( + "error_score must be the string 'raise' or a numeric value. (Hint: if " + "using 'raise', please make sure that it has been spelled correctly.)" + ) + + error_message_cross_validate = ( + "The 'error_score' parameter of cross_validate must be .*. Got .* instead." + ) + + with pytest.raises(ValueError, match=error_message_cross_validate): + cross_val_score(failing_clf, X, cv=3, error_score="unvalid-string") + + with pytest.raises(ValueError, match=error_message): + learning_curve(failing_clf, X, y, cv=3, error_score="unvalid-string") + assert failing_clf.score() == 0.0 # FailingClassifier coverage From 4e92b503ce9c66d99042842ee6d35a42b64ef885 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9mie=20du=20Boisberranger?= <34657725+jeremiedbb@users.noreply.github.com> Date: Wed, 26 Apr 2023 18:07:58 +0200 Subject: [PATCH 3/4] predict not mandatory --- sklearn/model_selection/_validation.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/sklearn/model_selection/_validation.py b/sklearn/model_selection/_validation.py index 65dfc8c1090b0..d65f2f8142c6f 100644 --- a/sklearn/model_selection/_validation.py +++ b/sklearn/model_selection/_validation.py @@ -1820,7 +1820,7 @@ def _incremental_fit_estimator( @validate_params( { - "estimator": [HasMethods(["fit", "predict"])], + "estimator": [HasMethods(["fit"])], "X": ["array-like", "sparse matrix"], "y": ["array-like", None], "param_name": [str], @@ -1864,8 +1864,10 @@ def validation_curve( Parameters ---------- - estimator : object type that implements the "fit" and "predict" methods - An object of that type which is cloned for each validation. + estimator : object type that implements the "fit" method. + An object of that type which is cloned for each validation. It must + also implement "predict" unless `scoring` is a callable that doesn't + rely on "predict" to compute a score. X : {array-like, sparse matrix} of shape (n_samples, n_features) Training vector, where `n_samples` is the number of samples and From 9ba8032cd3e10c674f1267d246a28f04f54ec86d Mon Sep 17 00:00:00 2001 From: Charlie-XIAO Date: Thu, 27 Apr 2023 01:15:25 +0800 Subject: [PATCH 4/4] removed trailing period --- sklearn/model_selection/_validation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/model_selection/_validation.py b/sklearn/model_selection/_validation.py index d65f2f8142c6f..60276db8ca361 100644 --- a/sklearn/model_selection/_validation.py +++ b/sklearn/model_selection/_validation.py @@ -1864,7 +1864,7 @@ def validation_curve( Parameters ---------- - estimator : object type that implements the "fit" method. + estimator : object type that implements the "fit" method An object of that type which is cloned for each validation. It must also implement "predict" unless `scoring` is a callable that doesn't rely on "predict" to compute a score.