Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 24 additions & 9 deletions sklearn/grid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,18 +264,33 @@ def fit_grid_point(X, y, base_estimator, parameters, train, test, scorer,
if y is not None:
y_test = y[safe_mask(y, test)]
y_train = y[safe_mask(y, train)]
clf.fit(X_train, y_train, **fit_params)

if scorer is not None:
this_score = scorer(clf, X_test, y_test)
try:
clf.fit(X_train, y_train, **fit_params)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The duplication could be avoided with something like clf.fit(*fit_args, **fit_params) where fit_args is set differently for the y is None and y is not None cases. (In some PR related to grid search somewhere I have implemented it this way, but the remainder of the PR was too controversial to merge as yet.)

except ValueError as e:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably be catching Exception not ValueError.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My reasoning here was that if the classifier fails on some specific folds but not others then this is due to data and should probably result in a ValueError. It seems that it is generally recommended to only catch a subset of exceptions rather than all of them. But perhaps assuming that it will be a ValueError is going too far.

# If the classifier fails, the score is set to 0 by default
this_score = 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not certain 0 is always the correct value. For example, for correlation metrics, might -1 be preferable sometimes? Maybe to support this functionality we should have an additional parameter to *SearchCV:

    on_error : float or 'raise', default 0.0
        Set the score to the given float value if an exception occurs while fitting a model. If set to `raise`, exceptions will be raised instead.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this idea.

warnings.warn("Classifier fit failed. The score for this fold on "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder how much we need to to help the user understand why/where it failed.

In terms of where: should we be noting the parameters and the fold number that broke? Perhaps not if all non-failing folds scored > 0, in which case it's easy enough to inspect the search results.

In terms of why: Firstly, str(e) doesn't say a lot, and repr(e) is better, but it's still not a full traceback. Providing a switch to raise errors instead might be beneficial, and means we needn't provide a detailed account of the error here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose we could print out the parameters that break the classfier but with fold number it could be more tricky because fit_grid_point() doesn't know the index of the fold that it is fitting. We could add an extra parameter to let it know that but I'm not sure if the benefit is worth the extra complexity.

Switching from str(e) to repr(e) seems like a good idea.

Same for providing a switch to raise errors, although again this means an extra parameter for fit_grid_point()

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think there's harm in an extra parameter for fit_grid_point, and elsewhere there have been other motivations for that function knowing which fold index it corresponds to (I recall a discussion that I now can't find, related to tracking errors in folds to a particular fold index; also, #2079).

"this test point will be set to zero. Details: " +
str(e), RuntimeWarning)
else:
this_score = clf.score(X_test, y_test)
if scorer is not None:
this_score = scorer(clf, X_test, y_test)
else:
this_score = clf.score(X_test, y_test)
else:
clf.fit(X_train, **fit_params)
if scorer is not None:
this_score = scorer(clf, X_test)
try:
clf.fit(X_train, **fit_params)
except ValueError as e:
# If the classifier fails, the score is set to 0 by default
this_score = 0
warnings.warn("Classifier fit failed. The score for this fold on "
"this test point will be set to zero. Details: " +
str(e), RuntimeWarning)
else:
this_score = clf.score(X_test)
if scorer is not None:
this_score = scorer(clf, X_test)
else:
this_score = clf.score(X_test)

if not isinstance(this_score, numbers.Number):
raise ValueError("scoring must return a number, got %s (%s)"
Expand Down
76 changes: 76 additions & 0 deletions sklearn/tests/test_grid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,3 +642,79 @@ def test_grid_search_with_multioutput_data():
correct_score = est.score(X[test], y[test])
assert_almost_equal(correct_score,
cv_validation_scores[i])


class FailingClassifier(BaseEstimator):
""" Classifier that raises a ValueError on fit() """

def __init__(self, parameter=None):
self.parameter = parameter

def fit(self, X, y=None):
raise ValueError("Failing classifier failed as requiered")

def predict(self, X):
return np.zeros(X.shape[0])


def test_grid_search_failing_classifier():
""" GridSearchCV with a failing classifier catches the error, sets the
score to zero and raises a warning """

with warnings.catch_warnings(record=True) as w:

# Cause all warnings to always be triggered.
warnings.simplefilter("always")

X, y = make_classification(n_samples=20, n_features=10, random_state=0)

clf = FailingClassifier()

# refit=False because we only want to check that errors caused by fits
# to individual folds will be caught and warnings raised instead. If
# refit was done, then an exception would be raised on refit and not
# caught by grid_search (expected behavior), and this would cause an
# error in this test.
gs = GridSearchCV(clf, [{'parameter': [0, 1, 2]}],
scoring='accuracy', refit=False)
gs.fit(X, y)

# Ensure that grid scores were set to zero as required
assert all(np.all(this_score.cv_validation_scores == 0.)
for this_score in gs.grid_scores_)

# Ensure that a warning was raised
assert len(w) > 0


def test_grid_search_failing_classifier_no_y():
""" GridSearchCV with a failing classifier catches the error, sets the
score to zero and raises a warning

This test is for the additional case when no y is given to grid_search.
"""

with warnings.catch_warnings(record=True) as w:

# Cause all warnings to always be triggered.
warnings.simplefilter("always")

X, _ = make_classification(n_samples=20, n_features=10, random_state=0)

clf = FailingClassifier()

# refit=False because we only want to check that errors caused by fits
# to individual folds will be caught and warnings raised instead. If
# refit was done, then an exception would be raised on refit and not
# caught by grid_search (expected behavior), and this would cause an
# error in this test.
gs = GridSearchCV(clf, [{'parameter': [0, 1, 2]}],
scoring='accuracy', refit=False)
gs.fit(X)

# Ensure that grid scores were set to zero as required
assert all(np.all(this_score.cv_validation_scores == 0.)
for this_score in gs.grid_scores_)

# Ensure that a warning was raised
assert len(w) > 0