Skip to content

[MRG] Remove class support check estimator and parametrize_with_checks #17134

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
8 changes: 6 additions & 2 deletions examples/release_highlights/plot_release_highlights_0_22_0.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,17 +225,21 @@
# ---------------------------------------------------
# Developers can check the compatibility of their scikit-learn compatible
# estimators using :func:`~utils.estimator_checks.check_estimator`. For
# instance, the ``check_estimator(LinearSVC)`` passes.
# instance, the ``check_estimator(LinearSVC())`` passes.
#
# We now provide a ``pytest`` specific decorator which allows ``pytest``
# to run all checks independently and report the checks that are failing.
#
# ..note::
# This entry was slightly updated in version 0.24, where passing classes
# isn't supported anymore: pass instances instead.

from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeRegressor
from sklearn.utils.estimator_checks import parametrize_with_checks


@parametrize_with_checks([LogisticRegression, DecisionTreeRegressor])
@parametrize_with_checks([LogisticRegression(), DecisionTreeRegressor()])
def test_sklearn_compatible_estimator(estimator, check):
check(estimator)

Expand Down
64 changes: 5 additions & 59 deletions sklearn/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,12 @@

from sklearn.linear_model._base import LinearClassifierMixin
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import GridSearchCV
from sklearn.utils import IS_PYPY
from sklearn.utils._testing import SkipTest
from sklearn.utils.estimator_checks import (
_mark_xfail_checks,
_construct_instance,
_set_checking_parameters,
_set_check_estimator_ids,
check_parameters_default_constructible,
check_class_weight_balanced_linear_classifier,
parametrize_with_checks)

Expand All @@ -48,35 +45,6 @@ def test_all_estimator_no_base_class():
assert not name.lower().startswith('base'), msg


@ignore_warnings("Passing a class is depr", category=FutureWarning) # 0.24
def test_estimator_cls_parameterize_with_checks():
# TODO: remove test in 0.24
# Non-regression test for #16707 to ensure that parametrize_with_checks
# works with estimator classes
param_checks = parametrize_with_checks([LogisticRegression])
# Using the generator does not raise
list(param_checks.args[1])


def test_mark_xfail_checks_with_unconsructable_estimator():
class MyEstimator:
def __init__(self):
raise ValueError("This is bad")

estimator, check = _mark_xfail_checks(MyEstimator, 42, None)
assert estimator == MyEstimator
assert check == 42


@pytest.mark.parametrize(
'name, Estimator',
all_estimators()
)
def test_parameters_default_constructible(name, Estimator):
# Test that estimators are default-constructible
check_parameters_default_constructible(name, Estimator)


def _sample_func(x, y=1):
pass

Expand Down Expand Up @@ -117,29 +85,11 @@ def test_estimators(estimator, check, request):
check(estimator)


@ignore_warnings("Passing a class is depr", category=FutureWarning) # 0.24
def test_check_estimator_generate_only():
# TODO in 0.24: remove checks on passing a class
estimator_cls_gen_checks = check_estimator(LogisticRegression,
generate_only=True)
all_instance_gen_checks = check_estimator(LogisticRegression(),
generate_only=True)
assert isgenerator(estimator_cls_gen_checks)
assert isgenerator(all_instance_gen_checks)

estimator_cls_checks = list(estimator_cls_gen_checks)
all_instance_checks = list(all_instance_gen_checks)

# all classes checks include check_parameters_default_constructible
assert len(estimator_cls_checks) == len(all_instance_checks) + 1

# TODO: meta-estimators like GridSearchCV has required parameters
# that do not have default values. This is expected to change in the future
with pytest.raises(SkipTest):
for estimator, check in check_estimator(GridSearchCV,
generate_only=True):
check(estimator)


@ignore_warnings(category=(DeprecationWarning, FutureWarning))
# ignore deprecated open(.., 'U') in numpy distutils
Expand Down Expand Up @@ -244,17 +194,13 @@ def test_all_tests_are_importable():
'setup.py'.format(missing_tests))


