From 8212a776b7d51a95ec31373ffa88b79e6d86bd95 Mon Sep 17 00:00:00 2001 From: shadyelgewily-slimstock <90049412+shadyelgewily-slimstock@users.noreply.github.com> Date: Mon, 23 Jan 2023 10:24:56 +0100 Subject: [PATCH] Fixed faulty test, by applying cross-validation to SVC for classification and preventing statistical fluctuations --- .../model_selection/tests/test_validation.py | 38 +++++++++++++------ 1 file changed, 26 insertions(+), 12 deletions(-) diff --git a/sklearn/model_selection/tests/test_validation.py b/sklearn/model_selection/tests/test_validation.py index 9665422d534b1..dcffda71c1f19 100644 --- a/sklearn/model_selection/tests/test_validation.py +++ b/sklearn/model_selection/tests/test_validation.py @@ -426,8 +426,9 @@ def test_cross_validate(): train_r2_scores = [] test_r2_scores = [] fitted_estimators = [] + for train, test in cv.split(X, y): - est = clone(reg).fit(X[train], y[train]) + est = clone(est).fit(X[train], y[train]) train_mse_scores.append(mse_scorer(est, X[train], y[train])) train_r2_scores.append(r2_scorer(est, X[train], y[train])) test_mse_scores.append(mse_scorer(est, X[test], y[test])) @@ -448,11 +449,14 @@ def test_cross_validate(): fitted_estimators, ) - check_cross_validate_single_metric(est, X, y, scores) - check_cross_validate_multi_metric(est, X, y, scores) + # To ensure that the test does not suffer from + # large statistical fluctuations due to slicing small datasets, + # we pass the cross-validation instance + check_cross_validate_single_metric(est, X, y, scores, cv) + check_cross_validate_multi_metric(est, X, y, scores, cv) -def check_cross_validate_single_metric(clf, X, y, scores): +def check_cross_validate_single_metric(clf, X, y, scores, cv): ( train_mse_scores, test_mse_scores, @@ -465,12 +469,22 @@ def check_cross_validate_single_metric(clf, X, y, scores): # Single metric passed as a string if return_train_score: mse_scores_dict = cross_validate( - clf, X, y, scoring="neg_mean_squared_error", return_train_score=True + clf, + X, + y, + scoring="neg_mean_squared_error", + return_train_score=True, + cv=cv, ) assert_array_almost_equal(mse_scores_dict["train_score"], train_mse_scores) else: mse_scores_dict = cross_validate( - clf, X, y, scoring="neg_mean_squared_error", return_train_score=False + clf, + X, + y, + scoring="neg_mean_squared_error", + return_train_score=False, + cv=cv, ) assert isinstance(mse_scores_dict, dict) assert len(mse_scores_dict) == dict_len @@ -480,12 +494,12 @@ def check_cross_validate_single_metric(clf, X, y, scores): if return_train_score: # It must be True by default - deprecated r2_scores_dict = cross_validate( - clf, X, y, scoring=["r2"], return_train_score=True + clf, X, y, scoring=["r2"], return_train_score=True, cv=cv ) assert_array_almost_equal(r2_scores_dict["train_r2"], train_r2_scores, True) else: r2_scores_dict = cross_validate( - clf, X, y, scoring=["r2"], return_train_score=False + clf, X, y, scoring=["r2"], return_train_score=False, cv=cv ) assert isinstance(r2_scores_dict, dict) assert len(r2_scores_dict) == dict_len @@ -493,14 +507,14 @@ def check_cross_validate_single_metric(clf, X, y, scores): # Test return_estimator option mse_scores_dict = cross_validate( - clf, X, y, scoring="neg_mean_squared_error", return_estimator=True + clf, X, y, scoring="neg_mean_squared_error", return_estimator=True, cv=cv ) for k, est in enumerate(mse_scores_dict["estimator"]): assert_almost_equal(est.coef_, fitted_estimators[k].coef_) assert_almost_equal(est.intercept_, fitted_estimators[k].intercept_) -def check_cross_validate_multi_metric(clf, X, y, scores): +def check_cross_validate_multi_metric(clf, X, y, scores, cv): # Test multimetric evaluation when scoring is a list / dict ( train_mse_scores, @@ -541,7 +555,7 @@ def custom_scorer(clf, X, y): if return_train_score: # return_train_score must be True by default - deprecated cv_results = cross_validate( - clf, X, y, scoring=scoring, return_train_score=True + clf, X, y, scoring=scoring, return_train_score=True, cv=cv ) assert_array_almost_equal(cv_results["train_r2"], train_r2_scores) assert_array_almost_equal( @@ -549,7 +563,7 @@ def custom_scorer(clf, X, y): ) else: cv_results = cross_validate( - clf, X, y, scoring=scoring, return_train_score=False + clf, X, y, scoring=scoring, return_train_score=False, cv=cv ) assert isinstance(cv_results, dict) assert set(cv_results.keys()) == (