-
-
Notifications
You must be signed in to change notification settings - Fork 26.2k
[MRG] Deprecate class support for check_estimator and parametrize_with_checks #17032
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
d53ec50
ccf9963
c013f4f
ddd67db
a4bca6c
e63e97c
6b36364
04a4d23
fe2f4b6
2fdf3d9
b307f2d
04df1f9
dcc41a9
e7396db
268728b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,7 +33,7 @@ | |
from ..linear_model import Ridge | ||
|
||
from ..base import (clone, ClusterMixin, is_classifier, is_regressor, | ||
RegressorMixin, is_outlier_detector) | ||
RegressorMixin, is_outlier_detector, BaseEstimator) | ||
|
||
from ..metrics import accuracy_score, adjusted_rand_score, f1_score | ||
from ..random_projection import BaseRandomProjection | ||
|
@@ -333,12 +333,15 @@ def _construct_instance(Estimator): | |
return estimator | ||
|
||
|
||
# TODO: probably not needed anymore in 0.24 since _generate_class_checks should | ||
# be removed too. Just put this in check_estimator() | ||
def _generate_instance_checks(name, estimator): | ||
"""Generate instance checks.""" | ||
yield from ((estimator, partial(check, name)) | ||
for check in _yield_all_checks(name, estimator)) | ||
|
||
|
||
# TODO: remove this in 0.24 | ||
def _generate_class_checks(Estimator): | ||
"""Generate class checks.""" | ||
name = Estimator.__name__ | ||
|
@@ -353,6 +356,8 @@ def _mark_xfail_checks(estimator, check, pytest): | |
if isinstance(estimator, type): | ||
# try to construct estimator instance, if it is unable to then | ||
# return the estimator class, ignoring the tag | ||
# TODO: remove this if block in 0.24 since passing instances isn't | ||
# supported anymore | ||
try: | ||
estimator = _construct_instance(estimator) | ||
except Exception: | ||
|
@@ -385,6 +390,10 @@ def parametrize_with_checks(estimators): | |
estimators : list of estimators objects or classes | ||
Estimators to generated checks for. | ||
|
||
.. deprecated:: 0.23 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess it should be |
||
Passing a class is deprecated from version 0.23, and won't be | ||
supported in 0.24. Pass an instance instead. | ||
|
||
Returns | ||
------- | ||
decorator : `pytest.mark.parametrize` | ||
|
@@ -395,13 +404,21 @@ def parametrize_with_checks(estimators): | |
>>> from sklearn.linear_model import LogisticRegression | ||
>>> from sklearn.tree import DecisionTreeRegressor | ||
|
||
>>> @parametrize_with_checks([LogisticRegression, DecisionTreeRegressor]) | ||
>>> @parametrize_with_checks([LogisticRegression(), | ||
... DecisionTreeRegressor()]) | ||
... def test_sklearn_compatible_estimator(estimator, check): | ||
... check(estimator) | ||
|
||
""" | ||
import pytest | ||
|
||
if any(isinstance(est, type) for est in estimators): | ||
# TODO: remove class support in 0.24 and update docstrings | ||
msg = ("Passing a class is deprecated since version 0.23 " | ||
"and won't be supported in 0.24." | ||
"Please pass an instance instead.") | ||
warnings.warn(msg, FutureWarning) | ||
|
||
checks_generator = chain.from_iterable( | ||
check_estimator(estimator, generate_only=True) | ||
for estimator in estimators) | ||
|
@@ -418,15 +435,17 @@ def check_estimator(Estimator, generate_only=False): | |
"""Check if estimator adheres to scikit-learn conventions. | ||
|
||
This estimator will run an extensive test-suite for input validation, | ||
shapes, etc, making sure that the estimator complies with `scikit-leanrn` | ||
shapes, etc, making sure that the estimator complies with `scikit-learn` | ||
conventions as detailed in :ref:`rolling_your_own_estimator`. | ||
Additional tests for classifiers, regressors, clustering or transformers | ||
will be run if the Estimator class inherits from the corresponding mixin | ||
from sklearn.base. | ||
|
||
This test can be applied to classes or instances. | ||
Classes currently have some additional tests that related to construction, | ||
while passing instances allows the testing of multiple options. | ||
while passing instances allows the testing of multiple options. However, | ||
support for classes is deprecated since version 0.23 and will be removed | ||
in version 0.24 (class checks will still be run on the instances). | ||
|
||
Setting `generate_only=True` returns a generator that yields (estimator, | ||
check) tuples where the check can be called independently from each | ||
|
@@ -439,9 +458,13 @@ def check_estimator(Estimator, generate_only=False): | |
|
||
Parameters | ||
---------- | ||
estimator : estimator object or class | ||
estimator : estimator object | ||
Estimator to check. Estimator is a class object or instance. | ||
|
||
.. deprecated:: 0.23 | ||
Passing a class is deprecated from version 0.23, and won't be | ||
supported in 0.24. Pass an instance instead. | ||
|
||
generate_only : bool, optional (default=False) | ||
When `False`, checks are evaluated when `check_estimator` is called. | ||
When `True`, `check_estimator` returns a generator that yields | ||
|
@@ -456,8 +479,14 @@ def check_estimator(Estimator, generate_only=False): | |
Generator that yields (estimator, check) tuples. Returned when | ||
`generate_only=True`. | ||
""" | ||
# TODO: remove class support in 0.24 and update docstrings | ||
if isinstance(Estimator, type): | ||
# got a class | ||
msg = ("Passing a class is deprecated since version 0.23 " | ||
"and won't be supported in 0.24." | ||
"Please pass an instance instead.") | ||
warnings.warn(msg, FutureWarning) | ||
|
||
checks_generator = _generate_class_checks(Estimator) | ||
else: | ||
# got an instance | ||
|
@@ -2570,6 +2599,12 @@ def check_parameters_default_constructible(name, Estimator): | |
# this check works on classes, not instances | ||
# test default-constructibility | ||
# get rid of deprecation warnings | ||
if isinstance(Estimator, BaseEstimator): | ||
# Convert estimator instance to its class | ||
# TODO: Always convert to class in 0.24, because check_estimator() will | ||
# only accept instances, not classes | ||
Estimator = Estimator.__class__ | ||
|
||
with ignore_warnings(category=FutureWarning): | ||
estimator = _construct_instance(Estimator) | ||
# test cloning | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
without this, the iterator would be consumed in
parametrize_with_checks
due to the addition ofif any(isinstance(est, type) for est in estimators):
Note that the docstring of
parametrize_with_checks
says that its argument should be a list.