From 8775c1eb8c7336e34ba1a430502a3ed2158bc891 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Sat, 16 May 2020 10:48:52 -0400 Subject: [PATCH 01/15] treat strict checks as xfail checks --- sklearn/tests/test_common.py | 25 ++++++++++++++- sklearn/utils/estimator_checks.py | 52 +++++++++++++++++++++++++------ 2 files changed, 67 insertions(+), 10 deletions(-) diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index b9f50a76f7b30..c4e064f0fb616 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -19,7 +19,7 @@ from sklearn.utils import all_estimators from sklearn.utils._testing import ignore_warnings -from sklearn.exceptions import ConvergenceWarning +from sklearn.exceptions import ConvergenceWarning, SkipTestWarning from sklearn.utils.estimator_checks import check_estimator import sklearn @@ -27,6 +27,7 @@ from sklearn.linear_model._base import LinearClassifierMixin from sklearn.linear_model import LogisticRegression +from sklearn.svm import NuSVC from sklearn.utils import IS_PYPY from sklearn.utils._testing import SkipTest from sklearn.utils.estimator_checks import ( @@ -204,3 +205,25 @@ def test_class_support_removed(): with pytest.raises(TypeError, match=msg): parametrize_with_checks([LogisticRegression]) + + +def test_strict_mode_check_estimator(): + # Make sure the strict checks are properly ignored when strict mode is off + # in check_estimator. + # We can't check the message because check_estimator doesn't give one. + + with pytest.warns(SkipTestWarning): + # LogisticRegression has no _xfail_checks, but check_n_features_in is + # still skipped because it's a strict check + check_estimator(LogisticRegression(), strict_mode=False) + + with pytest.warns(SkipTestWarning): + # NuSVC has some _xfail_checks. check_n_features_in is skipped along + # with the other checks in the tag. + check_estimator(NuSVC(), strict_mode=False) + + +@parametrize_with_checks([LogisticRegression(), NuSVC()], strict_mode=False) +def test_strict_mode_parametrize_with_checks(estimator, check): + # Not sure how to test parametrize_with_checks correctly?? + check(estimator) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index a95df7103503e..289f6b2e10b69 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -339,10 +339,10 @@ def _construct_instance(Estimator): return estimator -def _mark_xfail_checks(estimator, check, pytest): +def _mark_xfail_checks(estimator, check, strict_mode, pytest): """Mark (estimator, check) pairs with xfail according to the _xfail_checks_ tag""" - xfail_checks = estimator._get_tags()['_xfail_checks'] or {} + xfail_checks = _get_xfail_checks(estimator, strict_mode) check_name = _set_check_estimator_ids(check) if check_name not in xfail_checks: @@ -355,10 +355,10 @@ def _mark_xfail_checks(estimator, check, pytest): marks=pytest.mark.xfail(reason=reason)) -def _skip_if_xfail(estimator, check): +def _skip_if_xfail(estimator, check, strict_mode): # wrap a check so that it's skipped with a warning if it's part of the # xfail_checks tag. - xfail_checks = estimator._get_tags()['_xfail_checks'] or {} + xfail_checks = _get_xfail_checks(estimator, strict_mode) check_name = _set_check_estimator_ids(check) if check_name not in xfail_checks: @@ -373,7 +373,23 @@ def wrapped(*args, **kwargs): return wrapped -def parametrize_with_checks(estimators): +def _get_xfail_checks(estimator, strict_mode): + # Return the checks that are in the estimator's _xfail_checks tag, along + # with the strict checks if strict_mode is False. + xfail_checks = estimator._get_tags()['_xfail_checks'] or {} + + if not strict_mode: + strict_checks = { + _set_check_estimator_ids(check): + 'The check is strict and strict mode is off' # the reason + for check in _STRICT_CHECKS + } + xfail_checks.update(strict_checks) + + return xfail_checks + + +def parametrize_with_checks(estimators, strict_mode=True): """Pytest specific decorator for parametrizing estimator checks. The `id` of each check is set to be a pprint version of the estimator @@ -391,6 +407,12 @@ def parametrize_with_checks(estimators): Passing a class was deprecated in version 0.23, and support for classes was removed in 0.24. Pass an instance instead. + strict_mode : bool, default=True + If False, the strict checks will be treated as if they were in the + estimators' `_xfails_checks` tag: they will be marked as `xfail` for + pytest. The list of strict checks is at TODO. See TODO link for more + info on the `_xfails_check` tag. + Returns ------- decorator : `pytest.mark.parametrize` @@ -422,14 +444,14 @@ def parametrize_with_checks(estimators): for check in _yield_all_checks(estimator)) checks_with_marks = ( - _mark_xfail_checks(estimator, check, pytest) + _mark_xfail_checks(estimator, check, strict_mode, pytest) for estimator, check in checks_generator) return pytest.mark.parametrize("estimator, check", checks_with_marks, ids=_set_check_estimator_ids) -def check_estimator(Estimator, generate_only=False): +def check_estimator(Estimator, generate_only=False, strict_mode=True): """Check if estimator adheres to scikit-learn conventions. This estimator will run an extensive test-suite for input validation, @@ -457,7 +479,7 @@ def check_estimator(Estimator, generate_only=False): Passing a class was deprecated in version 0.23, and support for classes was removed in 0.24. - generate_only : bool, optional (default=False) + generate_only : bool, default=False When `False`, checks are evaluated when `check_estimator` is called. When `True`, `check_estimator` returns a generator that yields (estimator, check) tuples. The check is run by calling @@ -465,6 +487,12 @@ def check_estimator(Estimator, generate_only=False): .. versionadded:: 0.22 + strict_mode : bool, default=True + If False, the strict checks will be treated as if they were in the + estimator's `_xfails_checks` tag: they will be ignored with a + warning. The list of strict checks is at TODO. See TODO link for more + info on the `_xfails_check` tag. + Returns ------- checks_generator : generator @@ -481,7 +509,8 @@ def check_estimator(Estimator, generate_only=False): name = type(estimator).__name__ checks_generator = ((estimator, - partial(_skip_if_xfail(estimator, check), name)) + partial(_skip_if_xfail(estimator, check, strict_mode), + name)) for check in _yield_all_checks(estimator)) if generate_only: @@ -3026,3 +3055,8 @@ def check_requires_y_none(name, estimator_orig): except ValueError as ve: if not any(msg in str(ve) for msg in expected_err_msgs): warnings.warn(warning_msg, FutureWarning) + + +_STRICT_CHECKS = set([ + check_n_features_in, # arbitrary, we can decide on actual list later? +]) From 664552f22581a4e6ba71a5448e900ce56b572f77 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 19 May 2020 09:42:42 -0400 Subject: [PATCH 02/15] different names --- sklearn/utils/estimator_checks.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 289f6b2e10b69..c142db1395f1e 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -339,9 +339,9 @@ def _construct_instance(Estimator): return estimator -def _mark_xfail_checks(estimator, check, strict_mode, pytest): - """Mark (estimator, check) pairs with xfail according to the - _xfail_checks_ tag""" +def _maybe_mark_xfail(estimator, check, strict_mode, pytest): + # Mark (estimator, check) pairs as XFAIL if the check is in the + # _xfail_checks_ tag or if it's a strict check and strict_mode=False. xfail_checks = _get_xfail_checks(estimator, strict_mode) check_name = _set_check_estimator_ids(check) @@ -355,9 +355,9 @@ def _mark_xfail_checks(estimator, check, strict_mode, pytest): marks=pytest.mark.xfail(reason=reason)) -def _skip_if_xfail(estimator, check, strict_mode): +def _maybe_skip(estimator, check, strict_mode): # wrap a check so that it's skipped with a warning if it's part of the - # xfail_checks tag. + # xfail_checks tag, or if it's a strict check and strict_mode=False xfail_checks = _get_xfail_checks(estimator, strict_mode) check_name = _set_check_estimator_ids(check) @@ -444,7 +444,7 @@ def parametrize_with_checks(estimators, strict_mode=True): for check in _yield_all_checks(estimator)) checks_with_marks = ( - _mark_xfail_checks(estimator, check, strict_mode, pytest) + _maybe_mark_xfail(estimator, check, strict_mode, pytest) for estimator, check in checks_generator) return pytest.mark.parametrize("estimator, check", checks_with_marks, @@ -508,10 +508,10 @@ def check_estimator(Estimator, generate_only=False, strict_mode=True): estimator = Estimator name = type(estimator).__name__ - checks_generator = ((estimator, - partial(_skip_if_xfail(estimator, check, strict_mode), - name)) - for check in _yield_all_checks(estimator)) + checks_generator = ( + (estimator, partial(_maybe_skip(estimator, check, strict_mode), name)) + for check in _yield_all_checks(estimator) + ) if generate_only: return checks_generator From ecff04c5ceedc9bc10166955881452eccfabd66c Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 19 May 2020 09:46:58 -0400 Subject: [PATCH 03/15] Comments --- sklearn/utils/estimator_checks.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index c142db1395f1e..da7a3145ec621 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -342,6 +342,8 @@ def _construct_instance(Estimator): def _maybe_mark_xfail(estimator, check, strict_mode, pytest): # Mark (estimator, check) pairs as XFAIL if the check is in the # _xfail_checks_ tag or if it's a strict check and strict_mode=False. + # This is similar to _maybe_skip(), but this one is used by + # @parametrize_with_checks() instead of check_estimator() xfail_checks = _get_xfail_checks(estimator, strict_mode) check_name = _set_check_estimator_ids(check) @@ -356,8 +358,11 @@ def _maybe_mark_xfail(estimator, check, strict_mode, pytest): def _maybe_skip(estimator, check, strict_mode): - # wrap a check so that it's skipped with a warning if it's part of the + # Wrap a check so that it's skipped with a warning if it's part of the # xfail_checks tag, or if it's a strict check and strict_mode=False + # This is similar to _maybe_mark_xfail(), but this one is used by + # check_estimator() instead of @parametrize_with_checks which requires + # pytest xfail_checks = _get_xfail_checks(estimator, strict_mode) check_name = _set_check_estimator_ids(check) From c7c5c8d3dec6082a20e61fcde675f2b7d6a7f2d4 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 20 May 2020 09:19:00 -0400 Subject: [PATCH 04/15] some clearning --- sklearn/tests/test_common.py | 2 +- sklearn/utils/estimator_checks.py | 12 +++++++----- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index c4e064f0fb616..53244d0de6f0f 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -225,5 +225,5 @@ def test_strict_mode_check_estimator(): @parametrize_with_checks([LogisticRegression(), NuSVC()], strict_mode=False) def test_strict_mode_parametrize_with_checks(estimator, check): - # Not sure how to test parametrize_with_checks correctly?? + # Ideally we should assert that the strict checks are Xfailed... check(estimator) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index da7a3145ec621..0e647100f8c6b 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -341,7 +341,7 @@ def _construct_instance(Estimator): def _maybe_mark_xfail(estimator, check, strict_mode, pytest): # Mark (estimator, check) pairs as XFAIL if the check is in the - # _xfail_checks_ tag or if it's a strict check and strict_mode=False. + # _xfail_checks tag or if it's a strict check and strict_mode=False. # This is similar to _maybe_skip(), but this one is used by # @parametrize_with_checks() instead of check_estimator() xfail_checks = _get_xfail_checks(estimator, strict_mode) @@ -415,8 +415,9 @@ def parametrize_with_checks(estimators, strict_mode=True): strict_mode : bool, default=True If False, the strict checks will be treated as if they were in the estimators' `_xfails_checks` tag: they will be marked as `xfail` for - pytest. The list of strict checks is at TODO. See TODO link for more - info on the `_xfails_check` tag. + pytest. See :ref:`estimator_tags` for more info on the + `_xfails_check` tag. The set of strict checks is in + `sklearn.utils.estimator_checks._STRICT_CHECKS`. Returns ------- @@ -495,8 +496,9 @@ def check_estimator(Estimator, generate_only=False, strict_mode=True): strict_mode : bool, default=True If False, the strict checks will be treated as if they were in the estimator's `_xfails_checks` tag: they will be ignored with a - warning. The list of strict checks is at TODO. See TODO link for more - info on the `_xfails_check` tag. + warning. See :ref:`estimator_tags` for more info on the + `_xfails_check` tag. The set of strict checks is in + `sklearn.utils.estimator_checks._STRICT_CHECKS`. Returns ------- From a2b5bf5d4a561c19072b7e05555de90edea35bc8 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 26 May 2020 19:21:21 -0400 Subject: [PATCH 05/15] This is hard --- sklearn/decomposition/_sparse_pca.py | 4 +- sklearn/dummy.py | 4 +- sklearn/neural_network/_rbm.py | 4 +- sklearn/svm/_classes.py | 7 +- sklearn/tests/test_common.py | 38 ++++- sklearn/utils/estimator_checks.py | 214 ++++++++++++++------------- 6 files changed, 157 insertions(+), 114 deletions(-) diff --git a/sklearn/decomposition/_sparse_pca.py b/sklearn/decomposition/_sparse_pca.py index 8f766b734ffab..efe572576ff02 100644 --- a/sklearn/decomposition/_sparse_pca.py +++ b/sklearn/decomposition/_sparse_pca.py @@ -208,8 +208,8 @@ def transform(self, X): def _more_tags(self): return { '_xfail_checks': { - "check_methods_subset_invariance": - "fails for the transform method" + # fails for the transform method" + "check_methods_subset_invariance": {}, } } diff --git a/sklearn/dummy.py b/sklearn/dummy.py index cee7294ab5afd..c8cdda9a23b5f 100644 --- a/sklearn/dummy.py +++ b/sklearn/dummy.py @@ -349,8 +349,8 @@ def _more_tags(self): return { 'poor_score': True, 'no_validation': True, '_xfail_checks': { - 'check_methods_subset_invariance': - 'fails for the predict method' + # fails for the predict method + 'check_methods_subset_invariance': {} } } diff --git a/sklearn/neural_network/_rbm.py b/sklearn/neural_network/_rbm.py index fcb4e90772598..fa45df96c520c 100644 --- a/sklearn/neural_network/_rbm.py +++ b/sklearn/neural_network/_rbm.py @@ -377,7 +377,7 @@ def fit(self, X, y=None): def _more_tags(self): return { '_xfail_checks': { - 'check_methods_subset_invariance': - 'fails for the decision_function method' + # fails for the decision_function method + 'check_methods_subset_invariance': {} } } diff --git a/sklearn/svm/_classes.py b/sklearn/svm/_classes.py index d082c22d0a3bc..334c26a1b5cad 100644 --- a/sklearn/svm/_classes.py +++ b/sklearn/svm/_classes.py @@ -864,9 +864,10 @@ def __init__(self, *, nu=0.5, kernel='rbf', degree=3, gamma='scale', def _more_tags(self): return { '_xfail_checks': { - 'check_methods_subset_invariance': - 'fails for the decision_function method', - 'check_class_weight_classifiers': 'class_weight is ignored.' + # fails for the decision_function method' + 'check_methods_subset_invariance': {}, + # class_weight is ignored + 'check_class_weight_classifiers': {} } } diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index 53244d0de6f0f..9deb8a0c34fec 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -15,7 +15,7 @@ from functools import partial import pytest - +import numpy as np from sklearn.utils import all_estimators from sklearn.utils._testing import ignore_warnings @@ -25,6 +25,8 @@ import sklearn from sklearn.base import BiclusterMixin +from sklearn.decomposition import NMF +from sklearn.utils.validation import check_non_negative, check_array from sklearn.linear_model._base import LinearClassifierMixin from sklearn.linear_model import LogisticRegression from sklearn.svm import NuSVC @@ -33,7 +35,7 @@ from sklearn.utils.estimator_checks import ( _construct_instance, _set_checking_parameters, - _set_check_estimator_ids, + _get_check_estimator_ids, check_class_weight_balanced_linear_classifier, parametrize_with_checks) @@ -60,8 +62,8 @@ def _sample_func(x, y=1): "LogisticRegression(class_weight='balanced',random_state=1," "solver='newton-cg',warm_start=True)") ]) -def test_set_check_estimator_ids(val, expected): - assert _set_check_estimator_ids(val) == expected +def test_get_check_estimator_ids(val, expected): + assert _get_check_estimator_ids(val) == expected def _tested_estimators(): @@ -207,6 +209,21 @@ def test_class_support_removed(): parametrize_with_checks([LogisticRegression]) +class MyNMFWithBadErrorMessage(NMF): + # Same as NMF but raises an uninformative error message if X has negative + # value. This estimator would fail the check suite in strict mode, + # specifically it would fail check_fit_non_negative + def fit(self, X, y=None, **params): + X = check_array(X, accept_sparse=('csr', 'csc'), + dtype=[np.float64, np.float32]) + try: + check_non_negative(X, whom='') + except ValueError: + raise ValueError("Some non-informative error msg") + + return super().fit(X, y, **params) + + def test_strict_mode_check_estimator(): # Make sure the strict checks are properly ignored when strict mode is off # in check_estimator. @@ -214,7 +231,7 @@ def test_strict_mode_check_estimator(): with pytest.warns(SkipTestWarning): # LogisticRegression has no _xfail_checks, but check_n_features_in is - # still skipped because it's a strict check + # still skipped because it's a fully strict check check_estimator(LogisticRegression(), strict_mode=False) with pytest.warns(SkipTestWarning): @@ -222,8 +239,17 @@ def test_strict_mode_check_estimator(): # with the other checks in the tag. check_estimator(NuSVC(), strict_mode=False) + # MyNMF will fail check_fit_non_negative in strict mode, but it will pass + # in non-strict mode which doesn't check the exact error message. + with pytest.raises(AssertionError, match='does not match'): + check_estimator(MyNMFWithBadErrorMessage(), strict_mode=True) + check_estimator(MyNMFWithBadErrorMessage(), strict_mode=False) + -@parametrize_with_checks([LogisticRegression(), NuSVC()], strict_mode=False) +@parametrize_with_checks([LogisticRegression(), + NuSVC(), + MyNMFWithBadErrorMessage()], + strict_mode=False) def test_strict_mode_parametrize_with_checks(estimator, check): # Ideally we should assert that the strict checks are Xfailed... check(estimator) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 340d50378b553..606fe3fe414b2 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -134,7 +134,7 @@ def _yield_classifier_checks(classifier): @ignore_warnings(category=FutureWarning) -def check_supervised_y_no_nan(name, estimator_orig): +def check_supervised_y_no_nan(name, estimator_orig, strict_mode=True): # Checks that the Estimator targets are not NaN. estimator = clone(estimator_orig) rng = np.random.RandomState(888) @@ -281,14 +281,14 @@ def _yield_all_checks(estimator): yield check_fit_non_negative -def _set_check_estimator_ids(obj): +def _get_check_estimator_ids(obj): """Create pytest ids for checks. When `obj` is an estimator, this returns the pprint version of the estimator (with `print_changed_only=True`). When `obj` is a function, the name of the function is returned with its keyworld arguments. - `_set_check_estimator_ids` is designed to be used as the `id` in + `_get_check_estimator_ids` is designed to be used as the `id` in `pytest.mark.parametrize` where `check_estimator(..., generate_only=True)` is yielding estimators and checks. @@ -343,15 +343,12 @@ def _maybe_mark_xfail(estimator, check, strict_mode, pytest): # _xfail_checks tag or if it's a strict check and strict_mode=False. # This is similar to _maybe_skip(), but this one is used by # @parametrize_with_checks() instead of check_estimator() - xfail_checks = _get_xfail_checks(estimator, strict_mode) - check_name = _set_check_estimator_ids(check) - if check_name not in xfail_checks: - # check isn't part of the xfail_checks tags, just return it + if not _should_be_skipped_or_marked(estimator, check, strict_mode): return estimator, check else: - # check is in the tag, mark it as xfail for pytest - reason = xfail_checks[check_name] + reason = ('This check is in the _xfail_checks tag, or it is ' + 'a strict check and strict mode is off.') return pytest.param(estimator, check, marks=pytest.mark.xfail(reason=reason)) @@ -362,12 +359,13 @@ def _maybe_skip(estimator, check, strict_mode): # This is similar to _maybe_mark_xfail(), but this one is used by # check_estimator() instead of @parametrize_with_checks which requires # pytest - xfail_checks = _get_xfail_checks(estimator, strict_mode) - check_name = _set_check_estimator_ids(check) - if check_name not in xfail_checks: + if not _should_be_skipped_or_marked(estimator, check, strict_mode): return check + check_name = (check.func.__name__ if isinstance(check, partial) + else check.__name__) + @wraps(check) def wrapped(*args, **kwargs): raise SkipTest( @@ -377,20 +375,17 @@ def wrapped(*args, **kwargs): return wrapped -def _get_xfail_checks(estimator, strict_mode): - # Return the checks that are in the estimator's _xfail_checks tag, along - # with the strict checks if strict_mode is False. - xfail_checks = estimator._get_tags()['_xfail_checks'] or {} +def _should_be_skipped_or_marked(estimator, check, strict_mode): - if not strict_mode: - strict_checks = { - _set_check_estimator_ids(check): - 'The check is strict and strict mode is off' # the reason - for check in _STRICT_CHECKS - } - xfail_checks.update(strict_checks) + check_name = (check.func.__name__ if isinstance(check, partial) + else check.__name__) + + xfail_checks = estimator._get_tags()['_xfail_checks'] or {} - return xfail_checks + return ( + check_name in xfail_checks or + check_name in _FULLY_STRICT_CHECKS and not strict_mode + ) def parametrize_with_checks(estimators, strict_mode=True): @@ -444,16 +439,17 @@ def parametrize_with_checks(estimators, strict_mode=True): names = (type(estimator).__name__ for estimator in estimators) - checks_generator = ((estimator, partial(check, name)) - for name, estimator in zip(names, estimators) - for check in _yield_all_checks(estimator)) + checks_generator = ( + (estimator, partial(check, name, strict_mode=strict_mode)) + for name, estimator in zip(names, estimators) + for check in _yield_all_checks(estimator)) checks_with_marks = ( _maybe_mark_xfail(estimator, check, strict_mode, pytest) for estimator, check in checks_generator) return pytest.mark.parametrize("estimator, check", checks_with_marks, - ids=_set_check_estimator_ids) + ids=_get_check_estimator_ids) def check_estimator(Estimator, generate_only=False, strict_mode=True): @@ -515,7 +511,8 @@ def check_estimator(Estimator, generate_only=False, strict_mode=True): name = type(estimator).__name__ checks_generator = ( - (estimator, partial(_maybe_skip(estimator, check, strict_mode), name)) + (estimator, partial(_maybe_skip(estimator, check, strict_mode), + name, strict_mode=strict_mode)) for check in _yield_all_checks(estimator) ) @@ -724,7 +721,7 @@ def _generate_sparse_matrix(X_csr): yield sparse_format + "_64", X -def check_estimator_sparse_data(name, estimator_orig): +def check_estimator_sparse_data(name, estimator_orig, strict_mode=True): rng = np.random.RandomState(0) X = rng.rand(40, 10) X[X < .8] = 0 @@ -783,7 +780,7 @@ def check_estimator_sparse_data(name, estimator_orig): @ignore_warnings(category=FutureWarning) -def check_sample_weights_pandas_series(name, estimator_orig): +def check_sample_weights_pandas_series(name, estimator_orig, strict_mode=True): # check that estimators will accept a 'sample_weight' parameter of # type pandas.Series in the 'fit' function. estimator = clone(estimator_orig) @@ -810,7 +807,7 @@ def check_sample_weights_pandas_series(name, estimator_orig): @ignore_warnings(category=(FutureWarning)) -def check_sample_weights_not_an_array(name, estimator_orig): +def check_sample_weights_not_an_array(name, estimator_orig, strict_mode=True): # check that estimators will accept a 'sample_weight' parameter of # type _NotAnArray in the 'fit' function. estimator = clone(estimator_orig) @@ -827,7 +824,7 @@ def check_sample_weights_not_an_array(name, estimator_orig): @ignore_warnings(category=(FutureWarning)) -def check_sample_weights_list(name, estimator_orig): +def check_sample_weights_list(name, estimator_orig, strict_mode=True): # check that estimators will accept a 'sample_weight' parameter of # type list in the 'fit' function. if has_fit_parameter(estimator_orig, "sample_weight"): @@ -847,7 +844,7 @@ def check_sample_weights_list(name, estimator_orig): @ignore_warnings(category=FutureWarning) -def check_sample_weights_shape(name, estimator_orig): +def check_sample_weights_shape(name, estimator_orig, strict_mode=True): # check that estimators raise an error if sample_weight # shape mismatches the input if (has_fit_parameter(estimator_orig, "sample_weight") and @@ -872,7 +869,7 @@ def check_sample_weights_shape(name, estimator_orig): @ignore_warnings(category=FutureWarning) -def check_sample_weights_invariance(name, estimator_orig): +def check_sample_weights_invariance(name, estimator_orig, strict_mode=True): # check that the estimators yield same results for # unit weights and no weights if (has_fit_parameter(estimator_orig, "sample_weight") and @@ -910,7 +907,7 @@ def check_sample_weights_invariance(name, estimator_orig): @ignore_warnings(category=(FutureWarning, UserWarning)) -def check_dtype_object(name, estimator_orig): +def check_dtype_object(name, estimator_orig, strict_mode=True): # check that estimators treat dtype object as numeric if possible rng = np.random.RandomState(0) X = _pairwise_estimator_convert_X(rng.rand(40, 10), estimator_orig) @@ -948,7 +945,7 @@ def check_dtype_object(name, estimator_orig): estimator.fit(X, y) -def check_complex_data(name, estimator_orig): +def check_complex_data(name, estimator_orig, strict_mode=True): # check that estimators raise an exception on providing complex data X = np.random.sample(10) + 1j * np.random.sample(10) X = X.reshape(-1, 1) @@ -959,7 +956,7 @@ def check_complex_data(name, estimator_orig): @ignore_warnings -def check_dict_unchanged(name, estimator_orig): +def check_dict_unchanged(name, estimator_orig, strict_mode=True): # this estimator raises # ValueError: Found array with 0 feature(s) (shape=(23, 0)) # while a minimum of 1 is required. @@ -1003,7 +1000,7 @@ def _is_public_parameter(attr): @ignore_warnings(category=FutureWarning) -def check_dont_overwrite_parameters(name, estimator_orig): +def check_dont_overwrite_parameters(name, estimator_orig, strict_mode=True): # check that fit method only changes or sets private attributes if hasattr(estimator_orig.__init__, "deprecated_original"): # to not check deprecated classes @@ -1058,7 +1055,7 @@ def check_dont_overwrite_parameters(name, estimator_orig): @ignore_warnings(category=FutureWarning) -def check_fit2d_predict1d(name, estimator_orig): +def check_fit2d_predict1d(name, estimator_orig, strict_mode=True): # check by fitting a 2d array and predicting with a 1d array rnd = np.random.RandomState(0) X = 3 * rnd.uniform(size=(20, 3)) @@ -1108,7 +1105,7 @@ def _apply_on_subsets(func, X): @ignore_warnings(category=FutureWarning) -def check_methods_subset_invariance(name, estimator_orig): +def check_methods_subset_invariance(name, estimator_orig, strict_mode=True): # check that method gives invariant results if applied # on mini batches or the whole set rnd = np.random.RandomState(0) @@ -1142,7 +1139,7 @@ def check_methods_subset_invariance(name, estimator_orig): @ignore_warnings -def check_fit2d_1sample(name, estimator_orig): +def check_fit2d_1sample(name, estimator_orig, strict_mode=True): # Check that fitting a 2d array with only one sample either works or # returns an informative message. The error message should either mention # the number of samples or the number of classes. @@ -1176,7 +1173,7 @@ def check_fit2d_1sample(name, estimator_orig): @ignore_warnings -def check_fit2d_1feature(name, estimator_orig): +def check_fit2d_1feature(name, estimator_orig, strict_mode=True): # check fitting a 2d array with only 1 feature either works or returns # informative message rnd = np.random.RandomState(0) @@ -1210,7 +1207,7 @@ def check_fit2d_1feature(name, estimator_orig): @ignore_warnings -def check_fit1d(name, estimator_orig): +def check_fit1d(name, estimator_orig, strict_mode=True): # check fitting 1d X array raises a ValueError rnd = np.random.RandomState(0) X = 3 * rnd.uniform(size=(20)) @@ -1232,7 +1229,8 @@ def check_fit1d(name, estimator_orig): @ignore_warnings(category=FutureWarning) -def check_transformer_general(name, transformer, readonly_memmap=False): +def check_transformer_general(name, transformer, readonly_memmap=False, + strict_mode=True): X, y = make_blobs(n_samples=30, centers=[[0, 0, 0], [1, 1, 1]], random_state=0, n_features=2, cluster_std=0.1) X = StandardScaler().fit_transform(X) @@ -1246,7 +1244,7 @@ def check_transformer_general(name, transformer, readonly_memmap=False): @ignore_warnings(category=FutureWarning) -def check_transformer_data_not_an_array(name, transformer): +def check_transformer_data_not_an_array(name, transformer, strict_mode=True): X, y = make_blobs(n_samples=30, centers=[[0, 0, 0], [1, 1, 1]], random_state=0, n_features=2, cluster_std=0.1) X = StandardScaler().fit_transform(X) @@ -1262,7 +1260,7 @@ def check_transformer_data_not_an_array(name, transformer): @ignore_warnings(category=FutureWarning) -def check_transformers_unfitted(name, transformer): +def check_transformers_unfitted(name, transformer, strict_mode=True): X, y = _boston_subset() transformer = clone(transformer) @@ -1273,7 +1271,7 @@ def check_transformers_unfitted(name, transformer): transformer.transform(X) -def _check_transformer(name, transformer_orig, X, y): +def _check_transformer(name, transformer_orig, X, y, strict_mode=True): n_samples, n_features = np.asarray(X).shape transformer = clone(transformer_orig) set_random_state(transformer) @@ -1352,7 +1350,7 @@ def _check_transformer(name, transformer_orig, X, y): @ignore_warnings -def check_pipeline_consistency(name, estimator_orig): +def check_pipeline_consistency(name, estimator_orig, strict_mode=True): if estimator_orig._get_tags()['non_deterministic']: msg = name + ' is non deterministic' raise SkipTest(msg) @@ -1381,7 +1379,7 @@ def check_pipeline_consistency(name, estimator_orig): @ignore_warnings -def check_fit_score_takes_y(name, estimator_orig): +def check_fit_score_takes_y(name, estimator_orig, strict_mode=True): # check that all estimators accept an optional y # in fit and score so they can be used in pipelines rnd = np.random.RandomState(0) @@ -1413,7 +1411,7 @@ def check_fit_score_takes_y(name, estimator_orig): @ignore_warnings -def check_estimators_dtypes(name, estimator_orig): +def check_estimators_dtypes(name, estimator_orig, strict_mode=True): rnd = np.random.RandomState(0) X_train_32 = 3 * rnd.uniform(size=(20, 5)).astype(np.float32) X_train_32 = _pairwise_estimator_convert_X(X_train_32, estimator_orig) @@ -1438,7 +1436,8 @@ def check_estimators_dtypes(name, estimator_orig): @ignore_warnings(category=FutureWarning) -def check_estimators_empty_data_messages(name, estimator_orig): +def check_estimators_empty_data_messages(name, estimator_orig, + strict_mode=True): e = clone(estimator_orig) set_random_state(e, 1) @@ -1461,7 +1460,7 @@ def check_estimators_empty_data_messages(name, estimator_orig): @ignore_warnings(category=FutureWarning) -def check_estimators_nan_inf(name, estimator_orig): +def check_estimators_nan_inf(name, estimator_orig, strict_mode=True): # Checks that Estimator X's do not contain NaN or inf. rnd = np.random.RandomState(0) X_train_finite = _pairwise_estimator_convert_X(rnd.uniform(size=(10, 3)), @@ -1532,7 +1531,7 @@ def check_estimators_nan_inf(name, estimator_orig): @ignore_warnings -def check_nonsquare_error(name, estimator_orig): +def check_nonsquare_error(name, estimator_orig, strict_mode=True): """Test that error is thrown when non-square data provided""" X, y = make_blobs(n_samples=20, n_features=10) @@ -1545,7 +1544,7 @@ def check_nonsquare_error(name, estimator_orig): @ignore_warnings -def check_estimators_pickle(name, estimator_orig): +def check_estimators_pickle(name, estimator_orig, strict_mode=True): """Test that we can pickle all estimators""" check_methods = ["predict", "transform", "decision_function", "predict_proba"] @@ -1589,7 +1588,8 @@ def check_estimators_pickle(name, estimator_orig): @ignore_warnings(category=FutureWarning) -def check_estimators_partial_fit_n_features(name, estimator_orig): +def check_estimators_partial_fit_n_features(name, estimator_orig, + strict_mode=True): # check if number of features changes between calls to partial_fit. if not hasattr(estimator_orig, 'partial_fit'): return @@ -1615,7 +1615,7 @@ def check_estimators_partial_fit_n_features(name, estimator_orig): @ignore_warnings(category=FutureWarning) -def check_classifier_multioutput(name, estimator): +def check_classifier_multioutput(name, estimator, strict_mode=True): n_samples, n_labels, n_classes = 42, 5, 3 tags = estimator._get_tags() estimator = clone(estimator) @@ -1673,7 +1673,7 @@ def check_classifier_multioutput(name, estimator): @ignore_warnings(category=FutureWarning) -def check_regressor_multioutput(name, estimator): +def check_regressor_multioutput(name, estimator, strict_mode=True): estimator = clone(estimator) n_samples = n_features = 10 @@ -1696,7 +1696,8 @@ def check_regressor_multioutput(name, estimator): @ignore_warnings(category=FutureWarning) -def check_clustering(name, clusterer_orig, readonly_memmap=False): +def check_clustering(name, clusterer_orig, readonly_memmap=False, + strict_mode=True): clusterer = clone(clusterer_orig) X, y = make_blobs(n_samples=50, random_state=1) X, y = shuffle(X, y, random_state=7) @@ -1755,7 +1756,8 @@ def check_clustering(name, clusterer_orig, readonly_memmap=False): @ignore_warnings(category=FutureWarning) -def check_clusterer_compute_labels_predict(name, clusterer_orig): +def check_clusterer_compute_labels_predict(name, clusterer_orig, + strict_mode=True): """Check that predict is invariant of compute_labels""" X, y = make_blobs(n_samples=20, random_state=0) clusterer = clone(clusterer_orig) @@ -1770,7 +1772,7 @@ def check_clusterer_compute_labels_predict(name, clusterer_orig): @ignore_warnings(category=FutureWarning) -def check_classifiers_one_label(name, classifier_orig): +def check_classifiers_one_label(name, classifier_orig, strict_mode=True): error_string_fit = "Classifier can't train when only one class is present." error_string_predict = ("Classifier can't predict when only one class is " "present.") @@ -1805,7 +1807,7 @@ def check_classifiers_one_label(name, classifier_orig): @ignore_warnings # Warnings are raised by decision function def check_classifiers_train(name, classifier_orig, readonly_memmap=False, - X_dtype='float64'): + X_dtype='float64', strict_mode=True): X_m, y_m = make_blobs(n_samples=300, random_state=0) X_m = X_m.astype(X_dtype) X_m, y_m = shuffle(X_m, y_m, random_state=7) @@ -1929,7 +1931,8 @@ def check_classifiers_train(name, classifier_orig, readonly_memmap=False, assert_array_equal(np.argsort(y_log_prob), np.argsort(y_prob)) -def check_outlier_corruption(num_outliers, expected_outliers, decision): +def check_outlier_corruption(num_outliers, expected_outliers, decision, + strict_mode=True): # Check for deviation from the precise given contamination level that may # be due to ties in the anomaly scores. if num_outliers < expected_outliers: @@ -1949,7 +1952,8 @@ def check_outlier_corruption(num_outliers, expected_outliers, decision): assert len(np.unique(sorted_decision[start:end])) == 1, msg -def check_outliers_train(name, estimator_orig, readonly_memmap=True): +def check_outliers_train(name, estimator_orig, readonly_memmap=True, + strict_mode=True): n_samples = 300 X, _ = make_blobs(n_samples=n_samples, random_state=0) X = shuffle(X, random_state=7) @@ -2025,8 +2029,9 @@ def check_outliers_train(name, estimator_orig, readonly_memmap=True): @ignore_warnings(category=(FutureWarning)) -def check_classifiers_multilabel_representation_invariance(name, - classifier_orig): +def check_classifiers_multilabel_representation_invariance( + name, classifier_orig, strict_mode=True): + X, y = make_multilabel_classification(n_samples=100, n_features=20, n_classes=5, n_labels=3, length=50, allow_unlabeled=True, @@ -2060,7 +2065,7 @@ def check_classifiers_multilabel_representation_invariance(name, @ignore_warnings(category=FutureWarning) def check_estimators_fit_returns_self(name, estimator_orig, - readonly_memmap=False): + readonly_memmap=False, strict_mode=True): """Check if self is returned when calling fit""" if estimator_orig._get_tags()['binary_only']: n_centers = 2 @@ -2082,7 +2087,7 @@ def check_estimators_fit_returns_self(name, estimator_orig, @ignore_warnings -def check_estimators_unfitted(name, estimator_orig): +def check_estimators_unfitted(name, estimator_orig, strict_mode=True): """Check that predict raises an exception in an unfitted estimator. Unfitted estimators should raise a NotFittedError. @@ -2098,7 +2103,7 @@ def check_estimators_unfitted(name, estimator_orig): @ignore_warnings(category=FutureWarning) -def check_supervised_y_2d(name, estimator_orig): +def check_supervised_y_2d(name, estimator_orig, strict_mode=True): tags = estimator_orig._get_tags() if tags['multioutput_only']: # These only work on 2d, so this test makes no sense @@ -2138,7 +2143,8 @@ def check_supervised_y_2d(name, estimator_orig): @ignore_warnings -def check_classifiers_predictions(X, y, name, classifier_orig): +def check_classifiers_predictions(X, y, name, classifier_orig, + strict_mode=True): classes = np.unique(y) classifier = clone(classifier_orig) if name == 'BernoulliNB': @@ -2185,7 +2191,7 @@ def _choose_check_classifiers_labels(name, y, y_names): return y if name in ["LabelPropagation", "LabelSpreading"] else y_names -def check_classifiers_classes(name, classifier_orig): +def check_classifiers_classes(name, classifier_orig, strict_mode=True): X_multiclass, y_multiclass = make_blobs(n_samples=30, random_state=0, cluster_std=0.1) X_multiclass, y_multiclass = shuffle(X_multiclass, y_multiclass, @@ -2223,7 +2229,7 @@ def check_classifiers_classes(name, classifier_orig): @ignore_warnings(category=FutureWarning) -def check_regressors_int(name, regressor_orig): +def check_regressors_int(name, regressor_orig, strict_mode=True): X, _ = _boston_subset() X = _pairwise_estimator_convert_X(X[:50], regressor_orig) rnd = np.random.RandomState(0) @@ -2252,7 +2258,7 @@ def check_regressors_int(name, regressor_orig): @ignore_warnings(category=FutureWarning) def check_regressors_train(name, regressor_orig, readonly_memmap=False, - X_dtype=np.float64): + X_dtype=np.float64, strict_mode=True): X, y = _boston_subset() X = X.astype(X_dtype) X = _pairwise_estimator_convert_X(X, regressor_orig) @@ -2298,7 +2304,8 @@ def check_regressors_train(name, regressor_orig, readonly_memmap=False, @ignore_warnings -def check_regressors_no_decision_function(name, regressor_orig): +def check_regressors_no_decision_function(name, regressor_orig, + strict_mode=True): # checks whether regressors have decision_function or predict_proba rng = np.random.RandomState(0) regressor = clone(regressor_orig) @@ -2324,7 +2331,7 @@ def check_regressors_no_decision_function(name, regressor_orig): @ignore_warnings(category=FutureWarning) -def check_class_weight_classifiers(name, classifier_orig): +def check_class_weight_classifiers(name, classifier_orig, strict_mode=True): if classifier_orig._get_tags()['binary_only']: problems = [2] @@ -2370,7 +2377,8 @@ def check_class_weight_classifiers(name, classifier_orig): @ignore_warnings(category=FutureWarning) def check_class_weight_balanced_classifiers(name, classifier_orig, X_train, - y_train, X_test, y_test, weights): + y_train, X_test, y_test, weights, + strict_mode=True): classifier = clone(classifier_orig) if hasattr(classifier, "n_iter"): classifier.set_params(n_iter=100) @@ -2389,7 +2397,8 @@ def check_class_weight_balanced_classifiers(name, classifier_orig, X_train, @ignore_warnings(category=FutureWarning) -def check_class_weight_balanced_linear_classifier(name, Classifier): +def check_class_weight_balanced_linear_classifier(name, Classifier, + strict_mode=True): """Test class weights with non-contiguous class labels.""" # this is run on classes, not instances, though this should be changed X = np.array([[-1.0, -1.0], [-1.0, 0], [-.8, -1.0], @@ -2428,7 +2437,7 @@ def check_class_weight_balanced_linear_classifier(name, Classifier): @ignore_warnings(category=FutureWarning) -def check_estimators_overwrite_params(name, estimator_orig): +def check_estimators_overwrite_params(name, estimator_orig, strict_mode=True): if estimator_orig._get_tags()['binary_only']: n_centers = 2 else: @@ -2467,7 +2476,7 @@ def check_estimators_overwrite_params(name, estimator_orig): @ignore_warnings(category=FutureWarning) -def check_no_attributes_set_in_init(name, estimator_orig): +def check_no_attributes_set_in_init(name, estimator_orig, strict_mode=True): """Check setting during init. """ estimator = clone(estimator_orig) if hasattr(type(estimator).__init__, "deprecated_original"): @@ -2501,7 +2510,7 @@ def check_no_attributes_set_in_init(name, estimator_orig): @ignore_warnings(category=FutureWarning) -def check_sparsify_coefficients(name, estimator_orig): +def check_sparsify_coefficients(name, estimator_orig, strict_mode=True): X = np.array([[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1], [-1, -2], [2, 2], [-2, -2]]) y = [1, 1, 1, 2, 2, 2, 3, 3, 3] @@ -2524,7 +2533,7 @@ def check_sparsify_coefficients(name, estimator_orig): @ignore_warnings(category=FutureWarning) -def check_classifier_data_not_an_array(name, estimator_orig): +def check_classifier_data_not_an_array(name, estimator_orig, strict_mode=True): X = np.array([[3, 0], [0, 1], [0, 2], [1, 1], [1, 2], [2, 1], [0, 3], [1, 0], [2, 0], [4, 4], [2, 3], [3, 2]]) X = _pairwise_estimator_convert_X(X, estimator_orig) @@ -2536,7 +2545,7 @@ def check_classifier_data_not_an_array(name, estimator_orig): @ignore_warnings(category=FutureWarning) -def check_regressor_data_not_an_array(name, estimator_orig): +def check_regressor_data_not_an_array(name, estimator_orig, strict_mode=True): X, y = _boston_subset(n_samples=50) X = _pairwise_estimator_convert_X(X, estimator_orig) y = _enforce_estimator_tags_y(estimator_orig, y) @@ -2546,7 +2555,8 @@ def check_regressor_data_not_an_array(name, estimator_orig): @ignore_warnings(category=FutureWarning) -def check_estimators_data_not_an_array(name, estimator_orig, X, y, obj_type): +def check_estimators_data_not_an_array(name, estimator_orig, X, y, obj_type, + strict_mode=True): if name in CROSS_DECOMPOSITION: raise SkipTest("Skipping check_estimators_data_not_an_array " "for cross decomposition module as estimators " @@ -2588,7 +2598,7 @@ def check_estimators_data_not_an_array(name, estimator_orig, X, y, obj_type): assert_allclose(pred1, pred2, atol=1e-2, err_msg=name) -def check_parameters_default_constructible(name, Estimator): +def check_parameters_default_constructible(name, Estimator, strict_mode=True): # test default-constructibility # get rid of deprecation warnings @@ -2688,7 +2698,8 @@ def _enforce_estimator_tags_x(estimator, X): @ignore_warnings(category=FutureWarning) -def check_non_transformer_estimators_n_iter(name, estimator_orig): +def check_non_transformer_estimators_n_iter(name, estimator_orig, + strict_mode=True): # Test that estimators that are not transformers with a parameter # max_iter, return the attribute of n_iter_ at least 1. @@ -2722,7 +2733,7 @@ def check_non_transformer_estimators_n_iter(name, estimator_orig): @ignore_warnings(category=FutureWarning) -def check_transformer_n_iter(name, estimator_orig): +def check_transformer_n_iter(name, estimator_orig, strict_mode=True): # Test that transformers with a parameter max_iter, return the # attribute of n_iter_ at least 1. estimator = clone(estimator_orig) @@ -2748,7 +2759,7 @@ def check_transformer_n_iter(name, estimator_orig): @ignore_warnings(category=FutureWarning) -def check_get_params_invariance(name, estimator_orig): +def check_get_params_invariance(name, estimator_orig, strict_mode=True): # Checks if get_params(deep=False) is a subset of get_params(deep=True) e = clone(estimator_orig) @@ -2760,7 +2771,7 @@ def check_get_params_invariance(name, estimator_orig): @ignore_warnings(category=FutureWarning) -def check_set_params(name, estimator_orig): +def check_set_params(name, estimator_orig, strict_mode=True): # Check that get_params() returns the same thing # before and after set_params() with some fuzz estimator = clone(estimator_orig) @@ -2814,7 +2825,8 @@ def check_set_params(name, estimator_orig): @ignore_warnings(category=FutureWarning) -def check_classifiers_regression_target(name, estimator_orig): +def check_classifiers_regression_target(name, estimator_orig, + strict_mode=True): # Check if classifier throws an exception when fed regression targets X, y = load_boston(return_X_y=True) @@ -2825,7 +2837,7 @@ def check_classifiers_regression_target(name, estimator_orig): @ignore_warnings(category=FutureWarning) -def check_decision_proba_consistency(name, estimator_orig): +def check_decision_proba_consistency(name, estimator_orig, strict_mode=True): # Check whether an estimator having both decision_function and # predict_proba methods has outputs with perfect rank correlation. @@ -2847,7 +2859,7 @@ def check_decision_proba_consistency(name, estimator_orig): assert_array_equal(rankdata(a), rankdata(b)) -def check_outliers_fit_predict(name, estimator_orig): +def check_outliers_fit_predict(name, estimator_orig, strict_mode=True): # Check fit_predict for outlier detectors. n_samples = 300 @@ -2894,17 +2906,20 @@ def check_outliers_fit_predict(name, estimator_orig): assert_raises(ValueError, estimator.fit_predict, X) -def check_fit_non_negative(name, estimator_orig): +def check_fit_non_negative(name, estimator_orig, strict_mode=True): # Check that proper warning is raised for non-negative X # when tag requires_positive_X is present X = np.array([[-1., 1], [-1., 1]]) y = np.array([1, 2]) estimator = clone(estimator_orig) - assert_raises_regex(ValueError, "Negative values in data passed to", - estimator.fit, X, y) + if strict_mode: + assert_raises_regex(ValueError, "Negative values in data passed to", + estimator.fit, X, y) + else: # Don't check error message if strict mode is off + assert_raises(ValueError, estimator.fit, X, y) -def check_fit_idempotent(name, estimator_orig): +def check_fit_idempotent(name, estimator_orig, strict_mode=True): # Check that est.fit(X) is the same as est.fit(X).fit(X). Ideally we would # check that the estimated parameters during training (e.g. coefs_) are # the same, but having a universal comparison function for those @@ -2959,7 +2974,7 @@ def check_fit_idempotent(name, estimator_orig): ) -def check_n_features_in(name, estimator_orig): +def check_n_features_in(name, estimator_orig, strict_mode=True): # Make sure that n_features_in_ attribute doesn't exist until fit is # called, and that its value is correct. @@ -2997,7 +3012,7 @@ def check_n_features_in(name, estimator_orig): ) -def check_requires_y_none(name, estimator_orig): +def check_requires_y_none(name, estimator_orig, strict_mode=True): # Make sure that an estimator with requires_y=True fails gracefully when # given y=None @@ -3029,6 +3044,7 @@ def check_requires_y_none(name, estimator_orig): warnings.warn(warning_msg, FutureWarning) -_STRICT_CHECKS = set([ - check_n_features_in, # arbitrary, we can decide on actual list later? +# set of checks that are completely strict, i.e. they have no non-strict part +_FULLY_STRICT_CHECKS = set([ + 'check_n_features_in', ]) From 2e2bebd7caff3e7f62a8e94bc9e90a15a36ff8d0 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 27 May 2020 08:08:18 -0400 Subject: [PATCH 06/15] put back reasons --- sklearn/decomposition/_sparse_pca.py | 4 ++-- sklearn/dummy.py | 4 ++-- sklearn/neural_network/_rbm.py | 4 ++-- sklearn/svm/_classes.py | 7 +++---- sklearn/utils/estimator_checks.py | 25 +++++++++++++++---------- 5 files changed, 24 insertions(+), 20 deletions(-) diff --git a/sklearn/decomposition/_sparse_pca.py b/sklearn/decomposition/_sparse_pca.py index efe572576ff02..ae920a9c697e7 100644 --- a/sklearn/decomposition/_sparse_pca.py +++ b/sklearn/decomposition/_sparse_pca.py @@ -208,8 +208,8 @@ def transform(self, X): def _more_tags(self): return { '_xfail_checks': { - # fails for the transform method" - "check_methods_subset_invariance": {}, + "check_methods_subset_invariance": + 'fails for the transform method', } } diff --git a/sklearn/dummy.py b/sklearn/dummy.py index c8cdda9a23b5f..84d7001ca0cd7 100644 --- a/sklearn/dummy.py +++ b/sklearn/dummy.py @@ -349,8 +349,8 @@ def _more_tags(self): return { 'poor_score': True, 'no_validation': True, '_xfail_checks': { - # fails for the predict method - 'check_methods_subset_invariance': {} + 'check_methods_subset_invariance': + 'fails for the predict method', } } diff --git a/sklearn/neural_network/_rbm.py b/sklearn/neural_network/_rbm.py index fa45df96c520c..52982ceb6e3a0 100644 --- a/sklearn/neural_network/_rbm.py +++ b/sklearn/neural_network/_rbm.py @@ -377,7 +377,7 @@ def fit(self, X, y=None): def _more_tags(self): return { '_xfail_checks': { - # fails for the decision_function method - 'check_methods_subset_invariance': {} + 'check_methods_subset_invariance': + 'fails for the decision_function method', } } diff --git a/sklearn/svm/_classes.py b/sklearn/svm/_classes.py index 334c26a1b5cad..eb9d8e49cf80c 100644 --- a/sklearn/svm/_classes.py +++ b/sklearn/svm/_classes.py @@ -864,10 +864,9 @@ def __init__(self, *, nu=0.5, kernel='rbf', degree=3, gamma='scale', def _more_tags(self): return { '_xfail_checks': { - # fails for the decision_function method' - 'check_methods_subset_invariance': {}, - # class_weight is ignored - 'check_class_weight_classifiers': {} + 'check_methods_subset_invariance': + 'fails for the decision_function method', + 'check_class_weight_classifiers': 'class_weight is ignored' } } diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 606fe3fe414b2..548ec14b195a3 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -344,11 +344,11 @@ def _maybe_mark_xfail(estimator, check, strict_mode, pytest): # This is similar to _maybe_skip(), but this one is used by # @parametrize_with_checks() instead of check_estimator() - if not _should_be_skipped_or_marked(estimator, check, strict_mode): + should_be_marked, reason = _should_be_skipped_or_marked(estimator, check, + strict_mode) + if not should_be_marked: return estimator, check else: - reason = ('This check is in the _xfail_checks tag, or it is ' - 'a strict check and strict mode is off.') return pytest.param(estimator, check, marks=pytest.mark.xfail(reason=reason)) @@ -359,8 +359,9 @@ def _maybe_skip(estimator, check, strict_mode): # This is similar to _maybe_mark_xfail(), but this one is used by # check_estimator() instead of @parametrize_with_checks which requires # pytest - - if not _should_be_skipped_or_marked(estimator, check, strict_mode): + should_be_skipped, reason = _should_be_skipped_or_marked(estimator, check, + strict_mode) + if not should_be_skipped: return check check_name = (check.func.__name__ if isinstance(check, partial) @@ -369,7 +370,8 @@ def _maybe_skip(estimator, check, strict_mode): @wraps(check) def wrapped(*args, **kwargs): raise SkipTest( - f"Skipping {check_name} for {estimator.__class__.__name__}" + f"Skipping {check_name} for {estimator.__class__.__name__}: " + f"{reason}" ) return wrapped @@ -382,10 +384,13 @@ def _should_be_skipped_or_marked(estimator, check, strict_mode): xfail_checks = estimator._get_tags()['_xfail_checks'] or {} - return ( - check_name in xfail_checks or - check_name in _FULLY_STRICT_CHECKS and not strict_mode - ) + if check_name in _FULLY_STRICT_CHECKS and not strict_mode: + return True, 'The check is fully strict and strict mode is off' + + if check_name in xfail_checks: + return True, xfail_checks[check_name] + + return False, 'placeholder reason that will never be used' def parametrize_with_checks(estimators, strict_mode=True): From 0a61d69ff2c7975fb93e6110e00ff92815a765fe Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 27 May 2020 08:37:37 -0400 Subject: [PATCH 07/15] comments and cleaning --- sklearn/decomposition/_sparse_pca.py | 4 +- sklearn/dummy.py | 2 +- sklearn/neural_network/_rbm.py | 2 +- sklearn/svm/_classes.py | 2 +- sklearn/tests/test_common.py | 43 ++++++++++++++------- sklearn/utils/estimator_checks.py | 57 +++++++++++++++++++--------- 6 files changed, 75 insertions(+), 35 deletions(-) diff --git a/sklearn/decomposition/_sparse_pca.py b/sklearn/decomposition/_sparse_pca.py index ae920a9c697e7..8f766b734ffab 100644 --- a/sklearn/decomposition/_sparse_pca.py +++ b/sklearn/decomposition/_sparse_pca.py @@ -208,8 +208,8 @@ def transform(self, X): def _more_tags(self): return { '_xfail_checks': { - "check_methods_subset_invariance": - 'fails for the transform method', + "check_methods_subset_invariance": + "fails for the transform method" } } diff --git a/sklearn/dummy.py b/sklearn/dummy.py index 84d7001ca0cd7..cee7294ab5afd 100644 --- a/sklearn/dummy.py +++ b/sklearn/dummy.py @@ -350,7 +350,7 @@ def _more_tags(self): 'poor_score': True, 'no_validation': True, '_xfail_checks': { 'check_methods_subset_invariance': - 'fails for the predict method', + 'fails for the predict method' } } diff --git a/sklearn/neural_network/_rbm.py b/sklearn/neural_network/_rbm.py index 52982ceb6e3a0..fcb4e90772598 100644 --- a/sklearn/neural_network/_rbm.py +++ b/sklearn/neural_network/_rbm.py @@ -378,6 +378,6 @@ def _more_tags(self): return { '_xfail_checks': { 'check_methods_subset_invariance': - 'fails for the decision_function method', + 'fails for the decision_function method' } } diff --git a/sklearn/svm/_classes.py b/sklearn/svm/_classes.py index eb9d8e49cf80c..d082c22d0a3bc 100644 --- a/sklearn/svm/_classes.py +++ b/sklearn/svm/_classes.py @@ -866,7 +866,7 @@ def _more_tags(self): '_xfail_checks': { 'check_methods_subset_invariance': 'fails for the decision_function method', - 'check_class_weight_classifiers': 'class_weight is ignored' + 'check_class_weight_classifiers': 'class_weight is ignored.' } } diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index 9deb8a0c34fec..ace385ddf5903 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -225,25 +225,42 @@ def fit(self, X, y=None, **params): def test_strict_mode_check_estimator(): - # Make sure the strict checks are properly ignored when strict mode is off - # in check_estimator. - # We can't check the message because check_estimator doesn't give one. - - with pytest.warns(SkipTestWarning): - # LogisticRegression has no _xfail_checks, but check_n_features_in is - # still skipped because it's a fully strict check + # Tests various conditions for the strict mode of check_estimator() + # Details are in the comments + + # LogisticRegression has no _xfail_checks, so when strict_mode is on, there + # should be no skipped tests. + with pytest.warns(None) as catched_warnings: + check_estimator(LogisticRegression(), strict_mode=True) + assert not any(isinstance(w, SkipTestWarning) for w in catched_warnings) + # When strict mode is off, check_n_features should be skipped because it's + # a fully strict check + with pytest.warns(SkipTestWarning, + match='check is fully strict and strict mode is off'): check_estimator(LogisticRegression(), strict_mode=False) - with pytest.warns(SkipTestWarning): - # NuSVC has some _xfail_checks. check_n_features_in is skipped along - # with the other checks in the tag. + # NuSVC has some _xfail_checks. They should be skipped regardless of + # strict_mode + with pytest.warns(SkipTestWarning, + match='fails for the decision_function method'): + check_estimator(NuSVC(), strict_mode=True) + # When strict mode is off, check_n_features_in is skipped along with the + # rest of the xfail_checks + with pytest.warns(SkipTestWarning, + match='check is fully strict and strict mode is off'): check_estimator(NuSVC(), strict_mode=False) - # MyNMF will fail check_fit_non_negative in strict mode, but it will pass - # in non-strict mode which doesn't check the exact error message. + # MyNMF will fail check_fit_non_negative() in strict mode because it yields + # a bad error message with pytest.raises(AssertionError, match='does not match'): check_estimator(MyNMFWithBadErrorMessage(), strict_mode=True) - check_estimator(MyNMFWithBadErrorMessage(), strict_mode=False) + # However, it should pass the test suite in non-strict mode because when + # strict mode is off, check_fit_non_negative() will not check the exact + # error messsage. (We still assert that the warning from + # check_n_features_in is raised) + with pytest.warns(SkipTestWarning, + match='check is fully strict and strict mode is off'): + check_estimator(MyNMFWithBadErrorMessage(), strict_mode=False) @parametrize_with_checks([LogisticRegression(), diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 548ec14b195a3..cb0b0d4321486 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -339,8 +339,8 @@ def _construct_instance(Estimator): def _maybe_mark_xfail(estimator, check, strict_mode, pytest): - # Mark (estimator, check) pairs as XFAIL if the check is in the - # _xfail_checks tag or if it's a strict check and strict_mode=False. + # Mark (estimator, check) pairs as XFAIL if needed (see conditions in + # strict_mode_xfails_partially_strict_checks()) # This is similar to _maybe_skip(), but this one is used by # @parametrize_with_checks() instead of check_estimator() @@ -354,8 +354,8 @@ def _maybe_mark_xfail(estimator, check, strict_mode, pytest): def _maybe_skip(estimator, check, strict_mode): - # Wrap a check so that it's skipped with a warning if it's part of the - # xfail_checks tag, or if it's a strict check and strict_mode=False + # Wrap a check so that it's skipped if needed (see conditions in + # strict_mode_xfails_partially_strict_checks()) # This is similar to _maybe_mark_xfail(), but this one is used by # check_estimator() instead of @parametrize_with_checks which requires # pytest @@ -378,18 +378,25 @@ def wrapped(*args, **kwargs): def _should_be_skipped_or_marked(estimator, check, strict_mode): + # Return whether a check should be skipped (when using check_estimator()) + # or marked as XFAIL (when using @parametrize_with_checks()), along with a + # reason. + # A check should be skipped or marked if: + # - the check is in the _xfail_checks tag of the estimator + # - the check is fully strict and strict mode is off + # Checks that are only partially strict will not be skipped since we want + # to run their non-strict parts. check_name = (check.func.__name__ if isinstance(check, partial) else check.__name__) xfail_checks = estimator._get_tags()['_xfail_checks'] or {} + if check_name in xfail_checks: + return True, xfail_checks[check_name] if check_name in _FULLY_STRICT_CHECKS and not strict_mode: return True, 'The check is fully strict and strict mode is off' - if check_name in xfail_checks: - return True, xfail_checks[check_name] - return False, 'placeholder reason that will never be used' @@ -412,11 +419,19 @@ def parametrize_with_checks(estimators, strict_mode=True): classes was removed in 0.24. Pass an instance instead. strict_mode : bool, default=True - If False, the strict checks will be treated as if they were in the - estimators' `_xfails_checks` tag: they will be marked as `xfail` for - pytest. See :ref:`estimator_tags` for more info on the - `_xfails_check` tag. The set of strict checks is in - `sklearn.utils.estimator_checks._STRICT_CHECKS`. + If True, the full check suite is run. + If False, only the non-strict part of the check suite is run. + + In non-strict mode, some checks will be easier to pass: e.g., they + will only make sure an error is raised instead of also checking the + full error message. + Some checks are considered completely strict, in which case they are + treated as if they were in the estimators' `_xfails_checks` tag: they + will be marked as `xfail` for pytest. See :ref:`estimator_tags` for + more info on the `_xfails_check` tag. The set of strict checks is in + `sklearn.utils.estimator_checks._FULLY_STRICT_CHECKS`. + + .. versionadded:: 0.24 Returns ------- @@ -494,11 +509,19 @@ def check_estimator(Estimator, generate_only=False, strict_mode=True): .. versionadded:: 0.22 strict_mode : bool, default=True - If False, the strict checks will be treated as if they were in the - estimator's `_xfails_checks` tag: they will be ignored with a - warning. See :ref:`estimator_tags` for more info on the - `_xfails_check` tag. The set of strict checks is in - `sklearn.utils.estimator_checks._STRICT_CHECKS`. + If True, the full check suite is run. + If False, only the non-strict part of the check suite is run. + + In non-strict mode, some checks will be easier to pass: e.g., they + will only make sure an error is raised instead of also checking the + full error message. + Some checks are considered completely strict, in which case they are + treated as if they were in the estimators' `_xfails_checks` tag: they + will be ignored with a warning. See :ref:`estimator_tags` for more + info on the `_xfails_check` tag. The set of strict checks is in + `sklearn.utils.estimator_checks._FULLY_STRICT_CHECKS`. + + .. versionadded:: 0.24 Returns ------- From 9c9757b5fbe0826d935e36e77ba0f3529a56e11a Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 27 May 2020 09:14:11 -0400 Subject: [PATCH 08/15] typo --- sklearn/utils/estimator_checks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index cb0b0d4321486..41b03cd997edb 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -340,7 +340,7 @@ def _construct_instance(Estimator): def _maybe_mark_xfail(estimator, check, strict_mode, pytest): # Mark (estimator, check) pairs as XFAIL if needed (see conditions in - # strict_mode_xfails_partially_strict_checks()) + # _should_be_skipped_or_marked()) # This is similar to _maybe_skip(), but this one is used by # @parametrize_with_checks() instead of check_estimator() @@ -355,7 +355,7 @@ def _maybe_mark_xfail(estimator, check, strict_mode, pytest): def _maybe_skip(estimator, check, strict_mode): # Wrap a check so that it's skipped if needed (see conditions in - # strict_mode_xfails_partially_strict_checks()) + # _should_be_skipped_or_marked()) # This is similar to _maybe_mark_xfail(), but this one is used by # check_estimator() instead of @parametrize_with_checks which requires # pytest From 92b45e12e4bfa9287de657f37ede082e87e7ee49 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 3 Jun 2020 13:31:57 -0400 Subject: [PATCH 09/15] check name in xfail message --- sklearn/tests/test_common.py | 10 ++++------ sklearn/utils/estimator_checks.py | 2 +- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index ace385ddf5903..c41bdb1116a6c 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -235,8 +235,8 @@ def test_strict_mode_check_estimator(): assert not any(isinstance(w, SkipTestWarning) for w in catched_warnings) # When strict mode is off, check_n_features should be skipped because it's # a fully strict check - with pytest.warns(SkipTestWarning, - match='check is fully strict and strict mode is off'): + msg_check_n_features_in = 'check_n_features_in is fully strict ' + with pytest.warns(SkipTestWarning, match=msg_check_n_features_in): check_estimator(LogisticRegression(), strict_mode=False) # NuSVC has some _xfail_checks. They should be skipped regardless of @@ -246,8 +246,7 @@ def test_strict_mode_check_estimator(): check_estimator(NuSVC(), strict_mode=True) # When strict mode is off, check_n_features_in is skipped along with the # rest of the xfail_checks - with pytest.warns(SkipTestWarning, - match='check is fully strict and strict mode is off'): + with pytest.warns(SkipTestWarning, match=msg_check_n_features_in): check_estimator(NuSVC(), strict_mode=False) # MyNMF will fail check_fit_non_negative() in strict mode because it yields @@ -258,8 +257,7 @@ def test_strict_mode_check_estimator(): # strict mode is off, check_fit_non_negative() will not check the exact # error messsage. (We still assert that the warning from # check_n_features_in is raised) - with pytest.warns(SkipTestWarning, - match='check is fully strict and strict mode is off'): + with pytest.warns(SkipTestWarning, match=msg_check_n_features_in): check_estimator(MyNMFWithBadErrorMessage(), strict_mode=False) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 41b03cd997edb..37f35624e1f89 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -395,7 +395,7 @@ def _should_be_skipped_or_marked(estimator, check, strict_mode): return True, xfail_checks[check_name] if check_name in _FULLY_STRICT_CHECKS and not strict_mode: - return True, 'The check is fully strict and strict mode is off' + return True, f'{check_name} is fully strict and strict mode is off' return False, 'placeholder reason that will never be used' From a92320dbaccf2d78d6b5fe8a2f30bb660c86b8c2 Mon Sep 17 00:00:00 2001 From: Roman Yurchak Date: Fri, 10 Jul 2020 11:28:24 +0200 Subject: [PATCH 10/15] Lint --- sklearn/utils/estimator_checks.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 443782d6ea63c..115b87887ed2d 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -2253,7 +2253,6 @@ def check_classifiers_classes(name, classifier_orig, strict_mode=True): @ignore_warnings(category=FutureWarning) def check_regressors_int(name, regressor_orig, strict_mode=True): - X, _ = _boston_subset() X, _ = _regression_dataset() X = _pairwise_estimator_convert_X(X[:50], regressor_orig) rnd = np.random.RandomState(0) From 82584dcc05993a9b1d3a7f6d8d59e79b446cba20 Mon Sep 17 00:00:00 2001 From: Roman Yurchak Date: Sat, 11 Jul 2020 19:18:34 +0200 Subject: [PATCH 11/15] Lint --- sklearn/utils/estimator_checks.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 115b87887ed2d..fcd92a89c8ffd 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -898,7 +898,8 @@ def check_sample_weights_shape(name, estimator_orig, strict_mode=True): @ignore_warnings(category=FutureWarning) -def check_sample_weights_invariance(name, estimator_orig, kind="ones", strict_mode=True): +def check_sample_weights_invariance(name, estimator_orig, kind="ones", + strict_mode=True): # For kind="ones" check that the estimators yield same results for # unit weights and no weights # For kind="zeros" check that setting sample_weight to 0 is equivalent @@ -1493,7 +1494,7 @@ def check_estimators_nan_inf(name, estimator_orig, strict_mode=True): # Checks that Estimator X's do not contain NaN or inf. rnd = np.random.RandomState(0) X_train_finite = _pairwise_estimator_convert_X(rnd.uniform(size=(10, 3)), - estimator_orig) + estimator_orig) X_train_nan = rnd.uniform(size=(10, 3)) X_train_nan[0, 0] = np.nan X_train_inf = rnd.uniform(size=(10, 3)) From a0195f73b4a4fedcd8cb0dff5293a5b9d100f30c Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 20 Jul 2020 07:37:34 -0400 Subject: [PATCH 12/15] removed use of parenthesis in xfail_checks tag --- sklearn/calibration.py | 2 +- sklearn/cluster/_kmeans.py | 4 ++-- sklearn/ensemble/_iforest.py | 2 +- sklearn/linear_model/_logistic.py | 2 +- sklearn/linear_model/_ransac.py | 2 +- sklearn/linear_model/_ridge.py | 2 +- sklearn/linear_model/_stochastic_gradient.py | 4 ++-- sklearn/neighbors/_kde.py | 2 +- sklearn/svm/_classes.py | 14 +++++++------- 9 files changed, 17 insertions(+), 17 deletions(-) diff --git a/sklearn/calibration.py b/sklearn/calibration.py index 18aca33609ad3..b276c94173d19 100644 --- a/sklearn/calibration.py +++ b/sklearn/calibration.py @@ -307,7 +307,7 @@ class that has the highest probability, and can thus be different def _more_tags(self): return { '_xfail_checks': { - 'check_sample_weights_invariance(kind=zeros)': + 'check_sample_weights_invariance': 'zero sample_weight is not equivalent to removing samples', } } diff --git a/sklearn/cluster/_kmeans.py b/sklearn/cluster/_kmeans.py index a42a3b309a57d..23e920e97d364 100644 --- a/sklearn/cluster/_kmeans.py +++ b/sklearn/cluster/_kmeans.py @@ -1142,7 +1142,7 @@ def score(self, X, y=None, sample_weight=None): def _more_tags(self): return { '_xfail_checks': { - 'check_sample_weights_invariance(kind=zeros)': + 'check_sample_weights_invariance': 'zero sample_weight is not equivalent to removing samples', } } @@ -1868,7 +1868,7 @@ def predict(self, X, sample_weight=None): def _more_tags(self): return { '_xfail_checks': { - 'check_sample_weights_invariance(kind=zeros)': + 'check_sample_weights_invariance': 'zero sample_weight is not equivalent to removing samples', } } diff --git a/sklearn/ensemble/_iforest.py b/sklearn/ensemble/_iforest.py index f46a2c881b0d3..a0cc628e4c366 100644 --- a/sklearn/ensemble/_iforest.py +++ b/sklearn/ensemble/_iforest.py @@ -457,7 +457,7 @@ def _compute_score_samples(self, X, subsample_features): def _more_tags(self): return { '_xfail_checks': { - 'check_sample_weights_invariance(kind=zeros)': + 'check_sample_weights_invariance': 'zero sample_weight is not equivalent to removing samples', } } diff --git a/sklearn/linear_model/_logistic.py b/sklearn/linear_model/_logistic.py index da86a1755c2f1..1370c8f32cf2f 100644 --- a/sklearn/linear_model/_logistic.py +++ b/sklearn/linear_model/_logistic.py @@ -2090,7 +2090,7 @@ def score(self, X, y, sample_weight=None): def _more_tags(self): return { '_xfail_checks': { - 'check_sample_weights_invariance(kind=zeros)': + 'check_sample_weights_invariance': 'zero sample_weight is not equivalent to removing samples', } } diff --git a/sklearn/linear_model/_ransac.py b/sklearn/linear_model/_ransac.py index 03ae65d24a001..c9246c121c387 100644 --- a/sklearn/linear_model/_ransac.py +++ b/sklearn/linear_model/_ransac.py @@ -506,7 +506,7 @@ def score(self, X, y): def _more_tags(self): return { '_xfail_checks': { - 'check_sample_weights_invariance(kind=zeros)': + 'check_sample_weights_invariance': 'zero sample_weight is not equivalent to removing samples', } } diff --git a/sklearn/linear_model/_ridge.py b/sklearn/linear_model/_ridge.py index 3d905bfb5d0f0..cd8d18b6e53bc 100644 --- a/sklearn/linear_model/_ridge.py +++ b/sklearn/linear_model/_ridge.py @@ -1913,7 +1913,7 @@ def classes_(self): def _more_tags(self): return { '_xfail_checks': { - 'check_sample_weights_invariance(kind=zeros)': + 'check_sample_weights_invariance': 'zero sample_weight is not equivalent to removing samples', } } diff --git a/sklearn/linear_model/_stochastic_gradient.py b/sklearn/linear_model/_stochastic_gradient.py index 4c772c0ff79a3..cb311bb641b22 100644 --- a/sklearn/linear_model/_stochastic_gradient.py +++ b/sklearn/linear_model/_stochastic_gradient.py @@ -1098,7 +1098,7 @@ def _predict_log_proba(self, X): def _more_tags(self): return { '_xfail_checks': { - 'check_sample_weights_invariance(kind=zeros)': + 'check_sample_weights_invariance': 'zero sample_weight is not equivalent to removing samples', } } @@ -1588,7 +1588,7 @@ def __init__(self, loss="squared_loss", *, penalty="l2", alpha=0.0001, def _more_tags(self): return { '_xfail_checks': { - 'check_sample_weights_invariance(kind=zeros)': + 'check_sample_weights_invariance': 'zero sample_weight is not equivalent to removing samples', } } diff --git a/sklearn/neighbors/_kde.py b/sklearn/neighbors/_kde.py index b342ebcb9c26b..9802285ff7213 100644 --- a/sklearn/neighbors/_kde.py +++ b/sklearn/neighbors/_kde.py @@ -283,7 +283,7 @@ def sample(self, n_samples=1, random_state=None): def _more_tags(self): return { '_xfail_checks': { - 'check_sample_weights_invariance(kind=zeros)': + 'check_sample_weights_invariance': 'sample_weight must have positive values', } } diff --git a/sklearn/svm/_classes.py b/sklearn/svm/_classes.py index ad3dee1e44ae2..5cfbe935a0186 100644 --- a/sklearn/svm/_classes.py +++ b/sklearn/svm/_classes.py @@ -248,7 +248,7 @@ def fit(self, X, y, sample_weight=None): def _more_tags(self): return { '_xfail_checks': { - 'check_sample_weights_invariance(kind=zeros)': + 'check_sample_weights_invariance': 'zero sample_weight is not equivalent to removing samples', } } @@ -436,7 +436,7 @@ def fit(self, X, y, sample_weight=None): def _more_tags(self): return { '_xfail_checks': { - 'check_sample_weights_invariance(kind=zeros)': + 'check_sample_weights_invariance': 'zero sample_weight is not equivalent to removing samples', } } @@ -670,7 +670,7 @@ def __init__(self, *, C=1.0, kernel='rbf', degree=3, gamma='scale', def _more_tags(self): return { '_xfail_checks': { - 'check_sample_weights_invariance(kind=zeros)': + 'check_sample_weights_invariance': 'zero sample_weight is not equivalent to removing samples', } } @@ -895,7 +895,7 @@ def _more_tags(self): 'check_methods_subset_invariance': 'fails for the decision_function method', 'check_class_weight_classifiers': 'class_weight is ignored.', - 'check_sample_weights_invariance(kind=zeros)': + 'check_sample_weights_invariance': 'zero sample_weight is not equivalent to removing samples', } } @@ -1072,7 +1072,7 @@ def probB_(self): def _more_tags(self): return { '_xfail_checks': { - 'check_sample_weights_invariance(kind=zeros)': + 'check_sample_weights_invariance': 'zero sample_weight is not equivalent to removing samples', } } @@ -1226,7 +1226,7 @@ def __init__(self, *, nu=0.5, C=1.0, kernel='rbf', degree=3, def _more_tags(self): return { '_xfail_checks': { - 'check_sample_weights_invariance(kind=zeros)': + 'check_sample_weights_invariance': 'zero sample_weight is not equivalent to removing samples', } } @@ -1459,7 +1459,7 @@ def probB_(self): def _more_tags(self): return { '_xfail_checks': { - 'check_sample_weights_invariance(kind=zeros)': + 'check_sample_weights_invariance': 'zero sample_weight is not equivalent to removing samples', } } From 3b6942631a3cf58a79b9dd5f77570225e5adb7ff Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Sun, 2 Aug 2020 09:36:37 -0400 Subject: [PATCH 13/15] Addressed comments from Joel --- sklearn/utils/estimator_checks.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index fcd92a89c8ffd..dc60fa0f5fe01 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -387,7 +387,7 @@ def _should_be_skipped_or_marked(estimator, check, strict_mode): # Return whether a check should be skipped (when using check_estimator()) # or marked as XFAIL (when using @parametrize_with_checks()), along with a # reason. - # A check should be skipped or marked if: + # A check should be skipped or marked if either: # - the check is in the _xfail_checks tag of the estimator # - the check is fully strict and strict mode is off # Checks that are only partially strict will not be skipped since we want @@ -544,16 +544,15 @@ def check_estimator(Estimator, generate_only=False, strict_mode=True): estimator = Estimator name = type(estimator).__name__ - checks_generator = ( - (estimator, partial(_maybe_skip(estimator, check, strict_mode), - name, strict_mode=strict_mode)) - for check in _yield_all_checks(estimator) - ) + def checks_generator(): + for check in _yield_all_checks(estimator): + check = _maybe_skip(estimator, check, strict_mode) + yield estimator, partial(check, name, strict_mode=strict_mode) if generate_only: return checks_generator - for estimator, check in checks_generator: + for estimator, check in checks_generator(): try: check(estimator) except SkipTest as exception: From 993dd94c1e16f76757eb427e21db4b6527705c68 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Sun, 2 Aug 2020 10:18:33 -0400 Subject: [PATCH 14/15] probably fixed test --- sklearn/utils/estimator_checks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 7623ee4e07592..af1bac32c8dab 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -550,7 +550,7 @@ def checks_generator(): yield estimator, partial(check, name, strict_mode=strict_mode) if generate_only: - return checks_generator + return checks_generator() for estimator, check in checks_generator(): try: From b55ee3c283713e143552706ec9e54e2cda71b348 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Sun, 2 Aug 2020 10:28:00 -0400 Subject: [PATCH 15/15] use generator function instead of comprehension, hopefully clearer --- sklearn/utils/estimator_checks.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index af1bac32c8dab..6adc6210443c7 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -463,18 +463,14 @@ def parametrize_with_checks(estimators, strict_mode=True): "Please pass an instance instead.") raise TypeError(msg) - names = (type(estimator).__name__ for estimator in estimators) - - checks_generator = ( - (estimator, partial(check, name, strict_mode=strict_mode)) - for name, estimator in zip(names, estimators) - for check in _yield_all_checks(estimator)) - - checks_with_marks = ( - _maybe_mark_xfail(estimator, check, strict_mode, pytest) - for estimator, check in checks_generator) + def checks_generator(): + for estimator in estimators: + name = type(estimator).__name__ + for check in _yield_all_checks(estimator): + check = partial(check, name, strict_mode=strict_mode) + yield _maybe_mark_xfail(estimator, check, strict_mode, pytest) - return pytest.mark.parametrize("estimator, check", checks_with_marks, + return pytest.mark.parametrize("estimator, check", checks_generator(), ids=_get_check_estimator_ids)