Skip to content

Warn in the main process when a fit fails during a cross-validation #20619

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Aug 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/whats_new/v1.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,9 @@ Changelog
:pr:`18649` by :user:`Leandro Hermida <hermidalc>` and
:user:`Rodion Martynov <marrodion>`.

- |Enhancement| warn only once in the main process for per-split fit failures
in cross-validation. :pr:`20619` by :user:`Loïc Estève <lesteve>`

:mod:`sklearn.naive_bayes`
..........................

Expand Down
5 changes: 5 additions & 0 deletions sklearn/model_selection/_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from ._validation import _aggregate_score_dicts
from ._validation import _insert_error_scores
from ._validation import _normalize_score_results
from ._validation import _warn_about_fit_failures
from ..exceptions import NotFittedError
from joblib import Parallel
from ..utils import check_random_state
Expand Down Expand Up @@ -793,14 +794,18 @@ def evaluate_candidates(candidate_params, cv=None, more_results=None):
"splits, got {}".format(n_splits, len(out) // n_candidates)
)

_warn_about_fit_failures(out, self.error_score)

# For callable self.scoring, the return type is only know after
# calling. If the return type is a dictionary, the error scores
# can now be inserted with the correct key. The type checking
# of out will be done in `_insert_error_scores`.
if callable(self.scoring):
_insert_error_scores(out, self.error_score)

all_candidate_params.extend(candidate_params)
all_out.extend(out)

if more_results is not None:
for key, value in more_results.items():
all_more_results[key].extend(value)
Expand Down
44 changes: 33 additions & 11 deletions sklearn/model_selection/_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import time
from traceback import format_exc
from contextlib import suppress
from collections import Counter

import numpy as np
import scipy.sparse as sp
Expand Down Expand Up @@ -281,6 +282,8 @@ def cross_validate(
for train, test in cv.split(X, y, groups)
)

_warn_about_fit_failures(results, error_score)

# For callabe scoring, the return type is only know after calling. If the
# return type is a dictionary, the error scores can now be inserted with
# the correct key.
Expand Down Expand Up @@ -318,7 +321,7 @@ def _insert_error_scores(results, error_score):
successful_score = None
failed_indices = []
for i, result in enumerate(results):
if result["fit_failed"]:
if result["fit_error"] is not None:
failed_indices.append(i)
elif successful_score is None:
successful_score = result["test_scores"]
Expand All @@ -343,6 +346,31 @@ def _normalize_score_results(scores, scaler_score_key="score"):
return {scaler_score_key: scores}


def _warn_about_fit_failures(results, error_score):
fit_errors = [
result["fit_error"] for result in results if result["fit_error"] is not None
]
if fit_errors:
num_failed_fits = len(fit_errors)
num_fits = len(results)
fit_errors_counter = Counter(fit_errors)
delimiter = "-" * 80 + "\n"
fit_errors_summary = "\n".join(
f"{delimiter}{n} fits failed with the following error:\n{error}"
for error, n in fit_errors_counter.items()
)

some_fits_failed_message = (
f"\n{num_failed_fits} fits failed out of a total of {num_fits}.\n"
"The score on these train-test partitions for these parameters"
f" will be set to {error_score}.\n"
"If these failures are not expected, you can try to debug them "
"by setting error_score='raise'.\n\n"
f"Below are more details about the failures:\n{fit_errors_summary}"
)
warnings.warn(some_fits_failed_message, FitFailedWarning)


