Skip to content

Commit fc00415

Browse files
authored
Deprecate class support for check_estimator (#17032)
and parametrize_with_checks
1 parent 41b18fe commit fc00415

File tree

5 files changed

+74
-10
lines changed

5 files changed

+74
-10
lines changed

doc/developers/develop.rst

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -246,13 +246,13 @@ whether it is just for you or for contributing it to scikit-learn, there are
246246
several internals of scikit-learn that you should be aware of in addition to
247247
the scikit-learn API outlined above. You can check whether your estimator
248248
adheres to the scikit-learn interface and standards by running
249-
:func:`utils.estimator_checks.check_estimator` on the class or using
250-
:func:`~sklearn.utils.parametrize_with_checks` pytest decorator (see its
251-
docstring for details and possible interactions with `pytest`)::
249+
:func:`~sklearn.utils.estimator_checks.check_estimator` on an instance. The
250+
:func:`~sklearn.utils.parametrize_with_checks` pytest decorator can also be
251+
used (see its docstring for details and possible interactions with `pytest`)::
252252

253253
>>> from sklearn.utils.estimator_checks import check_estimator
254254
>>> from sklearn.svm import LinearSVC
255-
>>> check_estimator(LinearSVC) # passes
255+
>>> check_estimator(LinearSVC()) # passes
256256

257257
The main motivation to make a class compatible to the scikit-learn estimator
258258
interface might be that you want to use it together with model evaluation and

doc/whats_new/v0.23.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -534,6 +534,11 @@ Changelog
534534
matrix from a pandas DataFrame that contains only `SparseArray`s.
535535
:pr:`16728` by `Thomas Fan`_.
536536

537+
- |API| Passing classes to :func:`utils.estimator_checks.check_estimator` and
538+
:func:`utils.estimator_checks.parametrize_with_checks` is now deprecated,
539+
and support for classes will be removed in 0.24. Pass instances instead.
540+
:pr:`17032` by `Nicolas Hug`_.
541+
537542
:mod:`sklearn.cluster`
538543
......................
539544

sklearn/tests/test_common.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@ def test_all_estimator_no_base_class():
4848
assert not name.lower().startswith('base'), msg
4949

5050

51+
@ignore_warnings("Passing a class is depr", category=FutureWarning) # 0.24
5152
def test_estimator_cls_parameterize_with_checks():
53+
# TODO: remove test in 0.24
5254
# Non-regression test for #16707 to ensure that parametrize_with_checks
5355
# works with estimator classes
5456
param_checks = parametrize_with_checks([LogisticRegression])
@@ -105,7 +107,7 @@ def _tested_estimators():
105107
yield estimator
106108

107109

108-
@parametrize_with_checks(_tested_estimators())
110+
@parametrize_with_checks(list(_tested_estimators()))
109111
def test_estimators(estimator, check, request):
110112
# Common tests for estimator instances
111113
with ignore_warnings(category=(FutureWarning,
@@ -115,7 +117,9 @@ def test_estimators(estimator, check, request):
115117
check(estimator)
116118

117119

120+
@ignore_warnings("Passing a class is depr", category=FutureWarning) # 0.24
118121
def test_check_estimator_generate_only():
122+
# TODO in 0.24: remove checks on passing a class
119123
estimator_cls_gen_checks = check_estimator(LogisticRegression,
120124
generate_only=True)
121125
all_instance_gen_checks = check_estimator(LogisticRegression(),
@@ -238,3 +242,19 @@ def test_all_tests_are_importable():
238242
'__init__.py or an add_subpackage directive '
239243
'in the parent '
240244
'setup.py'.format(missing_tests))
245+
246+
247+
# TODO: remove in 0.24
248+
def test_class_support_deprecated():
249+
# Make sure passing classes to check_estimator or parametrize_with_checks
250+
# is deprecated
251+
252+
msg = "Passing a class is deprecated"
253+
with pytest.warns(FutureWarning, match=msg):
254+
check_estimator(LogisticRegression)
255+
256+
with pytest.warns(FutureWarning, match=msg):
257+
parametrize_with_checks([LogisticRegression])
258+
259+
# Make sure check_parameters_default_constructible accepts instances now
260+
check_parameters_default_constructible('name', LogisticRegression())

sklearn/utils/estimator_checks.py

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from ..linear_model import Ridge
3434

3535
from ..base import (clone, ClusterMixin, is_classifier, is_regressor,
36-
RegressorMixin, is_outlier_detector)
36+
RegressorMixin, is_outlier_detector, BaseEstimator)
3737

3838
from ..metrics import accuracy_score, adjusted_rand_score, f1_score
3939
from ..random_projection import BaseRandomProjection
@@ -333,12 +333,15 @@ def _construct_instance(Estimator):
333333
return estimator
334334

335335

336+
# TODO: probably not needed anymore in 0.24 since _generate_class_checks should
337+
# be removed too. Just put this in check_estimator()
336338
def _generate_instance_checks(name, estimator):
337339
"""Generate instance checks."""
338340
yield from ((estimator, partial(check, name))
339341
for check in _yield_all_checks(name, estimator))
340342

341343

344+
# TODO: remove this in 0.24
342345
def _generate_class_checks(Estimator):
343346
"""Generate class checks."""
344347
name = Estimator.__name__
@@ -353,6 +356,8 @@ def _mark_xfail_checks(estimator, check, pytest):
353356
if isinstance(estimator, type):
354357
# try to construct estimator instance, if it is unable to then
355358
# return the estimator class, ignoring the tag
359+
# TODO: remove this if block in 0.24 since passing instances isn't
360+
# supported anymore
356361
try:
357362
estimator = _construct_instance(estimator)
358363
except Exception:
@@ -385,6 +390,10 @@ def parametrize_with_checks(estimators):
385390
estimators : list of estimators objects or classes
386391
Estimators to generated checks for.
387392
393+
.. deprecated:: 0.23
394+
Passing a class is deprecated from version 0.23, and won't be
395+
supported in 0.24. Pass an instance instead.
396+
388397
Returns
389398
-------
390399
decorator : `pytest.mark.parametrize`
@@ -395,13 +404,21 @@ def parametrize_with_checks(estimators):
395404
>>> from sklearn.linear_model import LogisticRegression
396405
>>> from sklearn.tree import DecisionTreeRegressor
397406
398-
>>> @parametrize_with_checks([LogisticRegression, DecisionTreeRegressor])
407+
>>> @parametrize_with_checks([LogisticRegression(),
408+
... DecisionTreeRegressor()])
399409
... def test_sklearn_compatible_estimator(estimator, check):
400410
... check(estimator)
401411
402412
"""
403413
import pytest
404414

415+
if any(isinstance(est, type) for est in estimators):
416+
# TODO: remove class support in 0.24 and update docstrings
417+
msg = ("Passing a class is deprecated since version 0.23 "
418+
"and won't be supported in 0.24."
419+
"Please pass an instance instead.")
420+
warnings.warn(msg, FutureWarning)
421+
405422
checks_generator = chain.from_iterable(
406423
check_estimator(estimator, generate_only=True)
407424
for estimator in estimators)
@@ -418,15 +435,17 @@ def check_estimator(Estimator, generate_only=False):
418435
"""Check if estimator adheres to scikit-learn conventions.
419436
420437
This estimator will run an extensive test-suite for input validation,
421-
shapes, etc, making sure that the estimator complies with `scikit-leanrn`
438+
shapes, etc, making sure that the estimator complies with `scikit-learn`
422439
conventions as detailed in :ref:`rolling_your_own_estimator`.
423440
Additional tests for classifiers, regressors, clustering or transformers
424441
will be run if the Estimator class inherits from the corresponding mixin
425442
from sklearn.base.
426443
427444
This test can be applied to classes or instances.
428445
Classes currently have some additional tests that related to construction,
429-
while passing instances allows the testing of multiple options.
446+
while passing instances allows the testing of multiple options. However,
447+
support for classes is deprecated since version 0.23 and will be removed
448+
in version 0.24 (class checks will still be run on the instances).
430449
431450
Setting `generate_only=True` returns a generator that yields (estimator,
432451
check) tuples where the check can be called independently from each
@@ -439,9 +458,13 @@ def check_estimator(Estimator, generate_only=False):
439458
440459
Parameters
441460
----------
442-
estimator : estimator object or class
461+
estimator : estimator object
443462
Estimator to check. Estimator is a class object or instance.
444463
464+
.. deprecated:: 0.23
465+
Passing a class is deprecated from version 0.23, and won't be
466+
supported in 0.24. Pass an instance instead.
467+
445468
generate_only : bool, optional (default=False)
446469
When `False`, checks are evaluated when `check_estimator` is called.
447470
When `True`, `check_estimator` returns a generator that yields
@@ -456,8 +479,14 @@ def check_estimator(Estimator, generate_only=False):
456479
Generator that yields (estimator, check) tuples. Returned when
457480
`generate_only=True`.
458481
"""
482+
# TODO: remove class support in 0.24 and update docstrings
459483
if isinstance(Estimator, type):
460484
# got a class
485+
msg = ("Passing a class is deprecated since version 0.23 "
486+
"and won't be supported in 0.24."
487+
"Please pass an instance instead.")
488+
warnings.warn(msg, FutureWarning)
489+
461490
checks_generator = _generate_class_checks(Estimator)
462491
else:
463492
# got an instance
@@ -2570,6 +2599,12 @@ def check_parameters_default_constructible(name, Estimator):
25702599
# this check works on classes, not instances
25712600
# test default-constructibility
25722601
# get rid of deprecation warnings
2602+
if isinstance(Estimator, BaseEstimator):
2603+
# Convert estimator instance to its class
2604+
# TODO: Always convert to class in 0.24, because check_estimator() will
2605+
# only accept instances, not classes
2606+
Estimator = Estimator.__class__
2607+
25732608
with ignore_warnings(category=FutureWarning):
25742609
estimator = _construct_instance(Estimator)
25752610
# test cloning

sklearn/utils/tests/test_estimator_checks.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,7 @@ def fit(self, X, y):
356356
check_fit_score_takes_y("test", TestEstimatorWithDeprecatedFitMethod())
357357

358358

359+
@ignore_warnings("Passing a class is depr", category=FutureWarning) # 0.24
359360
def test_check_estimator():
360361
# tests that the estimator actually fails on "bad" estimators.
361362
# not a complete test of all checks, which are very extensive.
@@ -579,7 +580,10 @@ def test_check_regressor_data_not_an_array():
579580
EstimatorInconsistentForPandas())
580581

581582

583+
@ignore_warnings("Passing a class is depr", category=FutureWarning) # 0.24
582584
def test_check_estimator_required_parameters_skip():
585+
# TODO: remove whole test in 0.24 since passes classes to check_estimator()
586+
# isn't supported anymore
583587
class MyEstimator(BaseEstimator):
584588
_required_parameters = ["special_parameter"]
585589

0 commit comments

Comments
 (0)