Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions doc/developers/develop.rst
Original file line number Diff line number Diff line change
Expand Up @@ -246,13 +246,13 @@ whether it is just for you or for contributing it to scikit-learn, there are
several internals of scikit-learn that you should be aware of in addition to
the scikit-learn API outlined above. You can check whether your estimator
adheres to the scikit-learn interface and standards by running
:func:`utils.estimator_checks.check_estimator` on the class or using
:func:`~sklearn.utils.parametrize_with_checks` pytest decorator (see its
docstring for details and possible interactions with `pytest`)::
:func:`~sklearn.utils.estimator_checks.check_estimator` on an instance. The
:func:`~sklearn.utils.parametrize_with_checks` pytest decorator can also be
used (see its docstring for details and possible interactions with `pytest`)::

>>> from sklearn.utils.estimator_checks import check_estimator
>>> from sklearn.svm import LinearSVC
>>> check_estimator(LinearSVC) # passes
>>> check_estimator(LinearSVC()) # passes

The main motivation to make a class compatible to the scikit-learn estimator
interface might be that you want to use it together with model evaluation and
Expand Down
5 changes: 5 additions & 0 deletions doc/whats_new/v0.23.rst
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,11 @@ Changelog
matrix from a pandas DataFrame that contains only `SparseArray`s.
:pr:`16728` by `Thomas Fan`_.

- |API| Passing classes to :func:`utils.estimator_checks.check_estimator` and
:func:`utils.estimator_checks.parametrize_with_checks` is now deprecated,
and support for classes will be removed in 0.24. Pass instances instead.
:pr:`17032` by `Nicolas Hug`_.

:mod:`sklearn.cluster`
......................

Expand Down
22 changes: 21 additions & 1 deletion sklearn/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ def test_all_estimator_no_base_class():
assert not name.lower().startswith('base'), msg


@ignore_warnings("Passing a class is depr", category=FutureWarning) # 0.24
def test_estimator_cls_parameterize_with_checks():
# TODO: remove test in 0.24
# Non-regression test for #16707 to ensure that parametrize_with_checks
# works with estimator classes
param_checks = parametrize_with_checks([LogisticRegression])
Expand Down Expand Up @@ -105,7 +107,7 @@ def _tested_estimators():
yield estimator


@parametrize_with_checks(_tested_estimators())
@parametrize_with_checks(list(_tested_estimators()))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

without this, the iterator would be consumed in parametrize_with_checks due to the addition of if any(isinstance(est, type) for est in estimators):

Note that the docstring of parametrize_with_checks says that its argument should be a list.

