diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index 51f71f2f7919b..4cf4ca9e310ef 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -12,6 +12,7 @@ import re import pkgutil import functools +import itertools import pytest @@ -28,7 +29,8 @@ from sklearn.discriminant_analysis import LinearDiscriminantAnalysis from sklearn.utils import IS_PYPY from sklearn.utils.estimator_checks import ( - _yield_all_checks, + check_estimator, + yield_all_checks, _safe_tags, set_checking_parameters, check_parameters_default_constructible, @@ -95,12 +97,14 @@ def _rename_partial(val): if hasattr(val, "get_params") and not isinstance(val, type): return type(val).__name__ +ALL_ESTIMATORS = all_estimators() @pytest.mark.parametrize( - "estimator, check", - _generate_checks_per_estimator(_yield_all_checks, - _tested_estimators()), - ids=_rename_partial + "check, name, estimator", + itertools.chain.from_iterable( + check_estimator(Estimator, evaluate=False) + for Estimator in ALL_ESTIMATORS + ) ) def test_estimators(estimator, check): # Common tests for estimator instances diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index e922a7c0b4d48..39d5dc830c536 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -223,7 +223,7 @@ def _yield_outliers_checks(name, estimator): yield check_estimators_unfitted -def _yield_all_checks(name, estimator): +def yield_all_checks(name, estimator): tags = _safe_tags(estimator) if "2darray" not in tags["X_types"]: warnings.warn("Can't test estimator {} which requires input " @@ -265,7 +265,7 @@ def _yield_all_checks(name, estimator): yield check_fit_idempotent -def check_estimator(Estimator): +def check_estimator(Estimator, evaluate=True): """Check if estimator adheres to scikit-learn conventions. This estimator will run an extensive test-suite for input validation, @@ -283,19 +283,24 @@ def check_estimator(Estimator): estimator : estimator object or class Estimator to check. Estimator is a class object or instance. + evaluate : bool + Flag to indicate whether or not to evaluate the passed + estimator. """ - if isinstance(Estimator, type): - # got a class - name = Estimator.__name__ - estimator = Estimator() - check_parameters_default_constructible(name, Estimator) - check_no_attributes_set_in_init(name, estimator) - else: - # got an instance - estimator = Estimator - name = type(estimator).__name__ + if evaluate: + if isinstance(Estimator, type): + # got a class + name = Estimator.__name__ + estimator = Estimator() + # Generate tests for pytest hooks collector + check_parameters_default_constructible(name, Estimator) + check_no_attributes_set_in_init(name, estimator) + else: + # got an instance + estimator = Estimator + name = type(estimator).__name__ - for check in _yield_all_checks(name, estimator): + for check in yield_all_checks(name, estimator): try: check(name, estimator) except SkipTest as exception: