diff --git a/doc/developers/develop.rst b/doc/developers/develop.rst index b7b5d2ac0316f..858fd92d2e69b 100644 --- a/doc/developers/develop.rst +++ b/doc/developers/develop.rst @@ -249,22 +249,16 @@ Rolling your own estimator If you want to implement a new estimator that is scikit-learn-compatible, whether it is just for you or for contributing it to scikit-learn, there are several internals of scikit-learn that you should be aware of in addition to -the scikit-learn API outlined above. You can check whether your estimator -adheres to the scikit-learn interface and standards by running -:func:`~sklearn.utils.estimator_checks.check_estimator` on an instance. The -:func:`~sklearn.utils.estimator_checks.parametrize_with_checks` pytest -decorator can also be used (see its docstring for details and possible -interactions with `pytest`):: - - >>> from sklearn.utils.estimator_checks import check_estimator - >>> from sklearn.svm import LinearSVC - >>> check_estimator(LinearSVC()) # passes +the scikit-learn API outlined above. The main motivation to make a class compatible to the scikit-learn estimator interface might be that you want to use it together with model evaluation and selection tools such as :class:`model_selection.GridSearchCV` and :class:`pipeline.Pipeline`. +Checking the compatibility of your estimator with scikit-learn is described +in :ref:`checking_compatibility` + Before detailing the required interface below, we describe two ways to achieve the correct interface more easily. @@ -499,6 +493,35 @@ patterns. The :mod:`sklearn.utils.multiclass` module contains useful functions for working with multiclass and multilabel problems. +.. _checking_compatibility: + +Checking the estimator's compatibility +-------------------------------------- + +You can check whether your estimator adheres to the scikit-learn interface +and standards by running +:func:`~sklearn.utils.estimator_checks.check_estimator` on an instance. + +The :func:`~sklearn.utils.estimator_checks.parametrize_with_checks` pytest +decorator can also be used (see its docstring for details and possible +interactions with `pytest`):: + + >>> from sklearn.utils.estimator_checks import check_estimator + >>> from sklearn.svm import LinearSVC + >>> check_estimator(LinearSVC()) # passes + +Both :func:`~sklearn.utils.estimator_checks.check_estimator` and +:func:`~sklearn.utils.estimator_checks.parametrize_with_checks` expose an +`api_only` parameter: when True, the check suite will only consider pure +API-compatibility checks. Some more advanced checks will be ignored, such as +ensuring that error messages are informative, or ensuring that a classifier +is able to properly discriminate classes on a simple problem. We recommend +leaving this parameter to False to guarantee robust and user-friendly +estimators. + +The kind of checks that the check suite will run can also be partially +controlled by setting estimator tags, described below: + .. _estimator_tags: Estimator Tags diff --git a/doc/glossary.rst b/doc/glossary.rst index fe78c7e5bd8d3..a47b39f0a3879 100644 --- a/doc/glossary.rst +++ b/doc/glossary.rst @@ -142,7 +142,9 @@ General Concepts We provide limited backwards compatibility assurances for the estimator checks: we may add extra requirements on estimators tested with this function, usually when these were informally - assumed but not formally tested. + assumed but not formally tested. In particular, checks that are + not API-related (i.e. those that are ignored when `api_only` is + True) may enforce backward-incompatible requirements. Despite this informal contract with our users, the software is provided as is, as stated in the license. When a release inadvertently diff --git a/doc/whats_new/v0.24.rst b/doc/whats_new/v0.24.rst index 7460f8aa3aad9..3bc09b6f91b5e 100644 --- a/doc/whats_new/v0.24.rst +++ b/doc/whats_new/v0.24.rst @@ -97,7 +97,7 @@ Changelog `init_size_`, are deprecated and will be removed in 0.26. :pr:`17864` by :user:`Jérémie du Boisberranger `. -- |Enhancement| Added :func:`cluster.kmeans_plusplus` as public function. +- |Enhancement| Added :func:`cluster.kmeans_plusplus` as public function. Initialization by KMeans++ can now be called separately to generate initial cluster centroids. :pr:`17937` by :user:`g-walsh` @@ -736,7 +736,7 @@ Changelog when `handle_unknown='error'` and `drop=None` for samples encoded as all zeros. :pr:`14982` by :user:`Kevin Winata `. - + :mod:`sklearn.semi_supervised` .............................. @@ -775,6 +775,12 @@ Changelog :mod:`sklearn.utils` .................... +- |Feature| :func:`~utils.estimator_checks.check_estimator` and + :func:`~utils.estimator_checks.parametrize_with_checks` now expose an + `api_only` parameter which allows to control whether the check suite should + only check for pure API-compatibility, or also run more advanced checks. + :pr:`18582` and :pr:`17361` by `Nicolas Hug`_. + - |Enhancement| Add ``check_methods_sample_order_invariance`` to :func:`~utils.estimator_checks.check_estimator`, which checks that estimator methods are invariant if applied to the same dataset @@ -793,12 +799,10 @@ Changelog dimensions do not match in :func:`utils.sparse_func.incr_mean_variance_axis`. By :user:`Alex Gramfort `. - - |Enhancement| Add support for weights in :func:`utils.sparse_func.incr_mean_variance_axis`. By :user:`Maria Telenczuk ` and :user:`Alex Gramfort `. - Miscellaneous ............. diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index 0a34f30765862..d93c8f1f65a93 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -219,10 +219,13 @@ def test_class_support_removed(): class MyNMFWithBadErrorMessage(NMF): # Same as NMF but raises an uninformative error message if X has negative - # value. This estimator would fail the check suite in strict mode, - # specifically it would fail check_fit_non_negative - # FIXME : should be removed in 0.26 + # value. This estimator would fail the check suite with api_only=False, + # specifically it would fail check_fit_non_negative because its error + # message doesn't match the expected one. + def __init__(self): + # declare init to avoid deprecation warning since default has changed + # FIXME : __init__ should be removed in 0.26 super().__init__() self.init = 'nndsvda' self.max_iter = 500 @@ -238,51 +241,52 @@ def fit(self, X, y=None, **params): return super().fit(X, y, **params) -def test_strict_mode_check_estimator(): - # Tests various conditions for the strict mode of check_estimator() +def test_api_only_check_estimator(): + # Tests various conditions for the api_only parameter of check_estimator() # Details are in the comments - # LogisticRegression has no _xfail_checks, so when strict_mode is on, there + # LogisticRegression has no _xfail_checks, so when api_only=False, there # should be no skipped tests. with pytest.warns(None) as catched_warnings: - check_estimator(LogisticRegression(), strict_mode=True) + check_estimator(LogisticRegression(), api_only=False) assert not any(isinstance(w, SkipTestWarning) for w in catched_warnings) - # When strict mode is off, check_n_features should be skipped because it's - # a fully strict check - msg_check_n_features_in = 'check_n_features_in is fully strict ' - with pytest.warns(SkipTestWarning, match=msg_check_n_features_in): - check_estimator(LogisticRegression(), strict_mode=False) + # When api_only is True, check_fit2d_1sample should be skipped + # because it's not an API check + skip_match = 'check_fit2d_1sample is not an API check' + with pytest.warns(SkipTestWarning, match=skip_match): + check_estimator(LogisticRegression(), api_only=True) # NuSVC has some _xfail_checks. They should be skipped regardless of - # strict_mode + # api_only with pytest.warns(SkipTestWarning, match='fails for the decision_function method'): - check_estimator(NuSVC(), strict_mode=True) - # When strict mode is off, check_n_features_in is skipped along with the - # rest of the xfail_checks - with pytest.warns(SkipTestWarning, match=msg_check_n_features_in): - check_estimator(NuSVC(), strict_mode=False) - - # MyNMF will fail check_fit_non_negative() in strict mode because it yields - # a bad error message + check_estimator(NuSVC(), api_only=False) + # When api_only is True, check_fit2d_1sample is skipped along + # with the rest of the xfail_checks + with pytest.warns(SkipTestWarning, match=skip_match): + check_estimator(NuSVC(), api_only=True) + + # MyNMF will fail check_fit_non_negative() with api_only=False because it + # yields a bad error message with pytest.raises( AssertionError, match="The error message should contain" ): - check_estimator(MyNMFWithBadErrorMessage(), strict_mode=True) - # However, it should pass the test suite in non-strict mode because when - # strict mode is off, check_fit_non_negative() will not check the exact - # error messsage. (We still assert that the warning from - # check_n_features_in is raised) - with pytest.warns(SkipTestWarning, match=msg_check_n_features_in): - check_estimator(MyNMFWithBadErrorMessage(), strict_mode=False) + check_estimator(MyNMFWithBadErrorMessage(), api_only=False) + # However, it should pass the test suite with api_only=True because when in + # this case, check_fit_non_negative() will not check the exact error + # messsage. (We still assert that the warning from + # check_fit2d_1sample is raised) + with pytest.warns(SkipTestWarning, match=skip_match): + check_estimator(MyNMFWithBadErrorMessage(), api_only=True) @parametrize_with_checks([LogisticRegression(), NuSVC(), MyNMFWithBadErrorMessage()], - strict_mode=False) -def test_strict_mode_parametrize_with_checks(estimator, check): - # Ideally we should assert that the strict checks are Xfailed... + api_only=True) +def test_api_only_parametrize_with_checks(estimator, check): + # Ideally we should assert that the NON_API checks are either Xfailed or + # Xpassed check(estimator) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index c423f1fe8c37a..ff3e32978dbb8 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -14,7 +14,6 @@ from . import IS_PYPY from .. import config_context from ._testing import _get_args -from ._testing import assert_raise_message from ._testing import assert_array_equal from ._testing import assert_array_almost_equal from ._testing import assert_allclose @@ -32,6 +31,7 @@ from ..base import ( clone, ClusterMixin, + _DEFAULT_TAGS, is_classifier, is_regressor, is_outlier_detector, @@ -66,9 +66,39 @@ CROSS_DECOMPOSITION = ['PLSCanonical', 'PLSRegression', 'CCA', 'PLSSVD'] +def _safe_tags(estimator, key=None): + """Safely get estimator tags for common checks. + + :class:`~sklearn.BaseEstimator` provides the estimator tags machinery. + However, if a compatible estimator does not inherit from this base class, + we should default to the default tag. + + Parameters + ---------- + estimator : estimator object + The estimator from which to get the tag. + key : str, default=None + Tag name to get. By default (`None`), all tags are returned. + + Returns + ------- + tags : dict + The estimator tags. + """ + if hasattr(estimator, "_get_tags"): + if key is not None: + return estimator._get_tags().get(key, _DEFAULT_TAGS[key]) + tags = estimator._get_tags() + return {key: tags.get(key, _DEFAULT_TAGS[key]) + for key in _DEFAULT_TAGS.keys()} + if key is not None: + return _DEFAULT_TAGS[key] + return _DEFAULT_TAGS + + def _yield_checks(estimator): name = estimator.__class__.__name__ - tags = estimator._get_tags() + tags = _safe_tags(estimator) pairwise = _is_pairwise(estimator) yield check_no_attributes_set_in_init @@ -116,7 +146,7 @@ def _yield_checks(estimator): def _yield_classifier_checks(classifier): - tags = classifier._get_tags() + tags = _safe_tags(classifier) # test classifiers can handle non-array data and pandas objects yield check_classifier_data_not_an_array @@ -149,7 +179,7 @@ def _yield_classifier_checks(classifier): @ignore_warnings(category=FutureWarning) -def check_supervised_y_no_nan(name, estimator_orig, strict_mode=True): +def check_supervised_y_no_nan(name, estimator_orig, api_only=False): # Checks that the Estimator targets are not NaN. estimator = clone(estimator_orig) rng = np.random.RandomState(888) @@ -170,7 +200,7 @@ def check_supervised_y_no_nan(name, estimator_orig, strict_mode=True): def _yield_regressor_checks(regressor): - tags = regressor._get_tags() + tags = _safe_tags(regressor) # TODO: test with intercept # TODO: test with multiple responses # basic testing @@ -196,7 +226,7 @@ def _yield_regressor_checks(regressor): def _yield_transformer_checks(transformer): - tags = transformer._get_tags() + tags = _safe_tags(transformer) # All transformers should either deal with sparse data or raise an # exception with type TypeError and an intelligible error message if not tags["no_validation"]: @@ -206,7 +236,7 @@ def _yield_transformer_checks(transformer): if tags["preserves_dtype"]: yield check_transformer_preserve_dtypes yield partial(check_transformer_general, readonly_memmap=True) - if not transformer._get_tags()["stateless"]: + if not _safe_tags(transformer, key="stateless"): yield check_transformers_unfitted # Dependent on external solvers and hence accessing the iter # param is non-trivial. @@ -243,13 +273,13 @@ def _yield_outliers_checks(estimator): # test outlier detectors can handle non-array data yield check_classifier_data_not_an_array # test if NotFittedError is raised - if estimator._get_tags()["requires_fit"]: + if _safe_tags(estimator, key="requires_fit"): yield check_estimators_unfitted def _yield_all_checks(estimator): name = estimator.__class__.__name__ - tags = estimator._get_tags() + tags = _safe_tags(estimator) if "2darray" not in tags["X_types"]: warnings.warn("Can't test estimator {} which requires input " " of type {}".format(name, tags["X_types"]), @@ -369,14 +399,14 @@ def _construct_instance(Estimator): return estimator -def _maybe_mark_xfail(estimator, check, strict_mode, pytest): +def _maybe_mark_xfail(estimator, check, api_only, pytest): # Mark (estimator, check) pairs as XFAIL if needed (see conditions in # _should_be_skipped_or_marked()) # This is similar to _maybe_skip(), but this one is used by # @parametrize_with_checks() instead of check_estimator() should_be_marked, reason = _should_be_skipped_or_marked(estimator, check, - strict_mode) + api_only) if not should_be_marked: return estimator, check else: @@ -384,14 +414,14 @@ def _maybe_mark_xfail(estimator, check, strict_mode, pytest): marks=pytest.mark.xfail(reason=reason)) -def _maybe_skip(estimator, check, strict_mode): +def _maybe_skip(estimator, check, api_only): # Wrap a check so that it's skipped if needed (see conditions in # _should_be_skipped_or_marked()) # This is similar to _maybe_mark_xfail(), but this one is used by # check_estimator() instead of @parametrize_with_checks which requires # pytest should_be_skipped, reason = _should_be_skipped_or_marked(estimator, check, - strict_mode) + api_only) if not should_be_skipped: return check @@ -408,30 +438,30 @@ def wrapped(*args, **kwargs): return wrapped -def _should_be_skipped_or_marked(estimator, check, strict_mode): +def _should_be_skipped_or_marked(estimator, check, api_only): # Return whether a check should be skipped (when using check_estimator()) # or marked as XFAIL (when using @parametrize_with_checks()), along with a # reason. # A check should be skipped or marked if either: # - the check is in the _xfail_checks tag of the estimator - # - the check is fully strict and strict mode is off - # Checks that are only partially strict will not be skipped since we want - # to run their non-strict parts. + # - the check is not an API check and api_only is True + # Checks that are a mix of API and non-API checks will not be skipped since + # we want to run their API-checking parts. check_name = (check.func.__name__ if isinstance(check, partial) else check.__name__) - xfail_checks = estimator._get_tags()['_xfail_checks'] or {} + xfail_checks = _safe_tags(estimator, key='_xfail_checks') or {} if check_name in xfail_checks: return True, xfail_checks[check_name] - if check_name in _FULLY_STRICT_CHECKS and not strict_mode: - return True, f'{check_name} is fully strict and strict mode is off' + if check_name in _NON_API_CHECKS and api_only: + return True, f'{check_name} is not an API check and api_only is True.' return False, 'placeholder reason that will never be used' -def parametrize_with_checks(estimators, strict_mode=True): +def parametrize_with_checks(estimators, api_only=False): """Pytest specific decorator for parametrizing estimator checks. The `id` of each check is set to be a pprint version of the estimator @@ -449,18 +479,18 @@ def parametrize_with_checks(estimators, strict_mode=True): Passing a class was deprecated in version 0.23, and support for classes was removed in 0.24. Pass an instance instead. - strict_mode : bool, default=True - If True, the full check suite is run. - If False, only the non-strict part of the check suite is run. + api_only : bool, default=False + If True, the check suite will only ensure pure API-compatibility, and + will ignore other checks like controlling error messages or + prediction performance on easy datasets. + By default, the entire check suite is run. - In non-strict mode, some checks will be easier to pass: e.g., they - will only make sure an error is raised instead of also checking the - full error message. - Some checks are considered completely strict, in which case they are - treated as if they were in the estimators' `_xfails_checks` tag: they - will be marked as `xfail` for pytest. See :ref:`estimator_tags` for - more info on the `_xfails_check` tag. The set of strict checks is in - `sklearn.utils.estimator_checks._FULLY_STRICT_CHECKS`. + When True, some checks will be easier to pass. Some other checks will + be treated as if they were in the estimators' `_xfails_checks` tag: + they will be marked as `xfail` for pytest, but they will still be + run. If they pass, pytest will label them as `xpass`. These checks + are in `sklearn.utils.estimator_checks._NON_API_CHECKS`. See + :ref:`estimator_tags` for more info on the `_xfails_check` tag. .. versionadded:: 0.24 @@ -492,14 +522,14 @@ def checks_generator(): for estimator in estimators: name = type(estimator).__name__ for check in _yield_all_checks(estimator): - check = partial(check, name, strict_mode=strict_mode) - yield _maybe_mark_xfail(estimator, check, strict_mode, pytest) + check = partial(check, name, api_only=api_only) + yield _maybe_mark_xfail(estimator, check, api_only, pytest) return pytest.mark.parametrize("estimator, check", checks_generator(), ids=_get_check_estimator_ids) -def check_estimator(Estimator, generate_only=False, strict_mode=True): +def check_estimator(Estimator, generate_only=False, api_only=False): """Check if estimator adheres to scikit-learn conventions. This estimator will run an extensive test-suite for input validation, @@ -535,18 +565,17 @@ def check_estimator(Estimator, generate_only=False, strict_mode=True): .. versionadded:: 0.22 - strict_mode : bool, default=True - If True, the full check suite is run. - If False, only the non-strict part of the check suite is run. + api_only : bool, default=False + If True, the check suite will only ensure pure API-compatibility, and + will ignore other checks like controlling error messages or + prediction performance on easy datasets. + By default, the entire check suite is run. - In non-strict mode, some checks will be easier to pass: e.g., they - will only make sure an error is raised instead of also checking the - full error message. - Some checks are considered completely strict, in which case they are - treated as if they were in the estimators' `_xfails_checks` tag: they - will be ignored with a warning. See :ref:`estimator_tags` for more - info on the `_xfails_check` tag. The set of strict checks is in - `sklearn.utils.estimator_checks._FULLY_STRICT_CHECKS`. + When True, some checks will be easier to pass. Some other checks will + be treated as if they were in the estimators' `_xfails_checks` tag: + they will be ignored with a warning. These checks are in + `sklearn.utils.estimator_checks._NON_API_CHECKS`. See + :ref:`estimator_tags` for more info on the `_xfails_check` tag. .. versionadded:: 0.24 @@ -567,8 +596,8 @@ def check_estimator(Estimator, generate_only=False, strict_mode=True): def checks_generator(): for check in _yield_all_checks(estimator): - check = _maybe_skip(estimator, check, strict_mode) - yield estimator, partial(check, name, strict_mode=strict_mode) + check = _maybe_skip(estimator, check, api_only) + yield estimator, partial(check, name, api_only=api_only) if generate_only: return checks_generator() @@ -761,7 +790,9 @@ def _generate_sparse_matrix(X_csr): yield sparse_format + "_64", X -def check_estimator_sparse_data(name, estimator_orig, strict_mode=True): +def check_estimator_sparse_data(name, estimator_orig, api_only=False): + # Make sure that the estimator either accepts sparse data in fit and + # predict, or that it fails with a helpful error message. rng = np.random.RandomState(0) X = rng.rand(40, 10) X[X < .8] = 0 @@ -772,7 +803,7 @@ def check_estimator_sparse_data(name, estimator_orig, strict_mode=True): with ignore_warnings(category=FutureWarning): estimator = clone(estimator_orig) y = _enforce_estimator_tags_y(estimator, y) - tags = estimator_orig._get_tags() + tags = _safe_tags(estimator_orig) for matrix_format, X in _generate_sparse_matrix(X_csr): # catch deprecation warnings with ignore_warnings(category=FutureWarning): @@ -816,7 +847,7 @@ def check_estimator_sparse_data(name, estimator_orig, strict_mode=True): @ignore_warnings(category=FutureWarning) -def check_sample_weights_pandas_series(name, estimator_orig, strict_mode=True): +def check_sample_weights_pandas_series(name, estimator_orig, api_only=False): # check that estimators will accept a 'sample_weight' parameter of # type pandas.Series in the 'fit' function. estimator = clone(estimator_orig) @@ -829,7 +860,7 @@ def check_sample_weights_pandas_series(name, estimator_orig, strict_mode=True): X = pd.DataFrame(_pairwise_estimator_convert_X(X, estimator_orig)) y = pd.Series([1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 2, 2]) weights = pd.Series([1] * 12) - if estimator._get_tags()["multioutput_only"]: + if _safe_tags(estimator, key="multioutput_only"): y = pd.DataFrame(y) try: estimator.fit(X, y, sample_weight=weights) @@ -843,7 +874,7 @@ def check_sample_weights_pandas_series(name, estimator_orig, strict_mode=True): @ignore_warnings(category=(FutureWarning)) -def check_sample_weights_not_an_array(name, estimator_orig, strict_mode=True): +def check_sample_weights_not_an_array(name, estimator_orig, api_only=False): # check that estimators will accept a 'sample_weight' parameter of # type _NotAnArray in the 'fit' function. estimator = clone(estimator_orig) @@ -854,13 +885,13 @@ def check_sample_weights_not_an_array(name, estimator_orig, strict_mode=True): X = _NotAnArray(_pairwise_estimator_convert_X(X, estimator_orig)) y = _NotAnArray([1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 2, 2]) weights = _NotAnArray([1] * 12) - if estimator._get_tags()["multioutput_only"]: + if _safe_tags(estimator, key="multioutput_only"): y = _NotAnArray(y.data.reshape(-1, 1)) estimator.fit(X, y, sample_weight=weights) @ignore_warnings(category=(FutureWarning)) -def check_sample_weights_list(name, estimator_orig, strict_mode=True): +def check_sample_weights_list(name, estimator_orig, api_only=False): # check that estimators will accept a 'sample_weight' parameter of # type list in the 'fit' function. if has_fit_parameter(estimator_orig, "sample_weight"): @@ -877,7 +908,7 @@ def check_sample_weights_list(name, estimator_orig, strict_mode=True): @ignore_warnings(category=FutureWarning) -def check_sample_weights_shape(name, estimator_orig, strict_mode=True): +def check_sample_weights_shape(name, estimator_orig, api_only=False): # check that estimators raise an error if sample_weight # shape mismatches the input if (has_fit_parameter(estimator_orig, "sample_weight") and @@ -902,7 +933,7 @@ def check_sample_weights_shape(name, estimator_orig, strict_mode=True): @ignore_warnings(category=FutureWarning) def check_sample_weights_invariance(name, estimator_orig, kind="ones", - strict_mode=True): + api_only=False): # For kind="ones" check that the estimators yield same results for # unit weights and no weights # For kind="zeros" check that setting sample_weight to 0 is equivalent @@ -954,12 +985,12 @@ def check_sample_weights_invariance(name, estimator_orig, kind="ones", @ignore_warnings(category=(FutureWarning, UserWarning)) -def check_dtype_object(name, estimator_orig, strict_mode=True): +def check_dtype_object(name, estimator_orig, api_only=False): # check that estimators treat dtype object as numeric if possible rng = np.random.RandomState(0) X = _pairwise_estimator_convert_X(rng.rand(40, 10), estimator_orig) X = X.astype(object) - tags = estimator_orig._get_tags() + tags = _safe_tags(estimator_orig) y = (X[:, 0] * 4).astype(int) estimator = clone(estimator_orig) y = _enforce_estimator_tags_y(estimator, y) @@ -976,8 +1007,8 @@ def check_dtype_object(name, estimator_orig, strict_mode=True): if 'string' not in tags['X_types']: X[0, 0] = {'foo': 'bar'} - msg = "argument must be a string.* number" - with raises(TypeError, match=msg): + match = None if api_only else "argument must be a string.* number" + with raises(TypeError, match=match): estimator.fit(X, y) else: # Estimators supporting string will not call np.asarray to convert the @@ -987,7 +1018,7 @@ def check_dtype_object(name, estimator_orig, strict_mode=True): estimator.fit(X, y) -def check_complex_data(name, estimator_orig, strict_mode=True): +def check_complex_data(name, estimator_orig, api_only=False): # check that estimators raise an exception on providing complex data X = np.random.sample(10) + 1j * np.random.sample(10) X = X.reshape(-1, 1) @@ -998,13 +1029,9 @@ def check_complex_data(name, estimator_orig, strict_mode=True): @ignore_warnings -def check_dict_unchanged(name, estimator_orig, strict_mode=True): - # this estimator raises - # ValueError: Found array with 0 feature(s) (shape=(23, 0)) - # while a minimum of 1 is required. - # error - if name in ['SpectralCoclustering']: - return +def check_dict_unchanged(name, estimator_orig, api_only=False): + # check that calling the prediction method does not alter the __dict__ + # attribute of the estimator. rnd = np.random.RandomState(0) if name in ['RANSACRegressor']: X = 3 * rnd.uniform(size=(20, 3)) @@ -1042,7 +1069,7 @@ def _is_public_parameter(attr): @ignore_warnings(category=FutureWarning) -def check_dont_overwrite_parameters(name, estimator_orig, strict_mode=True): +def check_dont_overwrite_parameters(name, estimator_orig, api_only=False): # check that fit method only changes or sets private attributes if hasattr(estimator_orig.__init__, "deprecated_original"): # to not check deprecated classes @@ -1094,8 +1121,8 @@ def check_dont_overwrite_parameters(name, estimator_orig, strict_mode=True): @ignore_warnings(category=FutureWarning) -def check_fit2d_predict1d(name, estimator_orig, strict_mode=True): - # check by fitting a 2d array and predicting with a 1d array +def check_fit2d_predict1d(name, estimator_orig, api_only=False): + # check that predicting with a 1d array raises an error rnd = np.random.RandomState(0) X = 3 * rnd.uniform(size=(20, 3)) X = _pairwise_estimator_convert_X(X, estimator_orig) @@ -1114,8 +1141,9 @@ def check_fit2d_predict1d(name, estimator_orig, strict_mode=True): for method in ["predict", "transform", "decision_function", "predict_proba"]: if hasattr(estimator, method): - assert_raise_message(ValueError, "Reshape your data", - getattr(estimator, method), X[0]) + match = None if api_only else "Reshape your data" + with raises(ValueError, match=match): + getattr(estimator, method)(X[0]) def _apply_on_subsets(func, X): @@ -1138,7 +1166,7 @@ def _apply_on_subsets(func, X): @ignore_warnings(category=FutureWarning) -def check_methods_subset_invariance(name, estimator_orig, strict_mode=True): +def check_methods_subset_invariance(name, estimator_orig, api_only=False): # check that method gives invariant results if applied # on mini batches or the whole set rnd = np.random.RandomState(0) @@ -1171,7 +1199,7 @@ def check_methods_subset_invariance(name, estimator_orig, strict_mode=True): @ignore_warnings(category=FutureWarning) def check_methods_sample_order_invariance( - name, estimator_orig, strict_mode=True + name, estimator_orig, api_only=False ): # check that method gives invariant results if applied # on a subset with different sample order @@ -1179,7 +1207,7 @@ def check_methods_sample_order_invariance( X = 3 * rnd.uniform(size=(20, 3)) X = _pairwise_estimator_convert_X(X, estimator_orig) y = X[:, 0].astype(np.int64) - if estimator_orig._get_tags()['binary_only']: + if _safe_tags(estimator_orig, key='binary_only'): y[y == 2] = 1 estimator = clone(estimator_orig) y = _enforce_estimator_tags_y(estimator, y) @@ -1207,7 +1235,7 @@ def check_methods_sample_order_invariance( @ignore_warnings -def check_fit2d_1sample(name, estimator_orig, strict_mode=True): +def check_fit2d_1sample(name, estimator_orig, api_only=False): # Check that fitting a 2d array with only one sample either works or # returns an informative message. The error message should either mention # the number of samples or the number of classes. @@ -1238,7 +1266,7 @@ def check_fit2d_1sample(name, estimator_orig, strict_mode=True): @ignore_warnings -def check_fit2d_1feature(name, estimator_orig, strict_mode=True): +def check_fit2d_1feature(name, estimator_orig, api_only=False): # check fitting a 2d array with only 1 feature either works or returns # informative message rnd = np.random.RandomState(0) @@ -1269,7 +1297,7 @@ def check_fit2d_1feature(name, estimator_orig, strict_mode=True): @ignore_warnings -def check_fit1d(name, estimator_orig, strict_mode=True): +def check_fit1d(name, estimator_orig, api_only=False): # check fitting 1d X array raises a ValueError rnd = np.random.RandomState(0) X = 3 * rnd.uniform(size=(20)) @@ -1289,7 +1317,7 @@ def check_fit1d(name, estimator_orig, strict_mode=True): @ignore_warnings(category=FutureWarning) def check_transformer_general(name, transformer, readonly_memmap=False, - strict_mode=True): + api_only=False): X, y = make_blobs(n_samples=30, centers=[[0, 0, 0], [1, 1, 1]], random_state=0, n_features=2, cluster_std=0.1) X = StandardScaler().fit_transform(X) @@ -1299,11 +1327,11 @@ def check_transformer_general(name, transformer, readonly_memmap=False, if readonly_memmap: X, y = create_memmap_backed_data([X, y]) - _check_transformer(name, transformer, X, y) + _check_transformer(name, transformer, X, y, api_only=api_only) @ignore_warnings(category=FutureWarning) -def check_transformer_data_not_an_array(name, transformer, strict_mode=True): +def check_transformer_data_not_an_array(name, transformer, api_only=False): X, y = make_blobs(n_samples=30, centers=[[0, 0, 0], [1, 1, 1]], random_state=0, n_features=2, cluster_std=0.1) X = StandardScaler().fit_transform(X) @@ -1313,13 +1341,16 @@ def check_transformer_data_not_an_array(name, transformer, strict_mode=True): X = _pairwise_estimator_convert_X(X, transformer) this_X = _NotAnArray(X) this_y = _NotAnArray(np.asarray(y)) - _check_transformer(name, transformer, this_X, this_y) + _check_transformer(name, transformer, this_X, this_y, api_only=api_only) # try the same with some list - _check_transformer(name, transformer, X.tolist(), y.tolist()) + _check_transformer(name, transformer, X.tolist(), y.tolist(), + api_only=api_only) @ignore_warnings(category=FutureWarning) -def check_transformers_unfitted(name, transformer, strict_mode=True): +def check_transformers_unfitted(name, transformer, api_only=False): + # Make sure the unfitted transformer raises an error when transform is + # called X, y = _regression_dataset() transformer = clone(transformer) @@ -1333,7 +1364,13 @@ def check_transformers_unfitted(name, transformer, strict_mode=True): transformer.transform(X) -def _check_transformer(name, transformer_orig, X, y, strict_mode=True): +def _check_transformer(name, transformer_orig, X, y, api_only=False): + # Check that: + # - fit_transform returns n_samples transformed samples + # - an error is raised if transform is called with an incorrect number of + # features + # - fit_transform and transform give equivalent results. + # - fit_transform gives the same results twice n_samples, n_features = np.asarray(X).shape transformer = clone(transformer_orig) set_random_state(transformer) @@ -1354,6 +1391,7 @@ def _check_transformer(name, transformer_orig, X, y, strict_mode=True): X_pred = transformer_clone.fit_transform(X, y=y_) if isinstance(X_pred, tuple): + # for cross-decomposition estimators that transform both X and y for x_pred in X_pred: assert x_pred.shape[0] == n_samples else: @@ -1361,6 +1399,23 @@ def _check_transformer(name, transformer_orig, X, y, strict_mode=True): assert X_pred.shape[0] == n_samples if hasattr(transformer, 'transform'): + + # raises error on malformed input for transform + if hasattr(X, 'shape') and \ + not _safe_tags(transformer, key="stateless") and \ + X.ndim == 2 and X.shape[1] > 1: + + with raises( + ValueError, + err_msg=f"The transformer {name} does not raise an error " + "when the number of features in transform is different from " + "the number of features in fit." + ): + transformer.transform(X[:, :-1]) + if api_only: + # The remaining asserts are non-API asserts + return + if name in CROSS_DECOMPOSITION: X_pred2 = transformer.transform(X, y_) X_pred3 = transformer.fit_transform(X, y=y_) @@ -1368,7 +1423,7 @@ def _check_transformer(name, transformer_orig, X, y, strict_mode=True): X_pred2 = transformer.transform(X) X_pred3 = transformer.fit_transform(X, y=y_) - if transformer_orig._get_tags()['non_deterministic']: + if _safe_tags(transformer_orig, key='non_deterministic'): msg = name + ' is non deterministic' raise SkipTest(msg) if isinstance(X_pred, tuple) and isinstance(X_pred2, tuple): @@ -1397,28 +1452,16 @@ def _check_transformer(name, transformer_orig, X, y, strict_mode=True): assert _num_samples(X_pred2) == n_samples assert _num_samples(X_pred3) == n_samples - # raises error on malformed input for transform - if hasattr(X, 'shape') and \ - not transformer._get_tags()["stateless"] and \ - X.ndim == 2 and X.shape[1] > 1: - - # If it's not an array, it does not have a 'T' property - with raises( - ValueError, - err_msg=f"The transformer {name} does not raise an error " - "when the number of features in transform is different from " - "the number of features in fit." - ): - transformer.transform(X[:, :-1]) - @ignore_warnings -def check_pipeline_consistency(name, estimator_orig, strict_mode=True): - if estimator_orig._get_tags()['non_deterministic']: +def check_pipeline_consistency(name, estimator_orig, api_only=False): + # check that make_pipeline(est) gives results as est for scores and + # transforms + + if _safe_tags(estimator_orig, key='non_deterministic'): msg = name + ' is non deterministic' raise SkipTest(msg) - # check that make_pipeline(est) gives same score as est X, y = make_blobs(n_samples=30, centers=[[0, 0, 0], [1, 1, 1]], random_state=0, n_features=2, cluster_std=0.1) X -= X.min() @@ -1442,7 +1485,7 @@ def check_pipeline_consistency(name, estimator_orig, strict_mode=True): @ignore_warnings -def check_fit_score_takes_y(name, estimator_orig, strict_mode=True): +def check_fit_score_takes_y(name, estimator_orig, api_only=False): # check that all estimators accept an optional y # in fit and score so they can be used in pipelines rnd = np.random.RandomState(0) @@ -1471,7 +1514,8 @@ def check_fit_score_takes_y(name, estimator_orig, strict_mode=True): @ignore_warnings -def check_estimators_dtypes(name, estimator_orig, strict_mode=True): +def check_estimators_dtypes(name, estimator_orig, api_only=False): + # Check that methods can handle X input of different float and int dtypes rnd = np.random.RandomState(0) X_train_32 = 3 * rnd.uniform(size=(20, 5)).astype(np.float32) X_train_32 = _pairwise_estimator_convert_X(X_train_32, estimator_orig) @@ -1494,7 +1538,7 @@ def check_estimators_dtypes(name, estimator_orig, strict_mode=True): def check_transformer_preserve_dtypes( - name, transformer_orig, strict_mode=True + name, transformer_orig, api_only=False ): # check that dtype are preserved meaning if input X is of some dtype # X_transformed should be from the same dtype. @@ -1508,7 +1552,7 @@ def check_transformer_preserve_dtypes( X -= X.min() X = _pairwise_estimator_convert_X(X, transformer_orig) - for dtype in transformer_orig._get_tags()["preserves_dtype"]: + for dtype in _safe_tags(transformer_orig, key="preserves_dtype"): X_cast = X.astype(dtype) transformer = clone(transformer_orig) set_random_state(transformer) @@ -1528,7 +1572,9 @@ def check_transformer_preserve_dtypes( @ignore_warnings(category=FutureWarning) def check_estimators_empty_data_messages(name, estimator_orig, - strict_mode=True): + api_only=False): + # Make sure that a ValueError is raised when fit is called on data with no + # sample or no features. e = clone(estimator_orig) set_random_state(e, 1) @@ -1551,14 +1597,15 @@ def check_estimators_empty_data_messages(name, estimator_orig, msg = ( r"0 feature\(s\) \(shape=\(\d*, 0\)\) while a minimum of \d* " "is required." - ) + ) if not api_only else None with raises(ValueError, match=msg): e.fit(X_zero_features, y) @ignore_warnings(category=FutureWarning) -def check_estimators_nan_inf(name, estimator_orig, strict_mode=True): - # Checks that Estimator X's do not contain NaN or inf. +def check_estimators_nan_inf(name, estimator_orig, api_only=False): + # Checks that fit, predict and transform raise an error if X contains nans + # or inf. rnd = np.random.RandomState(0) X_train_finite = _pairwise_estimator_convert_X(rnd.uniform(size=(10, 3)), estimator_orig) @@ -1607,8 +1654,9 @@ def check_estimators_nan_inf(name, estimator_orig, strict_mode=True): @ignore_warnings -def check_nonsquare_error(name, estimator_orig, strict_mode=True): - """Test that error is thrown when non-square data provided.""" +def check_nonsquare_error(name, estimator_orig, api_only=False): + # Check that error is raised when non-square data is provided in fit for a + # pairwise estimator X, y = make_blobs(n_samples=20, n_features=10) estimator = clone(estimator_orig) @@ -1622,8 +1670,9 @@ def check_nonsquare_error(name, estimator_orig, strict_mode=True): @ignore_warnings -def check_estimators_pickle(name, estimator_orig, strict_mode=True): - """Test that we can pickle all estimators.""" +def check_estimators_pickle(name, estimator_orig, api_only=False): + # Test that we can pickle all estimators and that the pickled estimator + # gives the same predictions check_methods = ["predict", "transform", "decision_function", "predict_proba"] @@ -1634,7 +1683,7 @@ def check_estimators_pickle(name, estimator_orig, strict_mode=True): X -= X.min() X = _pairwise_estimator_convert_X(X, estimator_orig, kernel=rbf_kernel) - tags = estimator_orig._get_tags() + tags = _safe_tags(estimator_orig) # include NaN values when the estimator should deal with them if tags['allow_nan']: # set randomly 10 elements to np.nan @@ -1651,7 +1700,8 @@ def check_estimators_pickle(name, estimator_orig, strict_mode=True): # pickle and unpickle! pickled_estimator = pickle.dumps(estimator) - if estimator.__module__.startswith('sklearn.'): + module_name = estimator.__module__ + if module_name.startswith('sklearn.') and "test_" not in module_name: assert b"version" in pickled_estimator unpickled_estimator = pickle.loads(pickled_estimator) @@ -1667,8 +1717,9 @@ def check_estimators_pickle(name, estimator_orig, strict_mode=True): @ignore_warnings(category=FutureWarning) def check_estimators_partial_fit_n_features(name, estimator_orig, - strict_mode=True): - # check if number of features changes between calls to partial_fit. + api_only=False): + # check that an error is raised when number of features changes between + # calls to partial_fit. if not hasattr(estimator_orig, 'partial_fit'): return estimator = clone(estimator_orig) @@ -1694,9 +1745,13 @@ def check_estimators_partial_fit_n_features(name, estimator_orig, @ignore_warnings(category=FutureWarning) -def check_classifier_multioutput(name, estimator, strict_mode=True): +def check_classifier_multioutput(name, estimator, api_only=False): + # Make sure that the output of predict_proba and decision_function is + # correct for multiouput classification (multilabel, multiclass). Also + # checks that predict_proba and decision_function have consistent + # predictions, i.e. the orders are consistent. n_samples, n_labels, n_classes = 42, 5, 3 - tags = estimator._get_tags() + tags = _safe_tags(estimator) estimator = clone(estimator) X, y = make_multilabel_classification(random_state=42, n_samples=n_samples, @@ -1719,9 +1774,10 @@ def check_classifier_multioutput(name, estimator, strict_mode=True): "multioutput data is incorrect. Expected {}, got {}." .format((n_samples, n_classes), decision.shape)) - dec_pred = (decision > 0).astype(int) - dec_exp = estimator.classes_[dec_pred] - assert_array_equal(dec_exp, y_pred) + if not api_only: + dec_pred = (decision > 0).astype(int) + dec_exp = estimator.classes_[dec_pred] + assert_array_equal(dec_exp, y_pred) if hasattr(estimator, "predict_proba"): y_prob = estimator.predict_proba(X) @@ -1732,16 +1788,22 @@ def check_classifier_multioutput(name, estimator, strict_mode=True): "The shape of the probability for multioutput data is" " incorrect. Expected {}, got {}." .format((n_samples, 2), y_prob[i].shape)) - assert_array_equal( - np.argmax(y_prob[i], axis=1).astype(int), - y_pred[:, i] - ) + if not api_only: + assert_array_equal( + np.argmax(y_prob[i], axis=1).astype(int), + y_pred[:, i] + ) elif not tags['poor_score']: assert y_prob.shape == (n_samples, n_classes), ( "The shape of the probability for multioutput data is" " incorrect. Expected {}, got {}." .format((n_samples, n_classes), y_prob.shape)) - assert_array_equal(y_prob.round().astype(int), y_pred) + if not api_only: + assert_array_equal(y_prob.round().astype(int), y_pred) + + if api_only: + # The remaining asserts are non-API asserts + return if (hasattr(estimator, "decision_function") and hasattr(estimator, "predict_proba")): @@ -1752,7 +1814,9 @@ def check_classifier_multioutput(name, estimator, strict_mode=True): @ignore_warnings(category=FutureWarning) -def check_regressor_multioutput(name, estimator, strict_mode=True): +def check_regressor_multioutput(name, estimator, api_only=False): + # Make sure that multioutput regressors output float64 predictions and that + # the shape is correct. estimator = clone(estimator) n_samples = n_features = 10 @@ -1766,9 +1830,10 @@ def check_regressor_multioutput(name, estimator, strict_mode=True): estimator.fit(X, y) y_pred = estimator.predict(X) - assert y_pred.dtype == np.dtype('float64'), ( - "Multioutput predictions by a regressor are expected to be" - " floating-point precision. Got {} instead".format(y_pred.dtype)) + if not api_only: + assert y_pred.dtype == np.dtype('float64'), ( + "Multioutput predictions by a regressor are expected to be" + " floating-point precision. Got {} instead".format(y_pred.dtype)) assert y_pred.shape == y.shape, ( "The shape of the prediction for multioutput data is incorrect." " Expected {}, got {}.") @@ -1776,7 +1841,7 @@ def check_regressor_multioutput(name, estimator, strict_mode=True): @ignore_warnings(category=FutureWarning) def check_clustering(name, clusterer_orig, readonly_memmap=False, - strict_mode=True): + api_only=False): clusterer = clone(clusterer_orig) X, y = make_blobs(n_samples=50, random_state=1) X, y = shuffle(X, y, random_state=7) @@ -1803,8 +1868,14 @@ def check_clustering(name, clusterer_orig, readonly_memmap=False, pred = clusterer.labels_ assert pred.shape == (n_samples,) + + if api_only: + # The remaining asserts are non-API asserts + + return + assert adjusted_rand_score(pred, y) > 0.4 - if clusterer._get_tags()['non_deterministic']: + if _safe_tags(clusterer, key='non_deterministic'): return set_random_state(clusterer) with warnings.catch_warnings(record=True): @@ -1836,8 +1907,8 @@ def check_clustering(name, clusterer_orig, readonly_memmap=False, @ignore_warnings(category=FutureWarning) def check_clusterer_compute_labels_predict(name, clusterer_orig, - strict_mode=True): - """Check that predict is invariant of compute_labels.""" + api_only=False): + # Check that predict is invariant of compute_labels X, y = make_blobs(n_samples=20, random_state=0) clusterer = clone(clusterer_orig) set_random_state(clusterer) @@ -1851,7 +1922,10 @@ def check_clusterer_compute_labels_predict(name, clusterer_orig, @ignore_warnings(category=FutureWarning) -def check_classifiers_one_label(name, classifier_orig, strict_mode=True): +def check_classifiers_one_label(name, classifier_orig, api_only=False): + # Check that a classifier can fit when there's only 1 class, or that it + # raises a proper error. If it can fit, we also make sure that it can + # predict. error_string_fit = "Classifier can't train when only one class is present." error_string_predict = ("Classifier can't predict when only one class is " "present.") @@ -1878,7 +1952,7 @@ def check_classifiers_one_label(name, classifier_orig, strict_mode=True): @ignore_warnings # Warnings are raised by decision function def check_classifiers_train(name, classifier_orig, readonly_memmap=False, - X_dtype='float64', strict_mode=True): + X_dtype='float64', api_only=False): X_m, y_m = make_blobs(n_samples=300, random_state=0) X_m = X_m.astype(X_dtype) X_m, y_m = shuffle(X_m, y_m, random_state=7) @@ -1896,7 +1970,7 @@ def check_classifiers_train(name, classifier_orig, readonly_memmap=False, X_m, y_m, X_b, y_b = create_memmap_backed_data([X_m, y_m, X_b, y_b]) problems = [(X_b, y_b)] - tags = classifier_orig._get_tags() + tags = _safe_tags(classifier_orig) if not tags['binary_only']: problems.append((X_m, y_m)) @@ -1929,7 +2003,7 @@ def check_classifiers_train(name, classifier_orig, readonly_memmap=False, assert y_pred.shape == (n_samples,) # training set performance - if not tags['poor_score']: + if not tags['poor_score'] and not api_only: assert accuracy_score(y, y_pred) > 0.83 # raises error on malformed input for predict @@ -1959,11 +2033,13 @@ def check_classifiers_train(name, classifier_orig, readonly_memmap=False, assert decision.shape == (n_samples,) else: assert decision.shape == (n_samples, 1) - dec_pred = (decision.ravel() > 0).astype(int) - assert_array_equal(dec_pred, y_pred) + if not api_only: + dec_pred = (decision.ravel() > 0).astype(int) + assert_array_equal(dec_pred, y_pred) else: assert decision.shape == (n_samples, n_classes) - assert_array_equal(np.argmax(decision, axis=1), y_pred) + if not api_only: + assert_array_equal(np.argmax(decision, axis=1), y_pred) # raises error on malformed input for decision_function if not tags["no_validation"]: @@ -1988,10 +2064,11 @@ def check_classifiers_train(name, classifier_orig, readonly_memmap=False, # predict_proba agrees with predict y_prob = classifier.predict_proba(X) assert y_prob.shape == (n_samples, n_classes) - assert_array_equal(np.argmax(y_prob, axis=1), y_pred) - # check that probas for all classes sum to one - assert_array_almost_equal(np.sum(y_prob, axis=1), - np.ones(n_samples)) + if not api_only: + assert_array_equal(np.argmax(y_prob, axis=1), y_pred) + # check that probas for all classes sum to one + assert_array_almost_equal(np.sum(y_prob, axis=1), + np.ones(n_samples)) if not tags["no_validation"]: # raises error on malformed input for predict_proba if _is_pairwise(classifier_orig): @@ -2006,7 +2083,7 @@ def check_classifiers_train(name, classifier_orig, readonly_memmap=False, err_msg=msg.format(name, "predict_proba"), ): classifier.predict_proba(X.T) - if hasattr(classifier, "predict_log_proba"): + if hasattr(classifier, "predict_log_proba") and not api_only: # predict_log_proba is a transformation of predict_proba y_log_prob = classifier.predict_log_proba(X) assert_allclose(y_log_prob, np.log(y_prob), 8, atol=1e-9) @@ -2014,7 +2091,7 @@ def check_classifiers_train(name, classifier_orig, readonly_memmap=False, def check_outlier_corruption(num_outliers, expected_outliers, decision, - strict_mode=True): + api_only=False): # Check for deviation from the precise given contamination level that may # be due to ties in the anomaly scores. if num_outliers < expected_outliers: @@ -2035,7 +2112,7 @@ def check_outlier_corruption(num_outliers, expected_outliers, decision, def check_outliers_train(name, estimator_orig, readonly_memmap=True, - strict_mode=True): + api_only=False): n_samples = 300 X, _ = make_blobs(n_samples=n_samples, random_state=0) X = shuffle(X, random_state=7) @@ -2068,9 +2145,10 @@ def check_outliers_train(name, estimator_orig, readonly_memmap=True, estimator.predict(X.T) # decision_function agrees with predict - dec_pred = (decision >= 0).astype(int) - dec_pred[dec_pred == 0] = -1 - assert_array_equal(dec_pred, y_pred) + if not api_only: + dec_pred = (decision >= 0).astype(int) + dec_pred[dec_pred == 0] = -1 + assert_array_equal(dec_pred, y_pred) # raises error on malformed input for decision_function with raises(ValueError): @@ -2086,7 +2164,8 @@ def check_outliers_train(name, estimator_orig, readonly_memmap=True, # contamination parameter (not for OneClassSVM which has the nu parameter) if (hasattr(estimator, 'contamination') - and not hasattr(estimator, 'novelty')): + and not hasattr(estimator, 'novelty') + and not api_only): # proportion of outliers equal to contamination parameter when not # set to 'auto'. This is true for the training set and cannot thus be # checked as follows for estimators with a novelty parameter such as @@ -2116,7 +2195,8 @@ def check_outliers_train(name, estimator_orig, readonly_memmap=True, @ignore_warnings(category=(FutureWarning)) def check_classifiers_multilabel_representation_invariance( - name, classifier_orig, strict_mode=True): + name, classifier_orig, api_only=False): + # check different target representations for multilabel classifiers X, y = make_multilabel_classification(n_samples=100, n_features=20, n_classes=5, n_labels=3, @@ -2151,8 +2231,8 @@ def check_classifiers_multilabel_representation_invariance( @ignore_warnings(category=FutureWarning) def check_estimators_fit_returns_self(name, estimator_orig, - readonly_memmap=False, strict_mode=True): - """Check if self is returned when calling fit.""" + readonly_memmap=False, api_only=False): + # Check that self is returned when calling fit. X, y = make_blobs(random_state=0, n_samples=21) # some want non-negative input X -= X.min() @@ -2169,11 +2249,9 @@ def check_estimators_fit_returns_self(name, estimator_orig, @ignore_warnings -def check_estimators_unfitted(name, estimator_orig, strict_mode=True): - """Check that predict raises an exception in an unfitted estimator. - - Unfitted estimators should raise a NotFittedError. - """ +def check_estimators_unfitted(name, estimator_orig, api_only=False): + # Check that predict raises an exception in an unfitted estimator. + # Unfitted estimators should raise a NotFittedError. # Common test for Regressors, Classifiers and Outlier detection estimators X, y = _regression_dataset() @@ -2186,8 +2264,10 @@ def check_estimators_unfitted(name, estimator_orig, strict_mode=True): @ignore_warnings(category=FutureWarning) -def check_supervised_y_2d(name, estimator_orig, strict_mode=True): - tags = estimator_orig._get_tags() +def check_supervised_y_2d(name, estimator_orig, api_only=False): + # Check that estimators that don't support multi-ouput raise a warning if y + # is not 1d, and that they just ravel y + tags = _safe_tags(estimator_orig) rnd = np.random.RandomState(0) n_samples = 30 X = _pairwise_estimator_convert_X( @@ -2221,7 +2301,7 @@ def check_supervised_y_2d(name, estimator_orig, strict_mode=True): @ignore_warnings def check_classifiers_predictions(X, y, name, classifier_orig, - strict_mode=True): + api_only=False): classes = np.unique(y) classifier = clone(classifier_orig) if name == 'BernoulliNB': @@ -2251,8 +2331,7 @@ def check_classifiers_predictions(X, y, name, classifier_orig, (classifier, ", ".join(map(str, y_exp)), ", ".join(map(str, y_pred)))) - # training set performance - if name != "ComplementNB": + if not api_only and name != "ComplementNB": # This is a pathological data set for ComplementNB. # For some specific cases 'ComplementNB' predicts less classes # than expected @@ -2272,7 +2351,9 @@ def _choose_check_classifiers_labels(name, y, y_names): "SelfTrainingClassifier"] else y_names -def check_classifiers_classes(name, classifier_orig, strict_mode=True): +def check_classifiers_classes(name, classifier_orig, api_only=False): + # Check that decision function > 0 => pos class + # Also checks the classes_ attribute. X_multiclass, y_multiclass = make_blobs(n_samples=30, random_state=0, cluster_std=0.1) X_multiclass, y_multiclass = shuffle(X_multiclass, y_multiclass, @@ -2295,22 +2376,28 @@ def check_classifiers_classes(name, classifier_orig, strict_mode=True): y_names_binary = np.take(labels_binary, y_binary) problems = [(X_binary, y_binary, y_names_binary)] - if not classifier_orig._get_tags()['binary_only']: + if not _safe_tags(classifier_orig, key='binary_only'): problems.append((X_multiclass, y_multiclass, y_names_multiclass)) for X, y, y_names in problems: for y_names_i in [y_names, y_names.astype('O')]: y_ = _choose_check_classifiers_labels(name, y, y_names_i) - check_classifiers_predictions(X, y_, name, classifier_orig) + check_classifiers_predictions( + X, y_, name, classifier_orig, api_only + ) labels_binary = [-1, 1] y_names_binary = np.take(labels_binary, y_binary) y_binary = _choose_check_classifiers_labels(name, y_binary, y_names_binary) - check_classifiers_predictions(X_binary, y_binary, name, classifier_orig) + check_classifiers_predictions( + X_binary, y_binary, name, classifier_orig, api_only + ) @ignore_warnings(category=FutureWarning) -def check_regressors_int(name, regressor_orig, strict_mode=True): +def check_regressors_int(name, regressor_orig, api_only=False): + # Check that regressors give same prediction when y is encoded as int or + # float X, _ = _regression_dataset() X = _pairwise_estimator_convert_X(X[:50], regressor_orig) rnd = np.random.RandomState(0) @@ -2339,7 +2426,12 @@ def check_regressors_int(name, regressor_orig, strict_mode=True): @ignore_warnings(category=FutureWarning) def check_regressors_train(name, regressor_orig, readonly_memmap=False, - X_dtype=np.float64, strict_mode=True): + X_dtype=np.float64, api_only=False): + # Check that regressors: + # - raise an error when X and y have different number of samples + # - accept lists as input to fit + # - predict n_samples predictions + # - have a score > .5 on simple data X, y = _regression_dataset() X = X.astype(X_dtype) X = _pairwise_estimator_convert_X(X, regressor_orig) @@ -2381,13 +2473,13 @@ def check_regressors_train(name, regressor_orig, readonly_memmap=False, # TODO: find out why PLS and CCA fail. RANSAC is random # and furthermore assumes the presence of outliers, hence # skipped - if not regressor._get_tags()["poor_score"]: + if not _safe_tags(regressor, key="poor_score") and not api_only: assert regressor.score(X, y_) > 0.5 @ignore_warnings def check_regressors_no_decision_function(name, regressor_orig, - strict_mode=True): + api_only=False): # check that regressors don't have a decision_function, predict_proba, or # predict_log_proba method. rng = np.random.RandomState(0) @@ -2404,9 +2496,12 @@ def check_regressors_no_decision_function(name, regressor_orig, @ignore_warnings(category=FutureWarning) -def check_class_weight_classifiers(name, classifier_orig, strict_mode=True): +def check_class_weight_classifiers(name, classifier_orig, api_only=False): + # Make sure that classifiers take class_weight into account by creating a + # very noisy balanced dataset. We make sure that passing a very imbalanced + # class_weights helps recovering a good score. - if classifier_orig._get_tags()['binary_only']: + if _safe_tags(classifier_orig, key='binary_only'): problems = [2] else: problems = [2, 3] @@ -2445,14 +2540,14 @@ def check_class_weight_classifiers(name, classifier_orig, strict_mode=True): y_pred = classifier.predict(X_test) # XXX: Generally can use 0.89 here. On Windows, LinearSVC gets # 0.88 (Issue #9111) - if not classifier_orig._get_tags()['poor_score']: + if not _safe_tags(classifier_orig, key='poor_score'): assert np.mean(y_pred == 0) > 0.87 @ignore_warnings(category=FutureWarning) def check_class_weight_balanced_classifiers(name, classifier_orig, X_train, y_train, X_test, y_test, weights, - strict_mode=True): + api_only=False): classifier = clone(classifier_orig) if hasattr(classifier, "n_iter"): classifier.set_params(n_iter=100) @@ -2472,8 +2567,9 @@ def check_class_weight_balanced_classifiers(name, classifier_orig, X_train, @ignore_warnings(category=FutureWarning) def check_class_weight_balanced_linear_classifier(name, Classifier, - strict_mode=True): - """Test class weights with non-contiguous class labels.""" + api_only=False): + # Check that class_weight='balanced' is equivalent to manually passing + # class proportions. # this is run on classes, not instances, though this should be changed X = np.array([[-1.0, -1.0], [-1.0, 0], [-.8, -1.0], [1.0, 1.0], [1.0, 0.0]]) @@ -2511,7 +2607,8 @@ def check_class_weight_balanced_linear_classifier(name, Classifier, @ignore_warnings(category=FutureWarning) -def check_estimators_overwrite_params(name, estimator_orig, strict_mode=True): +def check_estimators_overwrite_params(name, estimator_orig, api_only=False): + # Check that calling fit does not alter the output of get_params X, y = make_blobs(random_state=0, n_samples=21) # some want non-negative input X -= X.min() @@ -2546,8 +2643,10 @@ def check_estimators_overwrite_params(name, estimator_orig, strict_mode=True): @ignore_warnings(category=FutureWarning) -def check_no_attributes_set_in_init(name, estimator_orig, strict_mode=True): - """Check setting during init.""" +def check_no_attributes_set_in_init(name, estimator_orig, api_only=False): + # Check that: + # - init does not set any attribute apart from the parameters + # - all parameters of init are set as attributes try: # Clone fails if the estimator does not store # all parameters as an attribute during init @@ -2580,7 +2679,9 @@ def check_no_attributes_set_in_init(name, estimator_orig, strict_mode=True): @ignore_warnings(category=FutureWarning) -def check_sparsify_coefficients(name, estimator_orig, strict_mode=True): +def check_sparsify_coefficients(name, estimator_orig, api_only=False): + # Check that sparsified coefs produce the same predictions as the + # originals coefs X = np.array([[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1], [-1, -2], [2, 2], [-2, -2]]) y = np.array([1, 1, 1, 2, 2, 2, 3, 3, 3]) @@ -2604,7 +2705,9 @@ def check_sparsify_coefficients(name, estimator_orig, strict_mode=True): @ignore_warnings(category=FutureWarning) -def check_classifier_data_not_an_array(name, estimator_orig, strict_mode=True): +def check_classifier_data_not_an_array(name, estimator_orig, api_only=False): + # Check that estimator yields same predictions whether an array was passed + # or not X = np.array([[3, 0], [0, 1], [0, 2], [1, 1], [1, 2], [2, 1], [0, 3], [1, 0], [2, 0], [4, 4], [2, 3], [3, 2]]) X = _pairwise_estimator_convert_X(X, estimator_orig) @@ -2616,7 +2719,9 @@ def check_classifier_data_not_an_array(name, estimator_orig, strict_mode=True): @ignore_warnings(category=FutureWarning) -def check_regressor_data_not_an_array(name, estimator_orig, strict_mode=True): +def check_regressor_data_not_an_array(name, estimator_orig, api_only=False): + # Check that estimator yields same predictions whether an array was passed + # or not X, y = _regression_dataset() X = _pairwise_estimator_convert_X(X, estimator_orig) y = _enforce_estimator_tags_y(estimator_orig, y) @@ -2627,7 +2732,7 @@ def check_regressor_data_not_an_array(name, estimator_orig, strict_mode=True): @ignore_warnings(category=FutureWarning) def check_estimators_data_not_an_array(name, estimator_orig, X, y, obj_type, - strict_mode=True): + api_only=False): if name in CROSS_DECOMPOSITION: raise SkipTest("Skipping check_estimators_data_not_an_array " "for cross decomposition module as estimators " @@ -2669,9 +2774,10 @@ def check_estimators_data_not_an_array(name, estimator_orig, X, y, obj_type, assert_allclose(pred1, pred2, atol=1e-2, err_msg=name) -def check_parameters_default_constructible(name, Estimator, strict_mode=True): - # test default-constructibility - # get rid of deprecation warnings +def check_parameters_default_constructible(name, Estimator, api_only=False): + # Check that the estimator's default parameters are immutable (sort of). + # Also check that get_params returns exactly the default parameters values + # on an unfitted estimator Estimator = Estimator.__class__ @@ -2765,16 +2871,16 @@ def param_filter(p): def _enforce_estimator_tags_y(estimator, y): # Estimators with a `requires_positive_y` tag only accept strictly positive # data - if estimator._get_tags()["requires_positive_y"]: + if _safe_tags(estimator, key="requires_positive_y"): # Create strictly positive y. The minimal increment above 0 is 1, as # y could be of integer dtype. y += 1 + abs(y.min()) # Estimators with a `binary_only` tag only accept up to two unique y values - if estimator._get_tags()["binary_only"] and y.size > 0: + if _safe_tags(estimator, key="binary_only") and y.size > 0: y = np.where(y == y.flat[0], y, y.flat[0] + 1) # Estimators in mono_output_task_error raise ValueError if y is of 1-D # Convert into a 2-D y for those estimators. - if estimator._get_tags()["multioutput_only"]: + if _safe_tags(estimator, key="multioutput_only"): return np.reshape(y, (-1, 1)) return y @@ -2786,18 +2892,18 @@ def _enforce_estimator_tags_x(estimator, X): X = X.dot(X.T) # Estimators with `1darray` in `X_types` tag only accept # X of shape (`n_samples`,) - if '1darray' in estimator._get_tags()['X_types']: + if '1darray' in _safe_tags(estimator, key='X_types'): X = X[:, 0] # Estimators with a `requires_positive_X` tag only accept # strictly positive data - if estimator._get_tags()['requires_positive_X']: + if _safe_tags(estimator, key='requires_positive_X'): X -= X.min() return X @ignore_warnings(category=FutureWarning) def check_non_transformer_estimators_n_iter(name, estimator_orig, - strict_mode=True): + api_only=False): # Test that estimators that are not transformers with a parameter # max_iter, return the attribute of n_iter_ at least 1. @@ -2833,7 +2939,7 @@ def check_non_transformer_estimators_n_iter(name, estimator_orig, @ignore_warnings(category=FutureWarning) -def check_transformer_n_iter(name, estimator_orig, strict_mode=True): +def check_transformer_n_iter(name, estimator_orig, api_only=False): # Test that transformers with a parameter max_iter, return the # attribute of n_iter_ at least 1. estimator = clone(estimator_orig) @@ -2859,8 +2965,8 @@ def check_transformer_n_iter(name, estimator_orig, strict_mode=True): @ignore_warnings(category=FutureWarning) -def check_get_params_invariance(name, estimator_orig, strict_mode=True): - # Checks if get_params(deep=False) is a subset of get_params(deep=True) +def check_get_params_invariance(name, estimator_orig, api_only=False): + # Checks that get_params(deep=False) is a subset of get_params(deep=True) e = clone(estimator_orig) shallow_params = e.get_params(deep=False) @@ -2871,7 +2977,7 @@ def check_get_params_invariance(name, estimator_orig, strict_mode=True): @ignore_warnings(category=FutureWarning) -def check_set_params(name, estimator_orig, strict_mode=True): +def check_set_params(name, estimator_orig, api_only=False): # Check that get_params() returns the same thing # before and after set_params() with some fuzz estimator = clone(estimator_orig) @@ -2926,21 +3032,21 @@ def check_set_params(name, estimator_orig, strict_mode=True): @ignore_warnings(category=FutureWarning) def check_classifiers_regression_target(name, estimator_orig, - strict_mode=True): + api_only=False): # Check if classifier throws an exception when fed regression targets X, y = _regression_dataset() X = X + 1 + abs(X.min(axis=0)) # be sure that X is non-negative e = clone(estimator_orig) - msg = "Unknown label type: " - if not e._get_tags()["no_validation"]: - with raises(ValueError, match=msg): + match = None if api_only else "Unknown label type: " + if not _safe_tags(e, key="no_validation"): + with raises(ValueError, match=match): e.fit(X, y) @ignore_warnings(category=FutureWarning) -def check_decision_proba_consistency(name, estimator_orig, strict_mode=True): +def check_decision_proba_consistency(name, estimator_orig, api_only=False): # Check whether an estimator having both decision_function and # predict_proba methods has outputs with perfect rank correlation. @@ -2962,7 +3068,7 @@ def check_decision_proba_consistency(name, estimator_orig, strict_mode=True): assert_array_equal(rankdata(a), rankdata(b)) -def check_outliers_fit_predict(name, estimator_orig, strict_mode=True): +def check_outliers_fit_predict(name, estimator_orig, api_only=False): # Check fit_predict for outlier detectors. n_samples = 300 @@ -2985,7 +3091,7 @@ def check_outliers_fit_predict(name, estimator_orig, strict_mode=True): y_pred_2 = estimator.fit(X).predict(X) assert_array_equal(y_pred, y_pred_2) - if hasattr(estimator, "contamination"): + if hasattr(estimator, "contamination") and not api_only: # proportion of outliers equal to contamination parameter when not # set to 'auto' expected_outliers = 30 @@ -3010,21 +3116,18 @@ def check_outliers_fit_predict(name, estimator_orig, strict_mode=True): estimator.fit_predict(X) -def check_fit_non_negative(name, estimator_orig, strict_mode=True): - # Check that proper warning is raised for non-negative X +def check_fit_non_negative(name, estimator_orig, api_only=False): + # Check that proper error is raised for non-negative X # when tag requires_positive_X is present X = np.array([[-1., 1], [-1., 1]]) y = np.array([1, 2]) estimator = clone(estimator_orig) - if strict_mode: - with raises(ValueError, match="Negative values in data passed to"): - estimator.fit(X, y) - else: # Don't check error message if strict mode is off - with raises(ValueError): - estimator.fit(X, y) + match = None if api_only else "Negative values in data passed to" + with raises(ValueError, match=match): + estimator.fit(X, y) -def check_fit_idempotent(name, estimator_orig, strict_mode=True): +def check_fit_idempotent(name, estimator_orig, api_only=False): # Check that est.fit(X) is the same as est.fit(X).fit(X). Ideally we would # check that the estimated parameters during training (e.g. coefs_) are # the same, but having a universal comparison function for those @@ -3079,7 +3182,7 @@ def check_fit_idempotent(name, estimator_orig, strict_mode=True): ) -def check_n_features_in(name, estimator_orig, strict_mode=True): +def check_n_features_in(name, estimator_orig, api_only=False): # Make sure that n_features_in_ attribute doesn't exist until fit is # called, and that its value is correct. @@ -3117,7 +3220,7 @@ def check_n_features_in(name, estimator_orig, strict_mode=True): ) -def check_requires_y_none(name, estimator_orig, strict_mode=True): +def check_requires_y_none(name, estimator_orig, api_only=False): # Make sure that an estimator with requires_y=True fails gracefully when # given y=None @@ -3149,9 +3252,9 @@ def check_requires_y_none(name, estimator_orig, strict_mode=True): warnings.warn(warning_msg, FutureWarning) -def check_n_features_in_after_fitting(name, estimator_orig, strict_mode=True): +def check_n_features_in_after_fitting(name, estimator_orig, api_only=False): # Make sure that n_features_in are checked after fitting - tags = estimator_orig._get_tags() + tags = _safe_tags(estimator_orig) if "2darray" not in tags["X_types"] or tags["no_validation"]: return @@ -3203,7 +3306,23 @@ def check_n_features_in_after_fitting(name, estimator_orig, strict_mode=True): estimator.partial_fit(X_bad, y) -# set of checks that are completely strict, i.e. they have no non-strict part -_FULLY_STRICT_CHECKS = set([ - 'check_n_features_in', +# set of checks that do not check API-compatibility. They are ignored when +# api_only is True. +_NON_API_CHECKS = set([ + 'check_estimator_sparse_data', + 'check_sample_weights_invariance', + 'check_complex_data', + 'check_methods_subset_invariance', + 'check_methods_sample_order_invariance', + 'check_fit2d_1sample', + 'check_fit2d_1feature', + 'check_transformer_preserve_dtypes', + 'check_estimators_nan_inf', + 'check_clusterer_compute_labels_predict', + 'check_classifiers_one_label', + 'check_regressors_int', + 'check_class_weight_classifiers', + 'check_class_weight_balanced_linear_classifier', + 'check_sparsify_coefficients', + 'check_decision_proba_consistency', ]) diff --git a/sklearn/utils/tests/test_estimator_checks.py b/sklearn/utils/tests/test_estimator_checks.py index ecbf7cb7be7f4..61e7099d64084 100644 --- a/sklearn/utils/tests/test_estimator_checks.py +++ b/sklearn/utils/tests/test_estimator_checks.py @@ -21,8 +21,8 @@ from sklearn.utils.estimator_checks import check_no_attributes_set_in_init from sklearn.utils.estimator_checks import check_classifier_data_not_an_array from sklearn.utils.estimator_checks import check_regressor_data_not_an_array -from sklearn.utils.validation import check_is_fitted from sklearn.utils.estimator_checks import check_outlier_corruption +from sklearn.utils.estimator_checks import parametrize_with_checks from sklearn.utils.fixes import np_version, parse_version from sklearn.ensemble import RandomForestClassifier from sklearn.linear_model import LinearRegression, SGDClassifier @@ -32,7 +32,12 @@ from sklearn.linear_model import MultiTaskElasticNet, LogisticRegression from sklearn.svm import SVC, NuSVC from sklearn.neighbors import KNeighborsRegressor -from sklearn.utils.validation import check_array +from sklearn.utils.multiclass import check_classification_targets +from sklearn.utils.validation import ( + check_array, + check_is_fitted, + check_X_y, +) from sklearn.utils import all_estimators from sklearn.exceptions import SkipTestWarning @@ -418,8 +423,8 @@ def test_check_estimator(): # check that we have a set_params and can clone msg = "Passing a class was deprecated" assert_raises_regex(TypeError, msg, check_estimator, object) - msg = "object has no attribute '_get_tags'" - assert_raises_regex(AttributeError, msg, check_estimator, object()) + # msg = "object has no attribute '_get_tags'" + # assert_raises_regex(AttributeError, msg, check_estimator, object()) msg = ( "Parameter 'p' of estimator 'HasMutableParameters' is of type " "object which is not allowed" @@ -620,6 +625,138 @@ def test_check_estimator_pairwise(): check_estimator(est) +class MinimalClassifier: + _estimator_type = "classifier" + + def __repr__(self): + # Only required when using pytest-xdist to get an id not associated + # with the memory location + return self.__class__.__name__ + + def get_params(self, **params): + return {} + + def set_params(self, deep=True): + return self + + def __getstate__(self): + return self.__dict__.copy() + + def __setstate__(self, state): + self.__dict__.update(state) + + def fit(self, X, y): + X, y = check_X_y(X, y) + check_classification_targets(y) + self.n_features_in_ = X.shape[1] + self.classes_, counts = np.unique(y, return_counts=True) + self._most_frequent_class_idx = counts.argmax() + return self + + def predict_proba(self, X): + check_is_fitted(self) + X = check_array(X) + if X.shape[1] != self.n_features_in_: + raise ValueError + proba_shape = (X.shape[0], self.classes_.size) + y_proba = np.zeros(shape=proba_shape, dtype=np.float64) + y_proba[:, self._most_frequent_class_idx] = 1.0 + return y_proba + + def predict(self, X): + y_proba = self.predict_proba(X) + y_pred = y_proba.argmax(axis=1) + return self.classes_[y_pred] + + def score(self, X, y): + from sklearn.metrics import accuracy_score + return accuracy_score(y, self.predict(X)) + + +class MinimalRegressor: + _estimator_type = "regressor" + + def __repr__(self): + # Only required when using pytest-xdist to get an id not associated + # with the memory location + return self.__class__.__name__ + + def get_params(self, **params): + return {} + + def set_params(self, deep=True): + return self + + def __getstate__(self): + return self.__dict__.copy() + + def __setstate__(self, state): + self.__dict__.update(state) + + def fit(self, X, y): + X, y = check_X_y(X, y) + self.n_features_in_ = X.shape[1] + self._mean = np.mean(y) + return self + + def predict(self, X): + check_is_fitted(self) + X = check_array(X) + if X.shape[1] != self.n_features_in_: + raise ValueError + return np.ones(shape=(X.shape[0],)) * self._mean + + def score(self, X, y): + from sklearn.metrics import r2_score + return r2_score(y, self.predict(X)) + + +class MinimalTransformer: + + def __repr__(self): + # Only required when using pytest-xdist to get an id not associated + # with the memory location + return self.__class__.__name__ + + def get_params(self, **params): + return {} + + def set_params(self, deep=True): + return self + + def __getstate__(self): + return self.__dict__.copy() + + def __setstate__(self, state): + self.__dict__.update(state) + + def fit(self, X, y=None): + X = check_array(X) + self.n_features_in_ = X.shape[1] + return self + + def transform(self, X, y=None): + check_is_fitted(self) + X = check_array(X) + if X.shape[1] != self.n_features_in_: + raise ValueError + return X + + def inverse_transform(self, X, y=None): + return self.transform(X) + + def fit_transform(self, X, y=None): + return self.fit(X, y).transform(X, y) + + +@parametrize_with_checks( + [MinimalClassifier(), MinimalRegressor(), MinimalTransformer()], + api_only=True +) +def test_check_estimator_minimal(estimator, check): + check(estimator) + + def test_check_classifier_data_not_an_array(): assert_raises_regex(AssertionError, 'Not equal to tolerance',