def test_estimators(estimator, check, request):
# Common tests for estimator instances
with ignore_warnings(category=(FutureWarning,
Expand All @@ -115,7 +117,9 @@ def test_estimators(estimator, check, request):
check(estimator)


@ignore_warnings("Passing a class is depr", category=FutureWarning) # 0.24
def test_check_estimator_generate_only():
# TODO in 0.24: remove checks on passing a class
estimator_cls_gen_checks = check_estimator(LogisticRegression,
generate_only=True)
all_instance_gen_checks = check_estimator(LogisticRegression(),
Expand Down Expand Up @@ -238,3 +242,19 @@ def test_all_tests_are_importable():
'__init__.py or an add_subpackage directive '
'in the parent '
'setup.py'.format(missing_tests))


# TODO: remove in 0.24
def test_class_support_deprecated():
# Make sure passing classes to check_estimator or parametrize_with_checks
# is deprecated

msg = "Passing a class is deprecated"
with pytest.warns(FutureWarning, match=msg):
check_estimator(LogisticRegression)

with pytest.warns(FutureWarning, match=msg):
parametrize_with_checks([LogisticRegression])

# Make sure check_parameters_default_constructible accepts instances now
check_parameters_default_constructible('name', LogisticRegression())
45 changes: 40 additions & 5 deletions sklearn/utils/estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from ..linear_model import Ridge

from ..base import (clone, ClusterMixin, is_classifier, is_regressor,
RegressorMixin, is_outlier_detector)
RegressorMixin, is_outlier_detector, BaseEstimator)

from ..metrics import accuracy_score, adjusted_rand_score, f1_score
from ..random_projection import BaseRandomProjection
Expand Down Expand Up @@ -333,12 +333,15 @@ def _construct_instance(Estimator):
return estimator


# TODO: probably not needed anymore in 0.24 since _generate_class_checks should
# be removed too. Just put this in check_estimator()
def _generate_instance_checks(name, estimator):
"""Generate instance checks."""
yield from ((estimator, partial(check, name))
for check in _yield_all_checks(name, estimator))


# TODO: remove this in 0.24
def _generate_class_checks(Estimator):
"""Generate class checks."""
name = Estimator.__name__
Expand All @@ -353,6 +356,8 @@ def _mark_xfail_checks(estimator, check, pytest):
if isinstance(estimator, type):
# try to construct estimator instance, if it is unable to then
# return the estimator class, ignoring the tag
# TODO: remove this if block in 0.24 since passing instances isn't
# supported anymore
try:
estimator = _construct_instance(estimator)
except Exception:
Expand Down Expand Up @@ -385,6 +390,10 @@ def parametrize_with_checks(estimators):
estimators : list of estimators objects or classes
Estimators to generated checks for.

.. deprecated:: 0.23
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the estimator is not deprecated, but changed? I'm not sure if versionchanged works better here or not.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it should be deprecated for now, and become versionchanged in 0.24

Passing a class is deprecated from version 0.23, and won't be
supported in 0.24. Pass an instance instead.

Returns
-------
decorator : `pytest.mark.parametrize`
Expand All @@ -395,13 +404,21 @@ def parametrize_with_checks(estimators):
>>> from sklearn.linear_model import LogisticRegression
>>> from sklearn.tree import DecisionTreeRegressor

>>> @parametrize_with_checks([LogisticRegression, DecisionTreeRegressor])
>>> @parametrize_with_checks([LogisticRegression(),
... DecisionTreeRegressor()])
... def test_sklearn_compatible_estimator(estimator, check):
... check(estimator)

"""
import pytest

if any(isinstance(est, type) for est in estimators):
# TODO: remove class support in 0.24 and update docstrings
msg = ("Passing a class is deprecated since version 0.23 "
"and won't be supported in 0.24."
"Please pass an instance instead.")
warnings.warn(msg, FutureWarning)

checks_generator = chain.from_iterable(
check_estimator(estimator, generate_only=True)
for estimator in estimators)
Expand All @@ -418,15 +435,17 @@ def check_estimator(Estimator, generate_only=False):
"""Check if estimator adheres to scikit-learn conventions.

This estimator will run an extensive test-suite for input validation,
shapes, etc, making sure that the estimator complies with `scikit-leanrn`
shapes, etc, making sure that the estimator complies with `scikit-learn`
conventions as detailed in :ref:`rolling_your_own_estimator`.
Additional tests for classifiers, regressors, clustering or transformers
will be run if the Estimator class inherits from the corresponding mixin
from sklearn.base.

This test can be applied to classes or instances.
Classes currently have some additional tests that related to construction,
while passing instances allows the testing of multiple options.
while passing instances allows the testing of multiple options. However,
support for classes is deprecated since version 0.23 and will be removed
in version 0.24 (class checks will still be run on the instances).

Setting `generate_only=True` returns a generator that yields (estimator,
check) tuples where the check can be called independently from each
Expand All @@ -439,9 +458,13 @@ def check_estimator(Estimator, generate_only=False):

Parameters
----------
estimator : estimator object or class
estimator : estimator object
Estimator to check. Estimator is a class object or instance.

.. deprecated:: 0.23
Passing a class is deprecated from version 0.23, and won't be
supported in 0.24. Pass an instance instead.

generate_only : bool, optional (default=False)
When `False`, checks are evaluated when `check_estimator` is called.
When `True`, `check_estimator` returns a generator that yields
Expand All @@ -456,8 +479,14 @@ def check_estimator(Estimator, generate_only=False):
Generator that yields (estimator, check) tuples. Returned when
`generate_only=True`.
"""
# TODO: remove class support in 0.24 and update docstrings
if isinstance(Estimator, type):
# got a class
msg = ("Passing a class is deprecated since version 0.23 "
"and won't be supported in 0.24."
"Please pass an instance instead.")
warnings.warn(msg, FutureWarning)

checks_generator = _generate_class_checks(Estimator)
else:
# got an instance
Expand Down Expand Up @@ -2570,6 +2599,12 @@ def check_parameters_default_constructible(name, Estimator):
# this check works on classes, not instances
# test default-constructibility
# get rid of deprecation warnings
if isinstance(Estimator, BaseEstimator):
# Convert estimator instance to its class
# TODO: Always convert to class in 0.24, because check_estimator() will
# only accept instances, not classes
Estimator = Estimator.__class__

with ignore_warnings(category=FutureWarning):
estimator = _construct_instance(Estimator)
# test cloning
Expand Down
4 changes: 4 additions & 0 deletions sklearn/utils/tests/test_estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,7 @@ def fit(self, X, y):
check_fit_score_takes_y("test", TestEstimatorWithDeprecatedFitMethod())


@ignore_warnings("Passing a class is depr", category=FutureWarning) # 0.24
def test_check_estimator():
# tests that the estimator actually fails on "bad" estimators.
# not a complete test of all checks, which are very extensive.
Expand Down Expand Up @@ -579,7 +580,10 @@ def test_check_regressor_data_not_an_array():
EstimatorInconsistentForPandas())


@ignore_warnings("Passing a class is depr", category=FutureWarning) # 0.24
def test_check_estimator_required_parameters_skip():
# TODO: remove whole test in 0.24 since passes classes to check_estimator()
# isn't supported anymore
class MyEstimator(BaseEstimator):
_required_parameters = ["special_parameter"]

Expand Down