33
33
from ..linear_model import Ridge
34
34
35
35
from ..base import (clone , ClusterMixin , is_classifier , is_regressor ,
36
- RegressorMixin , is_outlier_detector )
36
+ RegressorMixin , is_outlier_detector , BaseEstimator )
37
37
38
38
from ..metrics import accuracy_score , adjusted_rand_score , f1_score
39
39
from ..random_projection import BaseRandomProjection
@@ -333,12 +333,15 @@ def _construct_instance(Estimator):
333
333
return estimator
334
334
335
335
336
+ # TODO: probably not needed anymore in 0.24 since _generate_class_checks should
337
+ # be removed too. Just put this in check_estimator()
336
338
def _generate_instance_checks (name , estimator ):
337
339
"""Generate instance checks."""
338
340
yield from ((estimator , partial (check , name ))
339
341
for check in _yield_all_checks (name , estimator ))
340
342
341
343
344
+ # TODO: remove this in 0.24
342
345
def _generate_class_checks (Estimator ):
343
346
"""Generate class checks."""
344
347
name = Estimator .__name__
@@ -353,6 +356,8 @@ def _mark_xfail_checks(estimator, check, pytest):
353
356
if isinstance (estimator , type ):
354
357
# try to construct estimator instance, if it is unable to then
355
358
# return the estimator class, ignoring the tag
359
+ # TODO: remove this if block in 0.24 since passing instances isn't
360
+ # supported anymore
356
361
try :
357
362
estimator = _construct_instance (estimator )
358
363
except Exception :
@@ -385,6 +390,10 @@ def parametrize_with_checks(estimators):
385
390
estimators : list of estimators objects or classes
386
391
Estimators to generated checks for.
387
392
393
+ .. deprecated:: 0.23
394
+ Passing a class is deprecated from version 0.23, and won't be
395
+ supported in 0.24. Pass an instance instead.
396
+
388
397
Returns
389
398
-------
390
399
decorator : `pytest.mark.parametrize`
@@ -395,13 +404,21 @@ def parametrize_with_checks(estimators):
395
404
>>> from sklearn.linear_model import LogisticRegression
396
405
>>> from sklearn.tree import DecisionTreeRegressor
397
406
398
- >>> @parametrize_with_checks([LogisticRegression, DecisionTreeRegressor])
407
+ >>> @parametrize_with_checks([LogisticRegression(),
408
+ ... DecisionTreeRegressor()])
399
409
... def test_sklearn_compatible_estimator(estimator, check):
400
410
... check(estimator)
401
411
402
412
"""
403
413
import pytest
404
414
415
+ if any (isinstance (est , type ) for est in estimators ):
416
+ # TODO: remove class support in 0.24 and update docstrings
417
+ msg = ("Passing a class is deprecated since version 0.23 "
418
+ "and won't be supported in 0.24."
419
+ "Please pass an instance instead." )
420
+ warnings .warn (msg , FutureWarning )
421
+
405
422
checks_generator = chain .from_iterable (
406
423
check_estimator (estimator , generate_only = True )
407
424
for estimator in estimators )
@@ -418,15 +435,17 @@ def check_estimator(Estimator, generate_only=False):
418
435
"""Check if estimator adheres to scikit-learn conventions.
419
436
420
437
This estimator will run an extensive test-suite for input validation,
421
- shapes, etc, making sure that the estimator complies with `scikit-leanrn `
438
+ shapes, etc, making sure that the estimator complies with `scikit-learn `
422
439
conventions as detailed in :ref:`rolling_your_own_estimator`.
423
440
Additional tests for classifiers, regressors, clustering or transformers
424
441
will be run if the Estimator class inherits from the corresponding mixin
425
442
from sklearn.base.
426
443
427
444
This test can be applied to classes or instances.
428
445
Classes currently have some additional tests that related to construction,
429
- while passing instances allows the testing of multiple options.
446
+ while passing instances allows the testing of multiple options. However,
447
+ support for classes is deprecated since version 0.23 and will be removed
448
+ in version 0.24 (class checks will still be run on the instances).
430
449
431
450
Setting `generate_only=True` returns a generator that yields (estimator,
432
451
check) tuples where the check can be called independently from each
@@ -439,9 +458,13 @@ def check_estimator(Estimator, generate_only=False):
439
458
440
459
Parameters
441
460
----------
442
- estimator : estimator object or class
461
+ estimator : estimator object
443
462
Estimator to check. Estimator is a class object or instance.
444
463
464
+ .. deprecated:: 0.23
465
+ Passing a class is deprecated from version 0.23, and won't be
466
+ supported in 0.24. Pass an instance instead.
467
+
445
468
generate_only : bool, optional (default=False)
446
469
When `False`, checks are evaluated when `check_estimator` is called.
447
470
When `True`, `check_estimator` returns a generator that yields
@@ -456,8 +479,14 @@ def check_estimator(Estimator, generate_only=False):
456
479
Generator that yields (estimator, check) tuples. Returned when
457
480
`generate_only=True`.
458
481
"""
482
+ # TODO: remove class support in 0.24 and update docstrings
459
483
if isinstance (Estimator , type ):
460
484
# got a class
485
+ msg = ("Passing a class is deprecated since version 0.23 "
486
+ "and won't be supported in 0.24."
487
+ "Please pass an instance instead." )
488
+ warnings .warn (msg , FutureWarning )
489
+
461
490
checks_generator = _generate_class_checks (Estimator )
462
491
else :
463
492
# got an instance
@@ -2570,6 +2599,12 @@ def check_parameters_default_constructible(name, Estimator):
2570
2599
# this check works on classes, not instances
2571
2600
# test default-constructibility
2572
2601
# get rid of deprecation warnings
2602
+ if isinstance (Estimator , BaseEstimator ):
2603
+ # Convert estimator instance to its class
2604
+ # TODO: Always convert to class in 0.24, because check_estimator() will
2605
+ # only accept instances, not classes
2606
+ Estimator = Estimator .__class__
2607
+
2573
2608
with ignore_warnings (category = FutureWarning ):
2574
2609
estimator = _construct_instance (Estimator )
2575
2610
# test cloning
0 commit comments