Skip to content

ENH Adds support to xfail in check_estimator. #16963

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions doc/developers/develop.rst
Original file line number Diff line number Diff line change
Expand Up @@ -525,14 +525,13 @@ _skip_test (default=False)
_xfail_checks (default=False)
dictionary ``{check_name: reason}`` of common checks that will be marked
as `XFAIL` for pytest, when using
:func:`~sklearn.utils.estimator_checks.parametrize_with_checks`. This tag
currently has no effect on
:func:`~sklearn.utils.estimator_checks.check_estimator`.
:func:`~sklearn.utils.estimator_checks.parametrize_with_checks`.
Don't use this unless there is a *very good* reason for your estimator
not to pass the check.
Also note that the usage of this tag is highly subject to change because
we are trying to make it more flexible: be prepared for breaking changes
in the future.
:func:`~sklearn.utils.estimator_checks.check_estimator` will raise a
:class:`UserWarning` when the check fails. Also note that the usage of this
tag is highly subject to change because we are trying to make it more
flexible: be prepared for breaking changes in the future.

stateless (default=False)
whether the estimator needs access to data for fitting. Even though an
Expand Down
67 changes: 47 additions & 20 deletions sklearn/utils/estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import re
from copy import deepcopy
from functools import partial
from functools import wraps
from itertools import chain
from inspect import signature

Expand Down Expand Up @@ -334,20 +335,47 @@ def _construct_instance(Estimator):
return estimator


def _mark_xfail_checks(estimator, check, pytest):
"""Mark (estimator, check) pairs with xfail according to the
_xfail_checks_ tag"""
xfail_checks = estimator._get_tags()['_xfail_checks'] or {}
def _make_check_warn_on_fail(check, xfail_checks_tag):
"""Wrap the check so that a warning is raised when the check is in the
`xfail_checks` tag and the check properly failed as expected.

Checks that aren't in the xfail_checks tag aren't wrapped and are returned
as-is.
"""
check_name = _set_check_estimator_ids(check)
if check_name not in xfail_checks_tag:
return check

if check_name not in xfail_checks:
# check isn't part of the xfail_checks tags, just return it
return estimator, check
else:
# check is in the tag, mark it as xfail for pytest
reason = xfail_checks[check_name]
return pytest.param(estimator, check,
marks=pytest.mark.xfail(reason=reason))
reason = xfail_checks_tag[check_name]
@wraps(check)
def wrapped(*args, **kwargs):
try:
check(*args, **kwargs)
except Exception:
warnings.warn(reason, UserWarning)
return
return wrapped


def _generate_marked_checks(estimator, pytest):
"""Generate checks marked with pytest.mark.xfail according to the
_xfail_checks tag."""
name = type(estimator).__name__
checks_generator = ((estimator, partial(check, name))
for check in _yield_all_checks(name, estimator))

xfail_checks_tag = estimator._get_tags()['_xfail_checks'] or {}

for estimator, check in checks_generator:
check_name = _set_check_estimator_ids(check)
if check_name in xfail_checks_tag:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason to remove the previous comments? These can't hurt IMO

# check is in the xfail_checks tag, mark it as xfail for pytest
reason = xfail_checks_tag[check_name]
yield pytest.param(estimator, check,
marks=pytest.mark.xfail(reason=reason))
else:
# check isn't part of the xfail_checks tag
yield estimator, check


def parametrize_with_checks(estimators):
Expand Down Expand Up @@ -392,14 +420,10 @@ def parametrize_with_checks(estimators):
"Please pass an instance instead.")
raise TypeError(msg)

checks_generator = chain.from_iterable(
check_estimator(estimator, generate_only=True)
checks_with_marks = chain.from_iterable(
_generate_marked_checks(estimator, pytest)
for estimator in estimators)

checks_with_marks = (
_mark_xfail_checks(estimator, check, pytest)
for estimator, check in checks_generator)

return pytest.mark.parametrize("estimator, check", checks_with_marks,
ids=_set_check_estimator_ids)

Expand Down Expand Up @@ -455,8 +479,11 @@ def check_estimator(Estimator, generate_only=False):
estimator = Estimator
name = type(estimator).__name__

checks_generator = ((estimator, partial(check, name))
for check in _yield_all_checks(name, estimator))
xfail_checks_tag = estimator._get_tags()['_xfail_checks'] or {}
checks_generator = (
(estimator, _make_check_warn_on_fail(
partial(check, name), xfail_checks_tag=xfail_checks_tag))
for check in _yield_all_checks(name, estimator))

if generate_only:
return checks_generator
Expand Down
9 changes: 9 additions & 0 deletions sklearn/utils/tests/test_estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.utils import deprecated
from sklearn.utils._testing import (assert_raises_regex,
assert_warns_message,
ignore_warnings,
assert_warns, assert_raises,
SkipTest)
Expand All @@ -31,6 +32,7 @@
from sklearn.mixture import GaussianMixture
from sklearn.cluster import MiniBatchKMeans
from sklearn.decomposition import NMF
from sklearn.dummy import DummyClassifier
from sklearn.linear_model import MultiTaskElasticNet, LogisticRegression
from sklearn.svm import SVC
from sklearn.neighbors import KNeighborsRegressor
Expand Down Expand Up @@ -574,6 +576,13 @@ def test_check_regressor_data_not_an_array():
EstimatorInconsistentForPandas())


def test_check_estimator_xfail_tag_raises_skip_test_warning():
# skips check_complex_data based on _xfail_checks
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we just check on an estimator with a properly set xfail_tags instead, like the Dummy one?

I don't understand why we need to create a new estimator here

Copy link
Member Author

@thomasjpfan thomasjpfan Apr 26, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking that if the estimator gets fixed and the xfail_tags gets removed this test would start failing.

Changed to use Dummy.

assert_warns_message(UserWarning,
"fails for the predict method",
check_estimator, DummyClassifier())


def run_tests_without_pytest():
"""Runs the tests in this file without using pytest.
"""
Expand Down