Skip to content

Commit 93bc20f

Browse files
authored
Raise an error when all fits fail in cross-validation or grid-search (#21026)
1 parent 0a88cf8 commit 93bc20f

File tree

6 files changed

+111
-46
lines changed

6 files changed

+111
-46
lines changed

doc/whats_new/v1.1.rst

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,19 @@ Changelog
5353
message when the solver does not support sparse matrices with int64 indices.
5454
:pr:`21093` by `Tom Dupre la Tour`_.
5555

56+
:mod:`sklearn.model_selection`
57+
..............................
58+
59+
- |Enhancement| raise an error during cross-validation when the fits for all the
60+
splits failed. Similarly raise an error during grid-search when the fits for
61+
all the models and all the splits failed. :pr:`21026` by :user:`Loïc Estève <lesteve>`.
62+
63+
:mod:`sklearn.pipeline`
64+
.......................
65+
66+
- |Enhancement| Added support for "passthrough" in :class:`FeatureUnion`.
67+
Setting a transformer to "passthrough" will pass the features unchanged.
68+
:pr:`20860` by :user:`Shubhraneel Pal <shubhraneel>`.
5669

5770
:mod:`sklearn.utils`
5871
....................
@@ -69,13 +82,6 @@ Changelog
6982
:pr:`20880` by :user:`Guillaume Lemaitre <glemaitre>`
7083
and :user:`András Simon <simonandras>`.
7184

72-
:mod:`sklearn.pipeline`
73-
.......................
74-
75-
- |Enhancement| Added support for "passthrough" in :class:`FeatureUnion`.
76-
Setting a transformer to "passthrough" will pass the features unchanged.
77-
:pr:`20860` by :user:`Shubhraneel Pal <shubhraneel>`.
78-
7985
Code and Documentation Contributors
8086
-----------------------------------
8187

sklearn/model_selection/_search.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from ._validation import _aggregate_score_dicts
3232
from ._validation import _insert_error_scores
3333
from ._validation import _normalize_score_results
34-
from ._validation import _warn_about_fit_failures
34+
from ._validation import _warn_or_raise_about_fit_failures
3535
from ..exceptions import NotFittedError
3636
from joblib import Parallel
3737
from ..utils import check_random_state
@@ -865,7 +865,7 @@ def evaluate_candidates(candidate_params, cv=None, more_results=None):
865865
"splits, got {}".format(n_splits, len(out) // n_candidates)
866866
)
867867

868-
_warn_about_fit_failures(out, self.error_score)
868+
_warn_or_raise_about_fit_failures(out, self.error_score)
869869

870870
# For callable self.scoring, the return type is only know after
871871
# calling. If the return type is a dictionary, the error scores

sklearn/model_selection/_validation.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from ..utils.metaestimators import _safe_split
3131
from ..metrics import check_scoring
3232
from ..metrics._scorer import _check_multimetric_scoring, _MultimetricScorer
33-
from ..exceptions import FitFailedWarning, NotFittedError
33+
from ..exceptions import FitFailedWarning
3434
from ._split import check_cv
3535
from ..preprocessing import LabelEncoder
3636

@@ -283,7 +283,7 @@ def cross_validate(
283283
for train, test in cv.split(X, y, groups)
284284
)
285285

286-
_warn_about_fit_failures(results, error_score)
286+
_warn_or_raise_about_fit_failures(results, error_score)
287287

288288
# For callabe scoring, the return type is only know after calling. If the
289289
# return type is a dictionary, the error scores can now be inserted with
@@ -327,9 +327,6 @@ def _insert_error_scores(results, error_score):
327327
elif successful_score is None:
328328
successful_score = result["test_scores"]
329329

330-
if successful_score is None:
331-
raise NotFittedError("All estimators failed to fit")
332-
333330
if isinstance(successful_score, dict):
334331
formatted_error = {name: error_score for name in successful_score}
335332
for i in failed_indices:
@@ -347,7 +344,7 @@ def _normalize_score_results(scores, scaler_score_key="score"):
347344
return {scaler_score_key: scores}
348345

349346

350-
def _warn_about_fit_failures(results, error_score):
347+
def _warn_or_raise_about_fit_failures(results, error_score):
351348
fit_errors = [
352349
result["fit_error"] for result in results if result["fit_error"] is not None
353350
]
@@ -361,15 +358,25 @@ def _warn_about_fit_failures(results, error_score):
361358
for error, n in fit_errors_counter.items()
362359
)
363360