# TODO: remove in 0.24
def test_class_support_deprecated():
def test_class_support_removed():
# Make sure passing classes to check_estimator or parametrize_with_checks
# is deprecated
# raises an error

msg = "Passing a class is deprecated"
with pytest.warns(FutureWarning, match=msg):
msg = "Passing a class was deprecated.* isn't supported anymore"
with pytest.raises(TypeError, match=msg):
check_estimator(LogisticRegression)

with pytest.warns(FutureWarning, match=msg):
with pytest.raises(TypeError, match=msg):
parametrize_with_checks([LogisticRegression])

# Make sure check_parameters_default_constructible accepts instances now
check_parameters_default_constructible('name', LogisticRegression())
86 changes: 23 additions & 63 deletions sklearn/utils/estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from ..linear_model import Ridge

from ..base import (clone, ClusterMixin, is_classifier, is_regressor,
RegressorMixin, is_outlier_detector, BaseEstimator)
RegressorMixin, is_outlier_detector)

from ..metrics import accuracy_score, adjusted_rand_score, f1_score
from ..random_projection import BaseRandomProjection
Expand Down Expand Up @@ -258,6 +258,7 @@ def _yield_all_checks(name, estimator):
if is_outlier_detector(estimator):
for check in _yield_outliers_checks(name, estimator):
yield check
yield check_parameters_default_constructible
yield check_fit2d_predict1d
yield check_methods_subset_invariance
yield check_fit2d_1sample
Expand Down Expand Up @@ -333,36 +334,9 @@ def _construct_instance(Estimator):
return estimator


# TODO: probably not needed anymore in 0.24 since _generate_class_checks should
# be removed too. Just put this in check_estimator()
def _generate_instance_checks(name, estimator):
"""Generate instance checks."""
yield from ((estimator, partial(check, name))
for check in _yield_all_checks(name, estimator))


# TODO: remove this in 0.24
def _generate_class_checks(Estimator):
"""Generate class checks."""
name = Estimator.__name__
yield (Estimator, partial(check_parameters_default_constructible, name))
estimator = _construct_instance(Estimator)
yield from _generate_instance_checks(name, estimator)


def _mark_xfail_checks(estimator, check, pytest):
"""Mark (estimator, check) pairs with xfail according to the
_xfail_checks_ tag"""
if isinstance(estimator, type):
# try to construct estimator instance, if it is unable to then
# return the estimator class, ignoring the tag
# TODO: remove this if block in 0.24 since passing instances isn't
# supported anymore
try:
estimator = _construct_instance(estimator)
except Exception:
return estimator, check

xfail_checks = estimator._get_tags()['_xfail_checks'] or {}
check_name = _set_check_estimator_ids(check)

Expand All @@ -387,12 +361,12 @@ def parametrize_with_checks(estimators):

Parameters
----------
estimators : list of estimators objects or classes
estimators : list of estimators instances
Estimators to generated checks for.

.. deprecated:: 0.23
Passing a class is deprecated from version 0.23, and won't be
supported in 0.24. Pass an instance instead.
.. versionchanged:: 0.24
Passing a class was deprecated in version 0.23, and support for
classes was removed in 0.24. Pass an instance instead.

Returns
-------
Expand All @@ -413,11 +387,10 @@ def parametrize_with_checks(estimators):
import pytest

