diff --git a/sklearn/calibration.py b/sklearn/calibration.py index 740c039ca91f0..ca0cdb24512a4 100644 --- a/sklearn/calibration.py +++ b/sklearn/calibration.py @@ -371,7 +371,7 @@ class that has the highest probability, and can thus be different def _more_tags(self): return { '_xfail_checks': { - 'check_sample_weights_invariance(kind=zeros)': + 'check_sample_weights_invariance': 'zero sample_weight is not equivalent to removing samples', } } diff --git a/sklearn/cluster/_kmeans.py b/sklearn/cluster/_kmeans.py index 00b5f31a8611b..5b959e0f048d2 100644 --- a/sklearn/cluster/_kmeans.py +++ b/sklearn/cluster/_kmeans.py @@ -1163,7 +1163,7 @@ def score(self, X, y=None, sample_weight=None): def _more_tags(self): return { '_xfail_checks': { - 'check_sample_weights_invariance(kind=zeros)': + 'check_sample_weights_invariance': 'zero sample_weight is not equivalent to removing samples', } } @@ -1889,7 +1889,7 @@ def predict(self, X, sample_weight=None): def _more_tags(self): return { '_xfail_checks': { - 'check_sample_weights_invariance(kind=zeros)': + 'check_sample_weights_invariance': 'zero sample_weight is not equivalent to removing samples', } } diff --git a/sklearn/ensemble/_iforest.py b/sklearn/ensemble/_iforest.py index f46a2c881b0d3..a0cc628e4c366 100644 --- a/sklearn/ensemble/_iforest.py +++ b/sklearn/ensemble/_iforest.py @@ -457,7 +457,7 @@ def _compute_score_samples(self, X, subsample_features): def _more_tags(self): return { '_xfail_checks': { - 'check_sample_weights_invariance(kind=zeros)': + 'check_sample_weights_invariance': 'zero sample_weight is not equivalent to removing samples', } } diff --git a/sklearn/linear_model/_logistic.py b/sklearn/linear_model/_logistic.py index da86a1755c2f1..1370c8f32cf2f 100644 --- a/sklearn/linear_model/_logistic.py +++ b/sklearn/linear_model/_logistic.py @@ -2090,7 +2090,7 @@ def score(self, X, y, sample_weight=None): def _more_tags(self): return { '_xfail_checks': { - 'check_sample_weights_invariance(kind=zeros)': + 'check_sample_weights_invariance': 'zero sample_weight is not equivalent to removing samples', } } diff --git a/sklearn/linear_model/_ransac.py b/sklearn/linear_model/_ransac.py index 03ae65d24a001..c9246c121c387 100644 --- a/sklearn/linear_model/_ransac.py +++ b/sklearn/linear_model/_ransac.py @@ -506,7 +506,7 @@ def score(self, X, y): def _more_tags(self): return { '_xfail_checks': { - 'check_sample_weights_invariance(kind=zeros)': + 'check_sample_weights_invariance': 'zero sample_weight is not equivalent to removing samples', } } diff --git a/sklearn/linear_model/_ridge.py b/sklearn/linear_model/_ridge.py index 3d905bfb5d0f0..cd8d18b6e53bc 100644 --- a/sklearn/linear_model/_ridge.py +++ b/sklearn/linear_model/_ridge.py @@ -1913,7 +1913,7 @@ def classes_(self): def _more_tags(self): return { '_xfail_checks': { - 'check_sample_weights_invariance(kind=zeros)': + 'check_sample_weights_invariance': 'zero sample_weight is not equivalent to removing samples', } } diff --git a/sklearn/linear_model/_stochastic_gradient.py b/sklearn/linear_model/_stochastic_gradient.py index 4c772c0ff79a3..cb311bb641b22 100644 --- a/sklearn/linear_model/_stochastic_gradient.py +++ b/sklearn/linear_model/_stochastic_gradient.py @@ -1098,7 +1098,7 @@ def _predict_log_proba(self, X): def _more_tags(self): return { '_xfail_checks': { - 'check_sample_weights_invariance(kind=zeros)': + 'check_sample_weights_invariance': 'zero sample_weight is not equivalent to removing samples', } } @@ -1588,7 +1588,7 @@ def __init__(self, loss="squared_loss", *, penalty="l2", alpha=0.0001, def _more_tags(self): return { '_xfail_checks': { - 'check_sample_weights_invariance(kind=zeros)': + 'check_sample_weights_invariance': 'zero sample_weight is not equivalent to removing samples', } } diff --git a/sklearn/neighbors/_kde.py b/sklearn/neighbors/_kde.py index 190ae8819184f..2f4c0a6b961db 100644 --- a/sklearn/neighbors/_kde.py +++ b/sklearn/neighbors/_kde.py @@ -284,7 +284,7 @@ def sample(self, n_samples=1, random_state=None): def _more_tags(self): return { '_xfail_checks': { - 'check_sample_weights_invariance(kind=zeros)': + 'check_sample_weights_invariance': 'sample_weight must have positive values', } } diff --git a/sklearn/svm/_classes.py b/sklearn/svm/_classes.py index ad3dee1e44ae2..5cfbe935a0186 100644 --- a/sklearn/svm/_classes.py +++ b/sklearn/svm/_classes.py @@ -248,7 +248,7 @@ def fit(self, X, y, sample_weight=None): def _more_tags(self): return { '_xfail_checks': { - 'check_sample_weights_invariance(kind=zeros)': + 'check_sample_weights_invariance': 'zero sample_weight is not equivalent to removing samples', } } @@ -436,7 +436,7 @@ def fit(self, X, y, sample_weight=None): def _more_tags(self): return { '_xfail_checks': { - 'check_sample_weights_invariance(kind=zeros)': + 'check_sample_weights_invariance': 'zero sample_weight is not equivalent to removing samples', } } @@ -670,7 +670,7 @@ def __init__(self, *, C=1.0, kernel='rbf', degree=3, gamma='scale', def _more_tags(self): return { '_xfail_checks': { - 'check_sample_weights_invariance(kind=zeros)': + 'check_sample_weights_invariance': 'zero sample_weight is not equivalent to removing samples', } } @@ -895,7 +895,7 @@ def _more_tags(self): 'check_methods_subset_invariance': 'fails for the decision_function method', 'check_class_weight_classifiers': 'class_weight is ignored.', - 'check_sample_weights_invariance(kind=zeros)': + 'check_sample_weights_invariance': 'zero sample_weight is not equivalent to removing samples', } } @@ -1072,7 +1072,7 @@ def probB_(self): def _more_tags(self): return { '_xfail_checks': { - 'check_sample_weights_invariance(kind=zeros)': + 'check_sample_weights_invariance': 'zero sample_weight is not equivalent to removing samples', } } @@ -1226,7 +1226,7 @@ def __init__(self, *, nu=0.5, C=1.0, kernel='rbf', degree=3, def _more_tags(self): return { '_xfail_checks': { - 'check_sample_weights_invariance(kind=zeros)': + 'check_sample_weights_invariance': 'zero sample_weight is not equivalent to removing samples', } } @@ -1459,7 +1459,7 @@ def probB_(self): def _more_tags(self): return { '_xfail_checks': { - 'check_sample_weights_invariance(kind=zeros)': + 'check_sample_weights_invariance': 'zero sample_weight is not equivalent to removing samples', } } diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index b9f50a76f7b30..c41bdb1116a6c 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -15,24 +15,27 @@ from functools import partial import pytest - +import numpy as np from sklearn.utils import all_estimators from sklearn.utils._testing import ignore_warnings -from sklearn.exceptions import ConvergenceWarning +from sklearn.exceptions import ConvergenceWarning, SkipTestWarning from sklearn.utils.estimator_checks import check_estimator import sklearn from sklearn.base import BiclusterMixin +from sklearn.decomposition import NMF +from sklearn.utils.validation import check_non_negative, check_array from sklearn.linear_model._base import LinearClassifierMixin from sklearn.linear_model import LogisticRegression +from sklearn.svm import NuSVC from sklearn.utils import IS_PYPY from sklearn.utils._testing import SkipTest from sklearn.utils.estimator_checks import ( _construct_instance, _set_checking_parameters, - _set_check_estimator_ids, + _get_check_estimator_ids, check_class_weight_balanced_linear_classifier, parametrize_with_checks) @@ -59,8 +62,8 @@ def _sample_func(x, y=1): "LogisticRegression(class_weight='balanced',random_state=1," "solver='newton-cg',warm_start=True)") ]) -def test_set_check_estimator_ids(val, expected): - assert _set_check_estimator_ids(val) == expected +def test_get_check_estimator_ids(val, expected): + assert _get_check_estimator_ids(val) == expected def _tested_estimators(): @@ -204,3 +207,64 @@ def test_class_support_removed(): with pytest.raises(TypeError, match=msg): parametrize_with_checks([LogisticRegression]) + + +class MyNMFWithBadErrorMessage(NMF): + # Same as NMF but raises an uninformative error message if X has negative + # value. This estimator would fail the check suite in strict mode, + # specifically it would fail check_fit_non_negative + def fit(self, X, y=None, **params): + X = check_array(X, accept_sparse=('csr', 'csc'), + dtype=[np.float64, np.float32]) + try: + check_non_negative(X, whom='') + except ValueError: + raise ValueError("Some non-informative error msg") + + return super().fit(X, y, **params) + + +def test_strict_mode_check_estimator(): + # Tests various conditions for the strict mode of check_estimator() + # Details are in the comments + + # LogisticRegression has no _xfail_checks, so when strict_mode is on, there + # should be no skipped tests. + with pytest.warns(None) as catched_warnings: + check_estimator(LogisticRegression(), strict_mode=True) + assert not any(isinstance(w, SkipTestWarning) for w in catched_warnings) + # When strict mode is off, check_n_features should be skipped because it's + # a fully strict check + msg_check_n_features_in = 'check_n_features_in is fully strict ' + with pytest.warns(SkipTestWarning, match=msg_check_n_features_in): + check_estimator(LogisticRegression(), strict_mode=False) + + # NuSVC has some _xfail_checks. They should be skipped regardless of + # strict_mode + with pytest.warns(SkipTestWarning, + match='fails for the decision_function method'): + check_estimator(NuSVC(), strict_mode=True) + # When strict mode is off, check_n_features_in is skipped along with the + # rest of the xfail_checks + with pytest.warns(SkipTestWarning, match=msg_check_n_features_in): + check_estimator(NuSVC(), strict_mode=False) + + # MyNMF will fail check_fit_non_negative() in strict mode because it yields + # a bad error message + with pytest.raises(AssertionError, match='does not match'): + check_estimator(MyNMFWithBadErrorMessage(), strict_mode=True) + # However, it should pass the test suite in non-strict mode because when + # strict mode is off, check_fit_non_negative() will not check the exact + # error messsage. (We still assert that the warning from + # check_n_features_in is raised) + with pytest.warns(SkipTestWarning, match=msg_check_n_features_in): + check_estimator(MyNMFWithBadErrorMessage(), strict_mode=False) + + +@parametrize_with_checks([LogisticRegression(), + NuSVC(), + MyNMFWithBadErrorMessage()], + strict_mode=False) +def test_strict_mode_parametrize_with_checks(estimator, check): + # Ideally we should assert that the strict checks are Xfailed... + check(estimator) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index a61de0767f697..6adc6210443c7 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -140,7 +140,7 @@ def _yield_classifier_checks(classifier): @ignore_warnings(category=FutureWarning) -def check_supervised_y_no_nan(name, estimator_orig): +def check_supervised_y_no_nan(name, estimator_orig, strict_mode=True): # Checks that the Estimator targets are not NaN. estimator = clone(estimator_orig) rng = np.random.RandomState(888) @@ -287,14 +287,14 @@ def _yield_all_checks(estimator): yield check_fit_non_negative -def _set_check_estimator_ids(obj): +def _get_check_estimator_ids(obj): """Create pytest ids for checks. When `obj` is an estimator, this returns the pprint version of the estimator (with `print_changed_only=True`). When `obj` is a function, the name of the function is returned with its keyworld arguments. - `_set_check_estimator_ids` is designed to be used as the `id` in + `_get_check_estimator_ids` is designed to be used as the `id` in `pytest.mark.parametrize` where `check_estimator(..., generate_only=True)` is yielding estimators and checks. @@ -344,41 +344,69 @@ def _construct_instance(Estimator): return estimator -def _mark_xfail_checks(estimator, check, pytest): - """Mark (estimator, check) pairs with xfail according to the - _xfail_checks_ tag.""" - xfail_checks = estimator._get_tags()['_xfail_checks'] or {} - check_name = _set_check_estimator_ids(check) +def _maybe_mark_xfail(estimator, check, strict_mode, pytest): + # Mark (estimator, check) pairs as XFAIL if needed (see conditions in + # _should_be_skipped_or_marked()) + # This is similar to _maybe_skip(), but this one is used by + # @parametrize_with_checks() instead of check_estimator() - if check_name not in xfail_checks: - # check isn't part of the xfail_checks tags, just return it + should_be_marked, reason = _should_be_skipped_or_marked(estimator, check, + strict_mode) + if not should_be_marked: return estimator, check else: - # check is in the tag, mark it as xfail for pytest - reason = xfail_checks[check_name] return pytest.param(estimator, check, marks=pytest.mark.xfail(reason=reason)) -def _skip_if_xfail(estimator, check): - # wrap a check so that it's skipped with a warning if it's part of the - # xfail_checks tag. - xfail_checks = estimator._get_tags()['_xfail_checks'] or {} - check_name = _set_check_estimator_ids(check) - - if check_name not in xfail_checks: +def _maybe_skip(estimator, check, strict_mode): + # Wrap a check so that it's skipped if needed (see conditions in + # _should_be_skipped_or_marked()) + # This is similar to _maybe_mark_xfail(), but this one is used by + # check_estimator() instead of @parametrize_with_checks which requires + # pytest + should_be_skipped, reason = _should_be_skipped_or_marked(estimator, check, + strict_mode) + if not should_be_skipped: return check + check_name = (check.func.__name__ if isinstance(check, partial) + else check.__name__) + @wraps(check) def wrapped(*args, **kwargs): raise SkipTest( - f"Skipping {check_name} for {estimator.__class__.__name__}" + f"Skipping {check_name} for {estimator.__class__.__name__}: " + f"{reason}" ) return wrapped -def parametrize_with_checks(estimators): +def _should_be_skipped_or_marked(estimator, check, strict_mode): + # Return whether a check should be skipped (when using check_estimator()) + # or marked as XFAIL (when using @parametrize_with_checks()), along with a + # reason. + # A check should be skipped or marked if either: + # - the check is in the _xfail_checks tag of the estimator + # - the check is fully strict and strict mode is off + # Checks that are only partially strict will not be skipped since we want + # to run their non-strict parts. + + check_name = (check.func.__name__ if isinstance(check, partial) + else check.__name__) + + xfail_checks = estimator._get_tags()['_xfail_checks'] or {} + if check_name in xfail_checks: + return True, xfail_checks[check_name] + + if check_name in _FULLY_STRICT_CHECKS and not strict_mode: + return True, f'{check_name} is fully strict and strict mode is off' + + return False, 'placeholder reason that will never be used' + + +def parametrize_with_checks(estimators, strict_mode=True): """Pytest specific decorator for parametrizing estimator checks. The `id` of each check is set to be a pprint version of the estimator @@ -396,6 +424,21 @@ def parametrize_with_checks(estimators): Passing a class was deprecated in version 0.23, and support for classes was removed in 0.24. Pass an instance instead. + strict_mode : bool, default=True + If True, the full check suite is run. + If False, only the non-strict part of the check suite is run. + + In non-strict mode, some checks will be easier to pass: e.g., they + will only make sure an error is raised instead of also checking the + full error message. + Some checks are considered completely strict, in which case they are + treated as if they were in the estimators' `_xfails_checks` tag: they + will be marked as `xfail` for pytest. See :ref:`estimator_tags` for + more info on the `_xfails_check` tag. The set of strict checks is in + `sklearn.utils.estimator_checks._FULLY_STRICT_CHECKS`. + + .. versionadded:: 0.24 + Returns ------- decorator : `pytest.mark.parametrize` @@ -420,21 +463,18 @@ def parametrize_with_checks(estimators): "Please pass an instance instead.") raise TypeError(msg) - names = (type(estimator).__name__ for estimator in estimators) - - checks_generator = ((estimator, partial(check, name)) - for name, estimator in zip(names, estimators) - for check in _yield_all_checks(estimator)) + def checks_generator(): + for estimator in estimators: + name = type(estimator).__name__ + for check in _yield_all_checks(estimator): + check = partial(check, name, strict_mode=strict_mode) + yield _maybe_mark_xfail(estimator, check, strict_mode, pytest) - checks_with_marks = ( - _mark_xfail_checks(estimator, check, pytest) - for estimator, check in checks_generator) + return pytest.mark.parametrize("estimator, check", checks_generator(), + ids=_get_check_estimator_ids) - return pytest.mark.parametrize("estimator, check", checks_with_marks, - ids=_set_check_estimator_ids) - -def check_estimator(Estimator, generate_only=False): +def check_estimator(Estimator, generate_only=False, strict_mode=True): """Check if estimator adheres to scikit-learn conventions. This estimator will run an extensive test-suite for input validation, @@ -470,6 +510,21 @@ def check_estimator(Estimator, generate_only=False): .. versionadded:: 0.22 + strict_mode : bool, default=True + If True, the full check suite is run. + If False, only the non-strict part of the check suite is run. + + In non-strict mode, some checks will be easier to pass: e.g., they + will only make sure an error is raised instead of also checking the + full error message. + Some checks are considered completely strict, in which case they are + treated as if they were in the estimators' `_xfails_checks` tag: they + will be ignored with a warning. See :ref:`estimator_tags` for more + info on the `_xfails_check` tag. The set of strict checks is in + `sklearn.utils.estimator_checks._FULLY_STRICT_CHECKS`. + + .. versionadded:: 0.24 + Returns ------- checks_generator : generator @@ -485,14 +540,15 @@ def check_estimator(Estimator, generate_only=False): estimator = Estimator name = type(estimator).__name__ - checks_generator = ((estimator, - partial(_skip_if_xfail(estimator, check), name)) - for check in _yield_all_checks(estimator)) + def checks_generator(): + for check in _yield_all_checks(estimator): + check = _maybe_skip(estimator, check, strict_mode) + yield estimator, partial(check, name, strict_mode=strict_mode) if generate_only: - return checks_generator + return checks_generator() - for estimator, check in checks_generator: + for estimator, check in checks_generator(): try: check(estimator) except SkipTest as exception: @@ -695,7 +751,7 @@ def _generate_sparse_matrix(X_csr): yield sparse_format + "_64", X -def check_estimator_sparse_data(name, estimator_orig): +def check_estimator_sparse_data(name, estimator_orig, strict_mode=True): rng = np.random.RandomState(0) X = rng.rand(40, 10) X[X < .8] = 0 @@ -751,7 +807,7 @@ def check_estimator_sparse_data(name, estimator_orig): @ignore_warnings(category=FutureWarning) -def check_sample_weights_pandas_series(name, estimator_orig): +def check_sample_weights_pandas_series(name, estimator_orig, strict_mode=True): # check that estimators will accept a 'sample_weight' parameter of # type pandas.Series in the 'fit' function. estimator = clone(estimator_orig) @@ -778,7 +834,7 @@ def check_sample_weights_pandas_series(name, estimator_orig): @ignore_warnings(category=(FutureWarning)) -def check_sample_weights_not_an_array(name, estimator_orig): +def check_sample_weights_not_an_array(name, estimator_orig, strict_mode=True): # check that estimators will accept a 'sample_weight' parameter of # type _NotAnArray in the 'fit' function. estimator = clone(estimator_orig) @@ -795,7 +851,7 @@ def check_sample_weights_not_an_array(name, estimator_orig): @ignore_warnings(category=(FutureWarning)) -def check_sample_weights_list(name, estimator_orig): +def check_sample_weights_list(name, estimator_orig, strict_mode=True): # check that estimators will accept a 'sample_weight' parameter of # type list in the 'fit' function. if has_fit_parameter(estimator_orig, "sample_weight"): @@ -812,7 +868,7 @@ def check_sample_weights_list(name, estimator_orig): @ignore_warnings(category=FutureWarning) -def check_sample_weights_shape(name, estimator_orig): +def check_sample_weights_shape(name, estimator_orig, strict_mode=True): # check that estimators raise an error if sample_weight # shape mismatches the input if (has_fit_parameter(estimator_orig, "sample_weight") and @@ -837,7 +893,8 @@ def check_sample_weights_shape(name, estimator_orig): @ignore_warnings(category=FutureWarning) -def check_sample_weights_invariance(name, estimator_orig, kind="ones"): +def check_sample_weights_invariance(name, estimator_orig, kind="ones", + strict_mode=True): # For kind="ones" check that the estimators yield same results for # unit weights and no weights # For kind="zeros" check that setting sample_weight to 0 is equivalent @@ -889,7 +946,7 @@ def check_sample_weights_invariance(name, estimator_orig, kind="ones"): @ignore_warnings(category=(FutureWarning, UserWarning)) -def check_dtype_object(name, estimator_orig): +def check_dtype_object(name, estimator_orig, strict_mode=True): # check that estimators treat dtype object as numeric if possible rng = np.random.RandomState(0) X = _pairwise_estimator_convert_X(rng.rand(40, 10), estimator_orig) @@ -924,7 +981,7 @@ def check_dtype_object(name, estimator_orig): estimator.fit(X, y) -def check_complex_data(name, estimator_orig): +def check_complex_data(name, estimator_orig, strict_mode=True): # check that estimators raise an exception on providing complex data X = np.random.sample(10) + 1j * np.random.sample(10) X = X.reshape(-1, 1) @@ -935,7 +992,7 @@ def check_complex_data(name, estimator_orig): @ignore_warnings -def check_dict_unchanged(name, estimator_orig): +def check_dict_unchanged(name, estimator_orig, strict_mode=True): # this estimator raises # ValueError: Found array with 0 feature(s) (shape=(23, 0)) # while a minimum of 1 is required. @@ -979,7 +1036,7 @@ def _is_public_parameter(attr): @ignore_warnings(category=FutureWarning) -def check_dont_overwrite_parameters(name, estimator_orig): +def check_dont_overwrite_parameters(name, estimator_orig, strict_mode=True): # check that fit method only changes or sets private attributes if hasattr(estimator_orig.__init__, "deprecated_original"): # to not check deprecated classes @@ -1032,7 +1089,7 @@ def check_dont_overwrite_parameters(name, estimator_orig): @ignore_warnings(category=FutureWarning) -def check_fit2d_predict1d(name, estimator_orig): +def check_fit2d_predict1d(name, estimator_orig, strict_mode=True): # check by fitting a 2d array and predicting with a 1d array rnd = np.random.RandomState(0) X = 3 * rnd.uniform(size=(20, 3)) @@ -1080,7 +1137,7 @@ def _apply_on_subsets(func, X): @ignore_warnings(category=FutureWarning) -def check_methods_subset_invariance(name, estimator_orig): +def check_methods_subset_invariance(name, estimator_orig, strict_mode=True): # check that method gives invariant results if applied # on mini batches or the whole set rnd = np.random.RandomState(0) @@ -1112,7 +1169,7 @@ def check_methods_subset_invariance(name, estimator_orig): @ignore_warnings -def check_fit2d_1sample(name, estimator_orig): +def check_fit2d_1sample(name, estimator_orig, strict_mode=True): # Check that fitting a 2d array with only one sample either works or # returns an informative message. The error message should either mention # the number of samples or the number of classes. @@ -1146,7 +1203,7 @@ def check_fit2d_1sample(name, estimator_orig): @ignore_warnings -def check_fit2d_1feature(name, estimator_orig): +def check_fit2d_1feature(name, estimator_orig, strict_mode=True): # check fitting a 2d array with only 1 feature either works or returns # informative message rnd = np.random.RandomState(0) @@ -1180,7 +1237,7 @@ def check_fit2d_1feature(name, estimator_orig): @ignore_warnings -def check_fit1d(name, estimator_orig): +def check_fit1d(name, estimator_orig, strict_mode=True): # check fitting 1d X array raises a ValueError rnd = np.random.RandomState(0) X = 3 * rnd.uniform(size=(20)) @@ -1202,7 +1259,8 @@ def check_fit1d(name, estimator_orig): @ignore_warnings(category=FutureWarning) -def check_transformer_general(name, transformer, readonly_memmap=False): +def check_transformer_general(name, transformer, readonly_memmap=False, + strict_mode=True): X, y = make_blobs(n_samples=30, centers=[[0, 0, 0], [1, 1, 1]], random_state=0, n_features=2, cluster_std=0.1) X = StandardScaler().fit_transform(X) @@ -1216,7 +1274,7 @@ def check_transformer_general(name, transformer, readonly_memmap=False): @ignore_warnings(category=FutureWarning) -def check_transformer_data_not_an_array(name, transformer): +def check_transformer_data_not_an_array(name, transformer, strict_mode=True): X, y = make_blobs(n_samples=30, centers=[[0, 0, 0], [1, 1, 1]], random_state=0, n_features=2, cluster_std=0.1) X = StandardScaler().fit_transform(X) @@ -1232,7 +1290,7 @@ def check_transformer_data_not_an_array(name, transformer): @ignore_warnings(category=FutureWarning) -def check_transformers_unfitted(name, transformer): +def check_transformers_unfitted(name, transformer, strict_mode=True): X, y = _regression_dataset() transformer = clone(transformer) @@ -1243,7 +1301,7 @@ def check_transformers_unfitted(name, transformer): transformer.transform(X) -def _check_transformer(name, transformer_orig, X, y): +def _check_transformer(name, transformer_orig, X, y, strict_mode=True): n_samples, n_features = np.asarray(X).shape transformer = clone(transformer_orig) set_random_state(transformer) @@ -1322,7 +1380,7 @@ def _check_transformer(name, transformer_orig, X, y): @ignore_warnings -def check_pipeline_consistency(name, estimator_orig): +def check_pipeline_consistency(name, estimator_orig, strict_mode=True): if estimator_orig._get_tags()['non_deterministic']: msg = name + ' is non deterministic' raise SkipTest(msg) @@ -1351,7 +1409,7 @@ def check_pipeline_consistency(name, estimator_orig): @ignore_warnings -def check_fit_score_takes_y(name, estimator_orig): +def check_fit_score_takes_y(name, estimator_orig, strict_mode=True): # check that all estimators accept an optional y # in fit and score so they can be used in pipelines rnd = np.random.RandomState(0) @@ -1380,7 +1438,7 @@ def check_fit_score_takes_y(name, estimator_orig): @ignore_warnings -def check_estimators_dtypes(name, estimator_orig): +def check_estimators_dtypes(name, estimator_orig, strict_mode=True): rnd = np.random.RandomState(0) X_train_32 = 3 * rnd.uniform(size=(20, 5)).astype(np.float32) X_train_32 = _pairwise_estimator_convert_X(X_train_32, estimator_orig) @@ -1403,7 +1461,8 @@ def check_estimators_dtypes(name, estimator_orig): @ignore_warnings(category=FutureWarning) -def check_estimators_empty_data_messages(name, estimator_orig): +def check_estimators_empty_data_messages(name, estimator_orig, + strict_mode=True): e = clone(estimator_orig) set_random_state(e, 1) @@ -1426,11 +1485,11 @@ def check_estimators_empty_data_messages(name, estimator_orig): @ignore_warnings(category=FutureWarning) -def check_estimators_nan_inf(name, estimator_orig): +def check_estimators_nan_inf(name, estimator_orig, strict_mode=True): # Checks that Estimator X's do not contain NaN or inf. rnd = np.random.RandomState(0) X_train_finite = _pairwise_estimator_convert_X(rnd.uniform(size=(10, 3)), - estimator_orig) + estimator_orig) X_train_nan = rnd.uniform(size=(10, 3)) X_train_nan[0, 0] = np.nan X_train_inf = rnd.uniform(size=(10, 3)) @@ -1497,7 +1556,7 @@ def check_estimators_nan_inf(name, estimator_orig): @ignore_warnings -def check_nonsquare_error(name, estimator_orig): +def check_nonsquare_error(name, estimator_orig, strict_mode=True): """Test that error is thrown when non-square data provided.""" X, y = make_blobs(n_samples=20, n_features=10) @@ -1510,7 +1569,7 @@ def check_nonsquare_error(name, estimator_orig): @ignore_warnings -def check_estimators_pickle(name, estimator_orig): +def check_estimators_pickle(name, estimator_orig, strict_mode=True): """Test that we can pickle all estimators.""" check_methods = ["predict", "transform", "decision_function", "predict_proba"] @@ -1554,7 +1613,8 @@ def check_estimators_pickle(name, estimator_orig): @ignore_warnings(category=FutureWarning) -def check_estimators_partial_fit_n_features(name, estimator_orig): +def check_estimators_partial_fit_n_features(name, estimator_orig, + strict_mode=True): # check if number of features changes between calls to partial_fit. if not hasattr(estimator_orig, 'partial_fit'): return @@ -1581,7 +1641,7 @@ def check_estimators_partial_fit_n_features(name, estimator_orig): @ignore_warnings(category=FutureWarning) -def check_classifier_multioutput(name, estimator): +def check_classifier_multioutput(name, estimator, strict_mode=True): n_samples, n_labels, n_classes = 42, 5, 3 tags = estimator._get_tags() estimator = clone(estimator) @@ -1639,7 +1699,7 @@ def check_classifier_multioutput(name, estimator): @ignore_warnings(category=FutureWarning) -def check_regressor_multioutput(name, estimator): +def check_regressor_multioutput(name, estimator, strict_mode=True): estimator = clone(estimator) n_samples = n_features = 10 @@ -1662,7 +1722,8 @@ def check_regressor_multioutput(name, estimator): @ignore_warnings(category=FutureWarning) -def check_clustering(name, clusterer_orig, readonly_memmap=False): +def check_clustering(name, clusterer_orig, readonly_memmap=False, + strict_mode=True): clusterer = clone(clusterer_orig) X, y = make_blobs(n_samples=50, random_state=1) X, y = shuffle(X, y, random_state=7) @@ -1721,7 +1782,8 @@ def check_clustering(name, clusterer_orig, readonly_memmap=False): @ignore_warnings(category=FutureWarning) -def check_clusterer_compute_labels_predict(name, clusterer_orig): +def check_clusterer_compute_labels_predict(name, clusterer_orig, + strict_mode=True): """Check that predict is invariant of compute_labels.""" X, y = make_blobs(n_samples=20, random_state=0) clusterer = clone(clusterer_orig) @@ -1736,7 +1798,7 @@ def check_clusterer_compute_labels_predict(name, clusterer_orig): @ignore_warnings(category=FutureWarning) -def check_classifiers_one_label(name, classifier_orig): +def check_classifiers_one_label(name, classifier_orig, strict_mode=True): error_string_fit = "Classifier can't train when only one class is present." error_string_predict = ("Classifier can't predict when only one class is " "present.") @@ -1771,7 +1833,7 @@ def check_classifiers_one_label(name, classifier_orig): @ignore_warnings # Warnings are raised by decision function def check_classifiers_train(name, classifier_orig, readonly_memmap=False, - X_dtype='float64'): + X_dtype='float64', strict_mode=True): X_m, y_m = make_blobs(n_samples=300, random_state=0) X_m = X_m.astype(X_dtype) X_m, y_m = shuffle(X_m, y_m, random_state=7) @@ -1895,7 +1957,8 @@ def check_classifiers_train(name, classifier_orig, readonly_memmap=False, assert_array_equal(np.argsort(y_log_prob), np.argsort(y_prob)) -def check_outlier_corruption(num_outliers, expected_outliers, decision): +def check_outlier_corruption(num_outliers, expected_outliers, decision, + strict_mode=True): # Check for deviation from the precise given contamination level that may # be due to ties in the anomaly scores. if num_outliers < expected_outliers: @@ -1915,7 +1978,8 @@ def check_outlier_corruption(num_outliers, expected_outliers, decision): assert len(np.unique(sorted_decision[start:end])) == 1, msg -def check_outliers_train(name, estimator_orig, readonly_memmap=True): +def check_outliers_train(name, estimator_orig, readonly_memmap=True, + strict_mode=True): n_samples = 300 X, _ = make_blobs(n_samples=n_samples, random_state=0) X = shuffle(X, random_state=7) @@ -1991,8 +2055,9 @@ def check_outliers_train(name, estimator_orig, readonly_memmap=True): @ignore_warnings(category=(FutureWarning)) -def check_classifiers_multilabel_representation_invariance(name, - classifier_orig): +def check_classifiers_multilabel_representation_invariance( + name, classifier_orig, strict_mode=True): + X, y = make_multilabel_classification(n_samples=100, n_features=20, n_classes=5, n_labels=3, length=50, allow_unlabeled=True, @@ -2026,7 +2091,7 @@ def check_classifiers_multilabel_representation_invariance(name, @ignore_warnings(category=FutureWarning) def check_estimators_fit_returns_self(name, estimator_orig, - readonly_memmap=False): + readonly_memmap=False, strict_mode=True): """Check if self is returned when calling fit.""" X, y = make_blobs(random_state=0, n_samples=21) # some want non-negative input @@ -2044,7 +2109,7 @@ def check_estimators_fit_returns_self(name, estimator_orig, @ignore_warnings -def check_estimators_unfitted(name, estimator_orig): +def check_estimators_unfitted(name, estimator_orig, strict_mode=True): """Check that predict raises an exception in an unfitted estimator. Unfitted estimators should raise a NotFittedError. @@ -2060,7 +2125,7 @@ def check_estimators_unfitted(name, estimator_orig): @ignore_warnings(category=FutureWarning) -def check_supervised_y_2d(name, estimator_orig): +def check_supervised_y_2d(name, estimator_orig, strict_mode=True): tags = estimator_orig._get_tags() if tags['multioutput_only']: # These only work on 2d, so this test makes no sense @@ -2097,7 +2162,8 @@ def check_supervised_y_2d(name, estimator_orig): @ignore_warnings -def check_classifiers_predictions(X, y, name, classifier_orig): +def check_classifiers_predictions(X, y, name, classifier_orig, + strict_mode=True): classes = np.unique(y) classifier = clone(classifier_orig) if name == 'BernoulliNB': @@ -2144,7 +2210,7 @@ def _choose_check_classifiers_labels(name, y, y_names): return y if name in ["LabelPropagation", "LabelSpreading"] else y_names -def check_classifiers_classes(name, classifier_orig): +def check_classifiers_classes(name, classifier_orig, strict_mode=True): X_multiclass, y_multiclass = make_blobs(n_samples=30, random_state=0, cluster_std=0.1) X_multiclass, y_multiclass = shuffle(X_multiclass, y_multiclass, @@ -2182,7 +2248,7 @@ def check_classifiers_classes(name, classifier_orig): @ignore_warnings(category=FutureWarning) -def check_regressors_int(name, regressor_orig): +def check_regressors_int(name, regressor_orig, strict_mode=True): X, _ = _regression_dataset() X = _pairwise_estimator_convert_X(X[:50], regressor_orig) rnd = np.random.RandomState(0) @@ -2211,7 +2277,7 @@ def check_regressors_int(name, regressor_orig): @ignore_warnings(category=FutureWarning) def check_regressors_train(name, regressor_orig, readonly_memmap=False, - X_dtype=np.float64): + X_dtype=np.float64, strict_mode=True): X, y = _regression_dataset() X = X.astype(X_dtype) X = _pairwise_estimator_convert_X(X, regressor_orig) @@ -2256,7 +2322,8 @@ def check_regressors_train(name, regressor_orig, readonly_memmap=False, @ignore_warnings -def check_regressors_no_decision_function(name, regressor_orig): +def check_regressors_no_decision_function(name, regressor_orig, + strict_mode=True): # checks whether regressors have decision_function or predict_proba rng = np.random.RandomState(0) regressor = clone(regressor_orig) @@ -2282,7 +2349,7 @@ def check_regressors_no_decision_function(name, regressor_orig): @ignore_warnings(category=FutureWarning) -def check_class_weight_classifiers(name, classifier_orig): +def check_class_weight_classifiers(name, classifier_orig, strict_mode=True): if classifier_orig._get_tags()['binary_only']: problems = [2] @@ -2329,7 +2396,8 @@ def check_class_weight_classifiers(name, classifier_orig): @ignore_warnings(category=FutureWarning) def check_class_weight_balanced_classifiers(name, classifier_orig, X_train, - y_train, X_test, y_test, weights): + y_train, X_test, y_test, weights, + strict_mode=True): classifier = clone(classifier_orig) if hasattr(classifier, "n_iter"): classifier.set_params(n_iter=100) @@ -2348,7 +2416,8 @@ def check_class_weight_balanced_classifiers(name, classifier_orig, X_train, @ignore_warnings(category=FutureWarning) -def check_class_weight_balanced_linear_classifier(name, Classifier): +def check_class_weight_balanced_linear_classifier(name, Classifier, + strict_mode=True): """Test class weights with non-contiguous class labels.""" # this is run on classes, not instances, though this should be changed X = np.array([[-1.0, -1.0], [-1.0, 0], [-.8, -1.0], @@ -2387,7 +2456,7 @@ def check_class_weight_balanced_linear_classifier(name, Classifier): @ignore_warnings(category=FutureWarning) -def check_estimators_overwrite_params(name, estimator_orig): +def check_estimators_overwrite_params(name, estimator_orig, strict_mode=True): X, y = make_blobs(random_state=0, n_samples=21) # some want non-negative input X -= X.min() @@ -2422,7 +2491,7 @@ def check_estimators_overwrite_params(name, estimator_orig): @ignore_warnings(category=FutureWarning) -def check_no_attributes_set_in_init(name, estimator_orig): +def check_no_attributes_set_in_init(name, estimator_orig, strict_mode=True): """Check setting during init.""" estimator = clone(estimator_orig) if hasattr(type(estimator).__init__, "deprecated_original"): @@ -2456,7 +2525,7 @@ def check_no_attributes_set_in_init(name, estimator_orig): @ignore_warnings(category=FutureWarning) -def check_sparsify_coefficients(name, estimator_orig): +def check_sparsify_coefficients(name, estimator_orig, strict_mode=True): X = np.array([[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1], [-1, -2], [2, 2], [-2, -2]]) y = np.array([1, 1, 1, 2, 2, 2, 3, 3, 3]) @@ -2480,7 +2549,7 @@ def check_sparsify_coefficients(name, estimator_orig): @ignore_warnings(category=FutureWarning) -def check_classifier_data_not_an_array(name, estimator_orig): +def check_classifier_data_not_an_array(name, estimator_orig, strict_mode=True): X = np.array([[3, 0], [0, 1], [0, 2], [1, 1], [1, 2], [2, 1], [0, 3], [1, 0], [2, 0], [4, 4], [2, 3], [3, 2]]) X = _pairwise_estimator_convert_X(X, estimator_orig) @@ -2492,7 +2561,7 @@ def check_classifier_data_not_an_array(name, estimator_orig): @ignore_warnings(category=FutureWarning) -def check_regressor_data_not_an_array(name, estimator_orig): +def check_regressor_data_not_an_array(name, estimator_orig, strict_mode=True): X, y = _regression_dataset() X = _pairwise_estimator_convert_X(X, estimator_orig) y = _enforce_estimator_tags_y(estimator_orig, y) @@ -2502,7 +2571,8 @@ def check_regressor_data_not_an_array(name, estimator_orig): @ignore_warnings(category=FutureWarning) -def check_estimators_data_not_an_array(name, estimator_orig, X, y, obj_type): +def check_estimators_data_not_an_array(name, estimator_orig, X, y, obj_type, + strict_mode=True): if name in CROSS_DECOMPOSITION: raise SkipTest("Skipping check_estimators_data_not_an_array " "for cross decomposition module as estimators " @@ -2544,7 +2614,7 @@ def check_estimators_data_not_an_array(name, estimator_orig, X, y, obj_type): assert_allclose(pred1, pred2, atol=1e-2, err_msg=name) -def check_parameters_default_constructible(name, Estimator): +def check_parameters_default_constructible(name, Estimator, strict_mode=True): # test default-constructibility # get rid of deprecation warnings @@ -2647,7 +2717,8 @@ def _enforce_estimator_tags_x(estimator, X): @ignore_warnings(category=FutureWarning) -def check_non_transformer_estimators_n_iter(name, estimator_orig): +def check_non_transformer_estimators_n_iter(name, estimator_orig, + strict_mode=True): # Test that estimators that are not transformers with a parameter # max_iter, return the attribute of n_iter_ at least 1. @@ -2681,7 +2752,7 @@ def check_non_transformer_estimators_n_iter(name, estimator_orig): @ignore_warnings(category=FutureWarning) -def check_transformer_n_iter(name, estimator_orig): +def check_transformer_n_iter(name, estimator_orig, strict_mode=True): # Test that transformers with a parameter max_iter, return the # attribute of n_iter_ at least 1. estimator = clone(estimator_orig) @@ -2707,7 +2778,7 @@ def check_transformer_n_iter(name, estimator_orig): @ignore_warnings(category=FutureWarning) -def check_get_params_invariance(name, estimator_orig): +def check_get_params_invariance(name, estimator_orig, strict_mode=True): # Checks if get_params(deep=False) is a subset of get_params(deep=True) e = clone(estimator_orig) @@ -2719,7 +2790,7 @@ def check_get_params_invariance(name, estimator_orig): @ignore_warnings(category=FutureWarning) -def check_set_params(name, estimator_orig): +def check_set_params(name, estimator_orig, strict_mode=True): # Check that get_params() returns the same thing # before and after set_params() with some fuzz estimator = clone(estimator_orig) @@ -2773,7 +2844,8 @@ def check_set_params(name, estimator_orig): @ignore_warnings(category=FutureWarning) -def check_classifiers_regression_target(name, estimator_orig): +def check_classifiers_regression_target(name, estimator_orig, + strict_mode=True): # Check if classifier throws an exception when fed regression targets X, y = _regression_dataset() @@ -2786,7 +2858,7 @@ def check_classifiers_regression_target(name, estimator_orig): @ignore_warnings(category=FutureWarning) -def check_decision_proba_consistency(name, estimator_orig): +def check_decision_proba_consistency(name, estimator_orig, strict_mode=True): # Check whether an estimator having both decision_function and # predict_proba methods has outputs with perfect rank correlation. @@ -2808,7 +2880,7 @@ def check_decision_proba_consistency(name, estimator_orig): assert_array_equal(rankdata(a), rankdata(b)) -def check_outliers_fit_predict(name, estimator_orig): +def check_outliers_fit_predict(name, estimator_orig, strict_mode=True): # Check fit_predict for outlier detectors. n_samples = 300 @@ -2855,17 +2927,20 @@ def check_outliers_fit_predict(name, estimator_orig): assert_raises(ValueError, estimator.fit_predict, X) -def check_fit_non_negative(name, estimator_orig): +def check_fit_non_negative(name, estimator_orig, strict_mode=True): # Check that proper warning is raised for non-negative X # when tag requires_positive_X is present X = np.array([[-1., 1], [-1., 1]]) y = np.array([1, 2]) estimator = clone(estimator_orig) - assert_raises_regex(ValueError, "Negative values in data passed to", - estimator.fit, X, y) + if strict_mode: + assert_raises_regex(ValueError, "Negative values in data passed to", + estimator.fit, X, y) + else: # Don't check error message if strict mode is off + assert_raises(ValueError, estimator.fit, X, y) -def check_fit_idempotent(name, estimator_orig): +def check_fit_idempotent(name, estimator_orig, strict_mode=True): # Check that est.fit(X) is the same as est.fit(X).fit(X). Ideally we would # check that the estimated parameters during training (e.g. coefs_) are # the same, but having a universal comparison function for those @@ -2920,7 +2995,7 @@ def check_fit_idempotent(name, estimator_orig): ) -def check_n_features_in(name, estimator_orig): +def check_n_features_in(name, estimator_orig, strict_mode=True): # Make sure that n_features_in_ attribute doesn't exist until fit is # called, and that its value is correct. @@ -2958,7 +3033,7 @@ def check_n_features_in(name, estimator_orig): ) -def check_requires_y_none(name, estimator_orig): +def check_requires_y_none(name, estimator_orig, strict_mode=True): # Make sure that an estimator with requires_y=True fails gracefully when # given y=None @@ -2988,3 +3063,9 @@ def check_requires_y_none(name, estimator_orig): except ValueError as ve: if not any(msg in str(ve) for msg in expected_err_msgs): warnings.warn(warning_msg, FutureWarning) + + +# set of checks that are completely strict, i.e. they have no non-strict part +_FULLY_STRICT_CHECKS = set([ + 'check_n_features_in', +])