364-
some_fits_failed_message = (
365-
f"\n{num_failed_fits} fits failed out of a total of {num_fits}.\n"
366-
"The score on these train-test partitions for these parameters"
367-
f" will be set to {error_score}.\n"
368-
"If these failures are not expected, you can try to debug them "
369-
"by setting error_score='raise'.\n\n"
370-
f"Below are more details about the failures:\n{fit_errors_summary}"
371-
)
372-
warnings.warn(some_fits_failed_message, FitFailedWarning)
361+
if num_failed_fits == num_fits:
362+
all_fits_failed_message = (
363+
f"\nAll the {num_fits} fits failed.\n"
364+
"It is is very likely that your model is misconfigured.\n"
365+
"You can try to debug the error by setting error_score='raise'.\n\n"
366+
f"Below are more details about the failures:\n{fit_errors_summary}"
367+
)
368+
raise ValueError(all_fits_failed_message)
369+
370+
else:
371+
some_fits_failed_message = (
372+
f"\n{num_failed_fits} fits failed out of a total of {num_fits}.\n"
373+
"The score on these train-test partitions for these parameters"
374+
f" will be set to {error_score}.\n"
375+
"If these failures are not expected, you can try to debug them "
376+
"by setting error_score='raise'.\n\n"
377+
f"Below are more details about the failures:\n{fit_errors_summary}"
378+
)
379+
warnings.warn(some_fits_failed_message, FitFailedWarning)
373380

374381

375382
def cross_val_score(

sklearn/model_selection/tests/test_search.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929

3030
from sklearn.base import BaseEstimator, ClassifierMixin
3131
from sklearn.base import is_classifier
32-
from sklearn.exceptions import NotFittedError
3332
from sklearn.datasets import make_classification
3433
from sklearn.datasets import make_blobs
3534
from sklearn.datasets import make_multilabel_classification
@@ -1628,6 +1627,27 @@ def get_cand_scores(i):
16281627
assert gs.best_index_ != clf.FAILING_PARAMETER
16291628

16301629

1630+
def test_grid_search_classifier_all_fits_fail():
1631+
X, y = make_classification(n_samples=20, n_features=10, random_state=0)
1632+
1633+
clf = FailingClassifier()
1634+
1635+
gs = GridSearchCV(
1636+
clf,
1637+
[{"parameter": [FailingClassifier.FAILING_PARAMETER] * 3}],
1638+
error_score=0.0,
1639+
)
1640+
1641+
warning_message = re.compile(
1642+
"All the 15 fits failed.+"
1643+
"15 fits failed with the following error.+ValueError.+Failing classifier failed"
1644+
" as required",
1645+
flags=re.DOTALL,
1646+
)
1647+
with pytest.raises(ValueError, match=warning_message):
1648+
gs.fit(X, y)
1649+
1650+
16311651
def test_grid_search_failing_classifier_raise():
16321652
# GridSearchCV with on_error == 'raise' raises the error
16331653

@@ -2130,7 +2150,7 @@ def custom_scorer(est, X, y):
21302150
assert_allclose(gs.cv_results_["mean_test_acc"], [1, 1, 0.1])
21312151

21322152

2133-
def test_callable_multimetric_clf_all_fails():
2153+
def test_callable_multimetric_clf_all_fits_fail():
21342154
# Warns and raises when all estimator fails to fit.
21352155
def custom_scorer(est, X, y):
21362156
return {"acc": 1}
@@ -2141,16 +2161,20 @@ def custom_scorer(est, X, y):
21412161

21422162
gs = GridSearchCV(
21432163
clf,
2144-
[{"parameter": [2, 2, 2]}],
2164+
[{"parameter": [FailingClassifier.FAILING_PARAMETER] * 3}],
21452165
scoring=custom_scorer,
21462166
refit=False,
21472167
error_score=0.1,
21482168
)
21492169

2150-
with pytest.warns(
2151-
FitFailedWarning,
2152-
match="15 fits failed.+total of 15",
2153-
), pytest.raises(NotFittedError, match="All estimators failed to fit"):
2170+
individual_fit_error_message = "ValueError: Failing classifier failed as required"
2171+
error_message = re.compile(
2172+
"All the 15 fits failed.+your model is misconfigured.+"
2173+
f"{individual_fit_error_message}",
2174+
flags=re.DOTALL,
2175+
)
2176+
2177+
with pytest.raises(ValueError, match=error_message):
21542178
gs.fit(X, y)
21552179

21562180

sklearn/model_selection/tests/test_split.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1774,7 +1774,7 @@ def test_nested_cv():
17741774
LeaveOneOut(),
17751775
GroupKFold(n_splits=3),
17761776
StratifiedKFold(),
1777-
StratifiedGroupKFold(),
1777+
StratifiedGroupKFold(n_splits=3),
17781778
StratifiedShuffleSplit(n_splits=3, random_state=0),
17791779
]
17801780

