diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index 92cbed36044bf..af98c1bc50a74 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -31,6 +31,7 @@ from sklearn.utils import IS_PYPY from sklearn.utils._testing import SkipTest from sklearn.utils.estimator_checks import ( + _mark_xfail_checks, _construct_instance, _set_checking_parameters, _set_check_estimator_ids, @@ -47,6 +48,24 @@ def test_all_estimator_no_base_class(): assert not name.lower().startswith('base'), msg +def test_estimator_cls_parameterize_with_checks(): + # Non-regression test for #16707 to ensure that parametrize_with_checks + # works with estimator classes + param_checks = parametrize_with_checks([LogisticRegression]) + # Using the generator does not raise + list(param_checks.args[1]) + + +def test_mark_xfail_checks_with_unconsructable_estimator(): + class MyEstimator: + def __init__(self): + raise ValueError("This is bad") + + estimator, check = _mark_xfail_checks(MyEstimator, 42, None) + assert estimator == MyEstimator + assert check == 42 + + @pytest.mark.parametrize( 'name, Estimator', all_estimators() diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 2cfb06c7994db..34a0e25c7fcaa 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -360,8 +360,17 @@ def _generate_class_checks(Estimator): def _mark_xfail_checks(estimator, check, pytest): """Mark estimator check pairs with xfail""" + if isinstance(estimator, type): + # try to construct estimator to get tags, if it is unable to then + # return the estimator class + try: + xfail_checks = _safe_tags(_construct_instance(estimator), + '_xfail_test') + except Exception: + return estimator, check + else: + xfail_checks = _safe_tags(estimator, '_xfail_test') - xfail_checks = _safe_tags(estimator, '_xfail_test') if not xfail_checks: return estimator, check