def cross_val_score(
estimator,
X,
Expand Down Expand Up @@ -598,8 +626,8 @@ def _fit_and_score(
The parameters that have been evaluated.
estimator : estimator object
The fitted estimator.
fit_failed : bool
The estimator failed to fit.
fit_error : str or None
Traceback str if the fit failed, None if the fit succeeded.
"""
if not isinstance(error_score, numbers.Number) and error_score != "raise":
raise ValueError(
Expand Down Expand Up @@ -666,15 +694,9 @@ def _fit_and_score(
test_scores = error_score
if return_train_score:
train_scores = error_score
warnings.warn(
"Estimator fit failed. The score on this train-test"
" partition for these parameters will be set to %f. "
"Details: \n%s" % (error_score, format_exc()),
FitFailedWarning,
)
result["fit_failed"] = True
result["fit_error"] = format_exc()
else:
result["fit_failed"] = False
result["fit_error"] = None

fit_time = time.time() - start_time
test_scores = _score(estimator, X_test, y_test, scorer, error_score)
Expand Down
33 changes: 23 additions & 10 deletions sklearn/model_selection/tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -1565,9 +1565,13 @@ def test_grid_search_failing_classifier():
refit=False,
error_score=0.0,
)
warning_message = (
"Estimator fit failed. The score on this train-test partition "
"for these parameters will be set to 0.0.*."

warning_message = re.compile(
"5 fits failed.+total of 15.+The score on these"
r" train-test partitions for these parameters will be set to 0\.0.+"
"5 fits failed with the following error.+ValueError.+Failing classifier failed"
" as required",
flags=re.DOTALL,
)
with pytest.warns(FitFailedWarning, match=warning_message):
gs.fit(X, y)
Expand Down Expand Up @@ -1598,9 +1602,12 @@ def get_cand_scores(i):
refit=False,
error_score=float("nan"),
)
warning_message = (
"Estimator fit failed. The score on this train-test partition "
"for these parameters will be set to nan."
warning_message = re.compile(
"5 fits failed.+total of 15.+The score on these"
r" train-test partitions for these parameters will be set to nan.+"
"5 fits failed with the following error.+ValueError.+Failing classifier failed"
" as required",
flags=re.DOTALL,
)
with pytest.warns(FitFailedWarning, match=warning_message):
gs.fit(X, y)
Expand Down Expand Up @@ -2112,7 +2119,12 @@ def custom_scorer(est, X, y):
error_score=0.1,
)

with pytest.warns(FitFailedWarning, match="Estimator fit failed"):
warning_message = re.compile(
"5 fits failed.+total of 15.+The score on these"
r" train-test partitions for these parameters will be set to 0\.1",
flags=re.DOTALL,
)
with pytest.warns(FitFailedWarning, match=warning_message):
gs.fit(X, y)

assert_allclose(gs.cv_results_["mean_test_acc"], [1, 1, 0.1])
Expand All @@ -2135,9 +2147,10 @@ def custom_scorer(est, X, y):
error_score=0.1,
)

with pytest.warns(FitFailedWarning, match="Estimator fit failed"), pytest.raises(
NotFittedError, match="All estimators failed to fit"
):
with pytest.warns(
FitFailedWarning,
match="15 fits failed.+total of 15",
), pytest.raises(NotFittedError, match="All estimators failed to fit"):
gs.fit(X, y)


Expand Down
66 changes: 35 additions & 31 deletions sklearn/model_selection/tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2082,37 +2082,6 @@ def test_fit_and_score_failing():
y = np.ones(9)
fit_and_score_args = [failing_clf, X, None, dict(), None, None, 0, None, None]
# passing error score to trigger the warning message
fit_and_score_kwargs = {"error_score": 0}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The check on the warnings were moved below to test_cross_validate_failing_fits_warnings since _fit_and_score does not warn anymore

# check if the warning message type is as expected
warning_message = (
"Estimator fit failed. The score on this train-test partition for "
"these parameters will be set to %f." % (fit_and_score_kwargs["error_score"])
)
with pytest.warns(FitFailedWarning, match=warning_message):
_fit_and_score(*fit_and_score_args, **fit_and_score_kwargs)
# since we're using FailingClassfier, our error will be the following
error_message = "ValueError: Failing classifier failed as required"
# the warning message we're expecting to see
warning_message = (
"Estimator fit failed. The score on this train-test "
"partition for these parameters will be set to %f. "
"Details: \n%s" % (fit_and_score_kwargs["error_score"], error_message)
)

def test_warn_trace(msg):
assert "Traceback (most recent call last):\n" in msg
split = msg.splitlines() # note: handles more than '\n'
mtb = split[0] + "\n" + split[-1]
return warning_message in mtb

# check traceback is included
warning_message = (
"Estimator fit failed. The score on this train-test partition for "
"these parameters will be set to %f." % (fit_and_score_kwargs["error_score"])
)
with pytest.warns(FitFailedWarning, match=warning_message):
_fit_and_score(*fit_and_score_args, **fit_and_score_kwargs)

fit_and_score_kwargs = {"error_score": "raise"}
# check if exception was raised, with default error_score='raise'
with pytest.raises(ValueError, match="Failing classifier failed as required"):
Expand Down Expand Up @@ -2161,6 +2130,41 @@ def test_fit_and_score_working():
assert result["parameters"] == fit_and_score_kwargs["parameters"]


@pytest.mark.parametrize("error_score", [np.nan, 0])
def test_cross_validate_failing_fits_warnings(error_score):
# Create a failing classifier to deliberately fail
failing_clf = FailingClassifier(FailingClassifier.FAILING_PARAMETER)
# dummy X data
X = np.arange(1, 10)
y = np.ones(9)
# fit_and_score_args = [failing_clf, X, None, dict(), None, None, 0, None, None]
# passing error score to trigger the warning message
cross_validate_args = [failing_clf, X, y]
cross_validate_kwargs = {"cv": 7, "error_score": error_score}
# check if the warning message type is as expected
warning_message = re.compile(
"7 fits failed.+total of 7.+The score on these"
" train-test partitions for these parameters will be set to"
f" {cross_validate_kwargs['error_score']}.",
flags=re.DOTALL,
)

with pytest.warns(FitFailedWarning, match=warning_message):
cross_validate(*cross_validate_args, **cross_validate_kwargs)

# since we're using FailingClassfier, our error will be the following
error_message = "ValueError: Failing classifier failed as required"

# check traceback is included
warning_message = re.compile(
"The score on these train-test partitions for these parameters will be set"
f" to {cross_validate_kwargs['error_score']}.+{error_message}",
re.DOTALL,
)
with pytest.warns(FitFailedWarning, match=warning_message):
cross_validate(*cross_validate_args, **cross_validate_kwargs)


def _failing_scorer(estimator, X, y, error_msg):
raise ValueError(error_msg)

Expand Down