sklearn/model_selection/tests/test_validation.py

Lines changed: 42 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2130,38 +2130,66 @@ def test_fit_and_score_working():
21302130
assert result["parameters"] == fit_and_score_kwargs["parameters"]
21312131

21322132

2133+
class DataDependentFailingClassifier(BaseEstimator):
2134+
def __init__(self, max_x_value=None):
2135+
self.max_x_value = max_x_value
2136+
2137+
def fit(self, X, y=None):
2138+
num_values_too_high = (X > self.max_x_value).sum()
2139+
if num_values_too_high:
2140+
raise ValueError(
2141+
f"Classifier fit failed with {num_values_too_high} values too high"
2142+
)
2143+
2144+
def score(self, X=None, Y=None):
2145+
return 0.0
2146+
2147+
21332148
@pytest.mark.parametrize("error_score", [np.nan, 0])
2134-
def test_cross_validate_failing_fits_warnings(error_score):
2149+
def test_cross_validate_some_failing_fits_warning(error_score):
21352150
# Create a failing classifier to deliberately fail
2136-
failing_clf = FailingClassifier(FailingClassifier.FAILING_PARAMETER)
2151+
failing_clf = DataDependentFailingClassifier(max_x_value=8)
21372152
# dummy X data
21382153
X = np.arange(1, 10)
21392154
y = np.ones(9)
2140-
# fit_and_score_args = [failing_clf, X, None, dict(), None, None, 0, None, None]
21412155
# passing error score to trigger the warning message
21422156
cross_validate_args = [failing_clf, X, y]
2143-
cross_validate_kwargs = {"cv": 7, "error_score": error_score}
2157+
cross_validate_kwargs = {"cv": 3, "error_score": error_score}
21442158
# check if the warning message type is as expected
2159+
2160+
individual_fit_error_message = (
2161+
"ValueError: Classifier fit failed with 1 values too high"
2162+
)
21452163
warning_message = re.compile(
2146-
"7 fits failed.+total of 7.+The score on these"
2164+
"2 fits failed.+total of 3.+The score on these"
21472165
" train-test partitions for these parameters will be set to"
2148-
f" {cross_validate_kwargs['error_score']}.",
2166+
f" {cross_validate_kwargs['error_score']}.+{individual_fit_error_message}",
21492167
flags=re.DOTALL,
21502168
)
21512169

21522170
with pytest.warns(FitFailedWarning, match=warning_message):
21532171
cross_validate(*cross_validate_args, **cross_validate_kwargs)
21542172

2155-
# since we're using FailingClassfier, our error will be the following
2156-
error_message = "ValueError: Failing classifier failed as required"
21572173

2158-
# check traceback is included
2159-
warning_message = re.compile(
2160-
"The score on these train-test partitions for these parameters will be set"
2161-
f" to {cross_validate_kwargs['error_score']}.+{error_message}",
2162-
re.DOTALL,
2174+
@pytest.mark.parametrize("error_score", [np.nan, 0])
2175+
def test_cross_validate_all_failing_fits_error(error_score):
2176+
# Create a failing classifier to deliberately fail
2177+
failing_clf = FailingClassifier(FailingClassifier.FAILING_PARAMETER)
2178+
# dummy X data
2179+
X = np.arange(1, 10)
2180+
y = np.ones(9)
2181+
2182+
cross_validate_args = [failing_clf, X, y]
2183+
cross_validate_kwargs = {"cv": 7, "error_score": error_score}
2184+
2185+
individual_fit_error_message = "ValueError: Failing classifier failed as required"
2186+
error_message = re.compile(
2187+
"All the 7 fits failed.+your model is misconfigured.+"
2188+
f"{individual_fit_error_message}",
2189+
flags=re.DOTALL,
21632190
)
2164-
with pytest.warns(FitFailedWarning, match=warning_message):
2191+
2192+
with pytest.raises(ValueError, match=error_message):
21652193
cross_validate(*cross_validate_args, **cross_validate_kwargs)
21662194

21672195

0 commit comments

Comments
 (0)