Skip to content

[WIP] Verbose flag displaying progress bar for check_estimator in sklearn.utils.estimator_checks #13843

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

14 changes: 9 additions & 5 deletions sklearn/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import re
import pkgutil
import functools
import itertools

import pytest

Expand All @@ -28,7 +29,8 @@
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.utils import IS_PYPY
from sklearn.utils.estimator_checks import (
_yield_all_checks,
check_estimator,
yield_all_checks,
_safe_tags,
set_checking_parameters,
check_parameters_default_constructible,
Expand Down Expand Up @@ -95,12 +97,14 @@ def _rename_partial(val):
if hasattr(val, "get_params") and not isinstance(val, type):
return type(val).__name__

ALL_ESTIMATORS = all_estimators()

@pytest.mark.parametrize(
"estimator, check",
_generate_checks_per_estimator(_yield_all_checks,
_tested_estimators()),
ids=_rename_partial
"check, name, estimator",
itertools.chain.from_iterable(
check_estimator(Estimator, evaluate=False)
for Estimator in ALL_ESTIMATORS
)
)
def test_estimators(estimator, check):
# Common tests for estimator instances
Expand Down
31 changes: 18 additions & 13 deletions sklearn/utils/estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def _yield_outliers_checks(name, estimator):
yield check_estimators_unfitted


def _yield_all_checks(name, estimator):
def yield_all_checks(name, estimator):
tags = _safe_tags(estimator)
if "2darray" not in tags["X_types"]:
warnings.warn("Can't test estimator {} which requires input "
Expand Down Expand Up @@ -265,7 +265,7 @@ def _yield_all_checks(name, estimator):
yield check_fit_idempotent


def check_estimator(Estimator):
def check_estimator(Estimator, evaluate=True):
"""Check if estimator adheres to scikit-learn conventions.

This estimator will run an extensive test-suite for input validation,
Expand All @@ -283,19 +283,24 @@ def check_estimator(Estimator):
estimator : estimator object or class
Estimator to check. Estimator is a class object or instance.

evaluate : bool
Flag to indicate whether or not to evaluate the passed
estimator.
"""
if isinstance(Estimator, type):
# got a class
name = Estimator.__name__
estimator = Estimator()
check_parameters_default_constructible(name, Estimator)
check_no_attributes_set_in_init(name, estimator)
else:
# got an instance
estimator = Estimator
name = type(estimator).__name__
if evaluate:
if isinstance(Estimator, type):
# got a class
name = Estimator.__name__
estimator = Estimator()
# Generate tests for pytest hooks collector
check_parameters_default_constructible(name, Estimator)
check_no_attributes_set_in_init(name, estimator)
else:
# got an instance
estimator = Estimator
name = type(estimator).__name__

for check in _yield_all_checks(name, estimator):
for check in yield_all_checks(name, estimator):
try:
check(name, estimator)
except SkipTest as exception:
Expand Down