diff --git a/doc/developers/develop.rst b/doc/developers/develop.rst index f17c58cee0d7f..13d2010ca7319 100644 --- a/doc/developers/develop.rst +++ b/doc/developers/develop.rst @@ -246,13 +246,13 @@ whether it is just for you or for contributing it to scikit-learn, there are several internals of scikit-learn that you should be aware of in addition to the scikit-learn API outlined above. You can check whether your estimator adheres to the scikit-learn interface and standards by running -:func:`utils.estimator_checks.check_estimator` on the class or using -:func:`~sklearn.utils.parametrize_with_checks` pytest decorator (see its -docstring for details and possible interactions with `pytest`):: +:func:`~sklearn.utils.estimator_checks.check_estimator` on an instance. The +:func:`~sklearn.utils.parametrize_with_checks` pytest decorator can also be +used (see its docstring for details and possible interactions with `pytest`):: >>> from sklearn.utils.estimator_checks import check_estimator >>> from sklearn.svm import LinearSVC - >>> check_estimator(LinearSVC) # passes + >>> check_estimator(LinearSVC()) # passes The main motivation to make a class compatible to the scikit-learn estimator interface might be that you want to use it together with model evaluation and diff --git a/doc/whats_new/v0.23.rst b/doc/whats_new/v0.23.rst index 4357845885e3f..f9b8e25176265 100644 --- a/doc/whats_new/v0.23.rst +++ b/doc/whats_new/v0.23.rst @@ -511,6 +511,11 @@ Changelog matrix from a pandas DataFrame that contains only `SparseArray`s. :pr:`16728` by `Thomas Fan`_. +- |API| Passing classes to :func:`utils.estimator_checks.check_estimator` and + :func:`utils.estimator_checks.parametrize_with_checks` is now deprecated, + and support for classes will be removed in 0.24. Pass instances instead. + :pr:`17032` by `Nicolas Hug`_. + :mod:`sklearn.cluster` ...................... diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index af98c1bc50a74..73c99b0483de8 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -48,7 +48,9 @@ def test_all_estimator_no_base_class(): assert not name.lower().startswith('base'), msg +@ignore_warnings("Passing a class is depr", category=FutureWarning) # 0.24 def test_estimator_cls_parameterize_with_checks(): + # TODO: remove test in 0.24 # Non-regression test for #16707 to ensure that parametrize_with_checks # works with estimator classes param_checks = parametrize_with_checks([LogisticRegression]) @@ -105,7 +107,7 @@ def _tested_estimators(): yield estimator -@parametrize_with_checks(_tested_estimators()) +@parametrize_with_checks(list(_tested_estimators())) def test_estimators(estimator, check, request): # Common tests for estimator instances with ignore_warnings(category=(FutureWarning, @@ -115,7 +117,9 @@ def test_estimators(estimator, check, request): check(estimator) +@ignore_warnings("Passing a class is depr", category=FutureWarning) # 0.24 def test_check_estimator_generate_only(): + # TODO in 0.24: remove checks on passing a class estimator_cls_gen_checks = check_estimator(LogisticRegression, generate_only=True) all_instance_gen_checks = check_estimator(LogisticRegression(), @@ -238,3 +242,19 @@ def test_all_tests_are_importable(): '__init__.py or an add_subpackage directive ' 'in the parent ' 'setup.py'.format(missing_tests)) + + +# TODO: remove in 0.24 +def test_class_support_deprecated(): + # Make sure passing classes to check_estimator or parametrize_with_checks + # is deprecated + + msg = "Passing a class is deprecated" + with pytest.warns(FutureWarning, match=msg): + check_estimator(LogisticRegression) + + with pytest.warns(FutureWarning, match=msg): + parametrize_with_checks([LogisticRegression]) + + # Make sure check_parameters_default_constructible accepts instances now + check_parameters_default_constructible('name', LogisticRegression()) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index efac2aca2a2df..ec28cb22919f0 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -33,7 +33,7 @@ from ..linear_model import Ridge from ..base import (clone, ClusterMixin, is_classifier, is_regressor, - RegressorMixin, is_outlier_detector) + RegressorMixin, is_outlier_detector, BaseEstimator) from ..metrics import accuracy_score, adjusted_rand_score, f1_score from ..random_projection import BaseRandomProjection @@ -333,12 +333,15 @@ def _construct_instance(Estimator): return estimator +# TODO: probably not needed anymore in 0.24 since _generate_class_checks should +# be removed too. Just put this in check_estimator() def _generate_instance_checks(name, estimator): """Generate instance checks.""" yield from ((estimator, partial(check, name)) for check in _yield_all_checks(name, estimator)) +# TODO: remove this in 0.24 def _generate_class_checks(Estimator): """Generate class checks.""" name = Estimator.__name__ @@ -353,6 +356,8 @@ def _mark_xfail_checks(estimator, check, pytest): if isinstance(estimator, type): # try to construct estimator instance, if it is unable to then # return the estimator class, ignoring the tag + # TODO: remove this if block in 0.24 since passing instances isn't + # supported anymore try: estimator = _construct_instance(estimator) except Exception: @@ -385,6 +390,10 @@ def parametrize_with_checks(estimators): estimators : list of estimators objects or classes Estimators to generated checks for. + .. deprecated:: 0.23 + Passing a class is deprecated from version 0.23, and won't be + supported in 0.24. Pass an instance instead. + Returns ------- decorator : `pytest.mark.parametrize` @@ -395,13 +404,21 @@ def parametrize_with_checks(estimators): >>> from sklearn.linear_model import LogisticRegression >>> from sklearn.tree import DecisionTreeRegressor - >>> @parametrize_with_checks([LogisticRegression, DecisionTreeRegressor]) + >>> @parametrize_with_checks([LogisticRegression(), + ... DecisionTreeRegressor()]) ... def test_sklearn_compatible_estimator(estimator, check): ... check(estimator) """ import pytest + if any(isinstance(est, type) for est in estimators): + # TODO: remove class support in 0.24 and update docstrings + msg = ("Passing a class is deprecated since version 0.23 " + "and won't be supported in 0.24." + "Please pass an instance instead.") + warnings.warn(msg, FutureWarning) + checks_generator = chain.from_iterable( check_estimator(estimator, generate_only=True) for estimator in estimators) @@ -418,7 +435,7 @@ def check_estimator(Estimator, generate_only=False): """Check if estimator adheres to scikit-learn conventions. This estimator will run an extensive test-suite for input validation, - shapes, etc, making sure that the estimator complies with `scikit-leanrn` + shapes, etc, making sure that the estimator complies with `scikit-learn` conventions as detailed in :ref:`rolling_your_own_estimator`. Additional tests for classifiers, regressors, clustering or transformers will be run if the Estimator class inherits from the corresponding mixin @@ -426,7 +443,9 @@ def check_estimator(Estimator, generate_only=False): This test can be applied to classes or instances. Classes currently have some additional tests that related to construction, - while passing instances allows the testing of multiple options. + while passing instances allows the testing of multiple options. However, + support for classes is deprecated since version 0.23 and will be removed + in version 0.24 (class checks will still be run on the instances). Setting `generate_only=True` returns a generator that yields (estimator, check) tuples where the check can be called independently from each @@ -439,9 +458,13 @@ def check_estimator(Estimator, generate_only=False): Parameters ---------- - estimator : estimator object or class + estimator : estimator object Estimator to check. Estimator is a class object or instance. + .. deprecated:: 0.23 + Passing a class is deprecated from version 0.23, and won't be + supported in 0.24. Pass an instance instead. + generate_only : bool, optional (default=False) When `False`, checks are evaluated when `check_estimator` is called. When `True`, `check_estimator` returns a generator that yields @@ -456,8 +479,14 @@ def check_estimator(Estimator, generate_only=False): Generator that yields (estimator, check) tuples. Returned when `generate_only=True`. """ + # TODO: remove class support in 0.24 and update docstrings if isinstance(Estimator, type): # got a class + msg = ("Passing a class is deprecated since version 0.23 " + "and won't be supported in 0.24." + "Please pass an instance instead.") + warnings.warn(msg, FutureWarning) + checks_generator = _generate_class_checks(Estimator) else: # got an instance @@ -2570,6 +2599,12 @@ def check_parameters_default_constructible(name, Estimator): # this check works on classes, not instances # test default-constructibility # get rid of deprecation warnings + if isinstance(Estimator, BaseEstimator): + # Convert estimator instance to its class + # TODO: Always convert to class in 0.24, because check_estimator() will + # only accept instances, not classes + Estimator = Estimator.__class__ + with ignore_warnings(category=FutureWarning): estimator = _construct_instance(Estimator) # test cloning diff --git a/sklearn/utils/tests/test_estimator_checks.py b/sklearn/utils/tests/test_estimator_checks.py index a755daa842ef5..594ff65f9e889 100644 --- a/sklearn/utils/tests/test_estimator_checks.py +++ b/sklearn/utils/tests/test_estimator_checks.py @@ -356,6 +356,7 @@ def fit(self, X, y): check_fit_score_takes_y("test", TestEstimatorWithDeprecatedFitMethod()) +@ignore_warnings("Passing a class is depr", category=FutureWarning) # 0.24 def test_check_estimator(): # tests that the estimator actually fails on "bad" estimators. # not a complete test of all checks, which are very extensive. @@ -579,7 +580,10 @@ def test_check_regressor_data_not_an_array(): EstimatorInconsistentForPandas()) +@ignore_warnings("Passing a class is depr", category=FutureWarning) # 0.24 def test_check_estimator_required_parameters_skip(): + # TODO: remove whole test in 0.24 since passes classes to check_estimator() + # isn't supported anymore class MyEstimator(BaseEstimator): _required_parameters = ["special_parameter"]