Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions doc/developers/develop.rst
Original file line number Diff line number Diff line change
Expand Up @@ -525,8 +525,8 @@ _skip_test (default=False)
_xfail_checks (default=False)
dictionary ``{check_name: reason}`` of common checks that will be marked
as `XFAIL` for pytest, when using
:func:`~sklearn.utils.estimator_checks.parametrize_with_checks`. This tag
currently has no effect on
:func:`~sklearn.utils.estimator_checks.parametrize_with_checks`. These
checks will be simply ignored and not run by
:func:`~sklearn.utils.estimator_checks.check_estimator`.
Don't use this unless there is a *very good* reason for your estimator
not to pass the check.
Expand Down
19 changes: 14 additions & 5 deletions sklearn/utils/estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import re
from copy import deepcopy
from functools import partial
from itertools import chain
from inspect import signature

import numpy as np
Expand Down Expand Up @@ -350,6 +349,13 @@ def _mark_xfail_checks(estimator, check, pytest):
marks=pytest.mark.xfail(reason=reason))


def _is_xfail(estimator, check):
# Whether the check is part of the _xfail_checks tag of the estimator
xfail_checks = estimator._get_tags()['_xfail_checks'] or {}
check_name = _set_check_estimator_ids(check)
return check_name in xfail_checks


def parametrize_with_checks(estimators):
"""Pytest specific decorator for parametrizing estimator checks.

Expand Down Expand Up @@ -392,9 +398,11 @@ def parametrize_with_checks(estimators):
"Please pass an instance instead.")
raise TypeError(msg)

checks_generator = chain.from_iterable(
check_estimator(estimator, generate_only=True)
for estimator in estimators)
names = (type(estimator).__name__ for estimator in estimators)

checks_generator = ((estimator, partial(check, name))
for name, estimator in zip(names, estimators)
for check in _yield_all_checks(name, estimator))

checks_with_marks = (
_mark_xfail_checks(estimator, check, pytest)
Expand Down Expand Up @@ -456,7 +464,8 @@ def check_estimator(Estimator, generate_only=False):
name = type(estimator).__name__

checks_generator = ((estimator, partial(check, name))
for check in _yield_all_checks(name, estimator))
for check in _yield_all_checks(name, estimator)
if not _is_xfail(estimator, check))

if generate_only:
return checks_generator
Expand Down
8 changes: 7 additions & 1 deletion sklearn/utils/tests/test_estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from sklearn.cluster import MiniBatchKMeans
from sklearn.decomposition import NMF
from sklearn.linear_model import MultiTaskElasticNet, LogisticRegression
from sklearn.svm import SVC
from sklearn.svm import SVC, NuSVC
from sklearn.neighbors import KNeighborsRegressor
from sklearn.tree import DecisionTreeClassifier
from sklearn.utils.validation import check_array
Expand Down Expand Up @@ -609,3 +609,9 @@ def test_all_estimators_all_public():
# This module is run as a script to check that we have no dependency on
# pytest for estimator checks.
run_tests_without_pytest()


def test_xfail_ignored_in_check_estimator():
# Make sure checks marked as xfail are just ignored and not run by
# check_estimator().
check_estimator(NuSVC())