if any(isinstance(est, type) for est in estimators):
# TODO: remove class support in 0.24 and update docstrings
msg = ("Passing a class is deprecated since version 0.23 "
"and won't be supported in 0.24."
msg = ("Passing a class was deprecated in version 0.23 "
Copy link
Member

@thomasjpfan thomasjpfan May 8, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not think we normally keep the deprecation warning version in the error message. Can this be:

Starting from 0.24, only instances are supported.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I'll remove if others feel the same way. No strong opinion on my side, I just thought it might be useful since we only went for a 1 version deprecation cycle on this

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thoughts on this @rth?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Starting from 0.24, only instances are supported for the "estimators" parameter.

Yes +1 for this error message, mostly because it tells the user what to do when they get this error. It could be a TypeError instead of a ValueError maybe as well?

"and isn't supported anymore from 0.24."
"Please pass an instance instead.")
warnings.warn(msg, FutureWarning)
raise TypeError(msg)

checks_generator = chain.from_iterable(
check_estimator(estimator, generate_only=True)
Expand All @@ -441,12 +414,6 @@ def check_estimator(Estimator, generate_only=False):
will be run if the Estimator class inherits from the corresponding mixin
from sklearn.base.

This test can be applied to classes or instances.
Classes currently have some additional tests that related to construction,
while passing instances allows the testing of multiple options. However,
support for classes is deprecated since version 0.23 and will be removed
in version 0.24 (class checks will still be run on the instances).

Setting `generate_only=True` returns a generator that yields (estimator,
check) tuples where the check can be called independently from each
other, i.e. `check(estimator)`. This allows all checks to be run
Expand All @@ -459,11 +426,11 @@ def check_estimator(Estimator, generate_only=False):
Parameters
----------
estimator : estimator object
Estimator to check. Estimator is a class object or instance.
Estimator instance to check.

.. deprecated:: 0.23
Passing a class is deprecated from version 0.23, and won't be
supported in 0.24. Pass an instance instead.
.. versionchanged:: 0.24
Passing a class was deprecated in version 0.23, and support for
classes was removed in 0.24.

generate_only : bool, optional (default=False)
When `False`, checks are evaluated when `check_estimator` is called.
Expand All @@ -479,20 +446,17 @@ def check_estimator(Estimator, generate_only=False):
Generator that yields (estimator, check) tuples. Returned when
`generate_only=True`.
"""
# TODO: remove class support in 0.24 and update docstrings
if isinstance(Estimator, type):
# got a class
msg = ("Passing a class is deprecated since version 0.23 "
"and won't be supported in 0.24."
msg = ("Passing a class was deprecated in version 0.23 "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

"and isn't supported anymore from 0.24."
"Please pass an instance instead.")
warnings.warn(msg, FutureWarning)
raise TypeError(msg)

checks_generator = _generate_class_checks(Estimator)
else:
# got an instance
estimator = Estimator
name = type(estimator).__name__
checks_generator = _generate_instance_checks(name, estimator)
estimator = Estimator
name = type(estimator).__name__

checks_generator = ((estimator, partial(check, name))
for check in _yield_all_checks(name, estimator))

if generate_only:
return checks_generator
Expand Down Expand Up @@ -2591,14 +2555,10 @@ def check_estimators_data_not_an_array(name, estimator_orig, X, y, obj_type):


def check_parameters_default_constructible(name, Estimator):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we rename it to estimator or do you think it would be considered an API breakage?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah it's not ideal but I'm concerned about backward compat :/

# this check works on classes, not instances
# test default-constructibility
# get rid of deprecation warnings
if isinstance(Estimator, BaseEstimator):
# Convert estimator instance to its class
# TODO: Always convert to class in 0.24, because check_estimator() will
# only accept instances, not classes
Estimator = Estimator.__class__

Estimator = Estimator.__class__
Copy link
Member

@rth rth May 10, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would actually fail in mypy. The reason it doesn't currently is that conda activate fails, and the the CI step silently is marked as successful. So the linting CI job is broken. Fix in #17177

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure?

mypy --ignore-missing-import sklearn passes one this branch (v 0.770)

I've had mypy and linting issues on PRs recently so it seemed to "work" fine

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh, yes you are right, it just installs everything in the main env. Well I would have though that it would error but nevermind.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We still need to fix the activation issues in #17177 though.

I though it was broken because I keep getting a few new mypy errors with code on master (that I haven't touched) occasionally.


with ignore_warnings(category=FutureWarning):
estimator = _construct_instance(Estimator)
Expand Down
Loading