From 00ebfe856c31f4e4441f807babb993ef61983aa0 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 14 May 2020 09:40:07 -0400 Subject: [PATCH 1/2] Ignore xfail_checks in check_estimator --- doc/developers/develop.rst | 4 ++-- sklearn/utils/estimator_checks.py | 10 +++++++++- sklearn/utils/tests/test_estimator_checks.py | 8 +++++++- 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/doc/developers/develop.rst b/doc/developers/develop.rst index 13d2010ca7319..8c6c0ea281202 100644 --- a/doc/developers/develop.rst +++ b/doc/developers/develop.rst @@ -525,8 +525,8 @@ _skip_test (default=False) _xfail_checks (default=False) dictionary ``{check_name: reason}`` of common checks that will be marked as `XFAIL` for pytest, when using - :func:`~sklearn.utils.estimator_checks.parametrize_with_checks`. This tag - currently has no effect on + :func:`~sklearn.utils.estimator_checks.parametrize_with_checks`. These + checks will be simply ignored and not run by :func:`~sklearn.utils.estimator_checks.check_estimator`. Don't use this unless there is a *very good* reason for your estimator not to pass the check. diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 1d5bece4fa60d..d728d042e4c4d 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -350,6 +350,13 @@ def _mark_xfail_checks(estimator, check, pytest): marks=pytest.mark.xfail(reason=reason)) +def _is_xfail(estimator, check): + # Whether the check is part of the _xfail_checks tag of the estimator + xfail_checks = estimator._get_tags()['_xfail_checks'] or {} + check_name = _set_check_estimator_ids(check) + return check_name in xfail_checks + + def parametrize_with_checks(estimators): """Pytest specific decorator for parametrizing estimator checks. @@ -456,7 +463,8 @@ def check_estimator(Estimator, generate_only=False): name = type(estimator).__name__ checks_generator = ((estimator, partial(check, name)) - for check in _yield_all_checks(name, estimator)) + for check in _yield_all_checks(name, estimator) + if not _is_xfail(estimator, check)) if generate_only: return checks_generator diff --git a/sklearn/utils/tests/test_estimator_checks.py b/sklearn/utils/tests/test_estimator_checks.py index ad66021f3ba03..45ceb5b8267dc 100644 --- a/sklearn/utils/tests/test_estimator_checks.py +++ b/sklearn/utils/tests/test_estimator_checks.py @@ -32,7 +32,7 @@ from sklearn.cluster import MiniBatchKMeans from sklearn.decomposition import NMF from sklearn.linear_model import MultiTaskElasticNet, LogisticRegression -from sklearn.svm import SVC +from sklearn.svm import SVC, NuSVC from sklearn.neighbors import KNeighborsRegressor from sklearn.tree import DecisionTreeClassifier from sklearn.utils.validation import check_array @@ -609,3 +609,9 @@ def test_all_estimators_all_public(): # This module is run as a script to check that we have no dependency on # pytest for estimator checks. run_tests_without_pytest() + + +def test_xfail_ignored_in_check_estimator(): + # Make sure checks marked as xfail are just ignored and not run by + # check_estimator(). + check_estimator(NuSVC()) From 1c94856afbdca70e0932d1fa8f773cf56668decf Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 14 May 2020 10:43:44 -0400 Subject: [PATCH 2/2] Avoid ignoring in parametrize... --- sklearn/utils/estimator_checks.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index d728d042e4c4d..250bea5a48a83 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -6,7 +6,6 @@ import re from copy import deepcopy from functools import partial -from itertools import chain from inspect import signature import numpy as np @@ -399,9 +398,11 @@ def parametrize_with_checks(estimators): "Please pass an instance instead.") raise TypeError(msg) - checks_generator = chain.from_iterable( - check_estimator(estimator, generate_only=True) - for estimator in estimators) + names = (type(estimator).__name__ for estimator in estimators) + + checks_generator = ((estimator, partial(check, name)) + for name, estimator in zip(names, estimators) + for check in _yield_all_checks(name, estimator)) checks_with_marks = ( _mark_xfail_checks(estimator, check, pytest)