5
5
6
6
import pickle
7
7
import re
8
+ import textwrap
8
9
import warnings
9
10
from contextlib import nullcontext
10
11
from copy import deepcopy
@@ -95,6 +96,7 @@ def _yield_api_checks(estimator):
95
96
if tags .requires_fit :
96
97
yield check_estimators_unfitted
97
98
yield check_do_not_raise_errors_in_init_or_set_params
99
+ yield check_n_features_in_after_fitting
98
100
99
101
100
102
def _yield_checks (estimator ):
@@ -441,7 +443,7 @@ def _should_be_skipped_or_marked(estimator, check):
441
443
return False , "placeholder reason that will never be used"
442
444
443
445
444
- def parametrize_with_checks (estimators , * , legacy = True ):
446
+ def parametrize_with_checks (estimators , * , legacy : bool = True ):
445
447
"""Pytest specific decorator for parametrizing estimator checks.
446
448
447
449
Checks are categorised into the following groups:
@@ -468,6 +470,7 @@ def parametrize_with_checks(estimators, *, legacy=True):
468
470
469
471
.. versionadded:: 0.24
470
472
473
+
471
474
legacy : bool, default=True
472
475
Whether to include legacy checks. Over time we remove checks from this category
473
476
and move them into their specific category.
@@ -520,7 +523,7 @@ def checks_generator():
520
523
)
521
524
522
525
523
- def check_estimator (estimator = None , generate_only = False , * , legacy = True ):
526
+ def check_estimator (estimator = None , generate_only = False , * , legacy : bool = True ):
524
527
"""Check if estimator adheres to scikit-learn conventions.
525
528
526
529
This function will run an extensive test-suite for input validation,
@@ -2009,13 +2012,14 @@ def check_estimators_partial_fit_n_features(name, estimator_orig):
2009
2012
2010
2013
2011
2014
@ignore_warnings (category = FutureWarning )
2012
- def check_classifier_multioutput (name , estimator ):
2015
+ def check_classifier_multioutput (name , estimator_orig ):
2013
2016
n_samples , n_labels , n_classes = 42 , 5 , 3
2014
- tags = get_tags (estimator )
2015
- estimator = clone (estimator )
2017
+ tags = get_tags (estimator_orig )
2018
+ estimator = clone (estimator_orig )
2016
2019
X , y = make_multilabel_classification (
2017
2020
random_state = 42 , n_samples = n_samples , n_labels = n_labels , n_classes = n_classes
2018
2021
)
2022
+ X = _enforce_estimator_tags_X (estimator , X )
2019
2023
estimator .fit (X , y )
2020
2024
y_pred = estimator .predict (X )
2021
2025
@@ -2174,13 +2178,14 @@ def check_clusterer_compute_labels_predict(name, clusterer_orig):
2174
2178
def check_classifiers_one_label (name , classifier_orig ):
2175
2179
error_string_fit = "Classifier can't train when only one class is present."
2176
2180
error_string_predict = "Classifier can't predict when only one class is present."
2181
+ classifier = clone (classifier_orig )
2177
2182
rnd = np .random .RandomState (0 )
2178
2183
X_train = rnd .uniform (size = (10 , 3 ))
2179
2184
X_test = rnd .uniform (size = (10 , 3 ))
2185
+ X_train , X_test = _enforce_estimator_tags_X (classifier , X_train , X_test = X_test )
2180
2186
y = np .ones (10 )
2181
2187
# catch deprecation warnings
2182
2188
with ignore_warnings (category = FutureWarning ):
2183
- classifier = clone (classifier_orig )
2184
2189
with raises (
2185
2190
ValueError , match = "class" , may_pass = True , err_msg = error_string_fit
2186
2191
) as cm :
@@ -2505,6 +2510,7 @@ def check_classifiers_multilabel_representation_invariance(name, classifier_orig
2505
2510
2506
2511
X_train , y_train = X [:80 ], y [:80 ]
2507
2512
X_test = X [80 :]
2513
+ X_train , X_test = _enforce_estimator_tags_X (classifier_orig , X_train , X_test = X_test )
2508
2514
2509
2515
y_train_list_of_lists = y_train .tolist ()
2510
2516
y_train_list_of_arrays = list (y_train )
@@ -2552,6 +2558,7 @@ def check_classifiers_multilabel_output_format_predict(name, classifier_orig):
2552
2558
2553
2559
X_train , X_test = X [:- test_size ], X [- test_size :]
2554
2560
y_train , y_test = y [:- test_size ], y [- test_size :]
2561
+ X_train , X_test = _enforce_estimator_tags_X (classifier_orig , X_train , X_test = X_test )
2555
2562
classifier .fit (X_train , y_train )
2556
2563
2557
2564
response_method_name = "predict"
@@ -2597,6 +2604,7 @@ def check_classifiers_multilabel_output_format_predict_proba(name, classifier_or
2597
2604
2598
2605
X_train , X_test = X [:- test_size ], X [- test_size :]
2599
2606
y_train = y [:- test_size ]
2607
+ X_train , X_test = _enforce_estimator_tags_X (classifier_orig , X_train , X_test = X_test )
2600
2608
classifier .fit (X_train , y_train )
2601
2609
2602
2610
response_method_name = "predict_proba"
@@ -2681,6 +2689,7 @@ def check_classifiers_multilabel_output_format_decision_function(name, classifie
2681
2689
2682
2690
X_train , X_test = X [:- test_size ], X [- test_size :]
2683
2691
y_train = y [:- test_size ]
2692
+ X_train , X_test = _enforce_estimator_tags_X (classifier_orig , X_train , X_test = X_test )
2684
2693
classifier .fit (X_train , y_train )
2685
2694
2686
2695
response_method_name = "decision_function"
@@ -3474,30 +3483,48 @@ def _enforce_estimator_tags_y(estimator, y):
3474
3483
return y
3475
3484
3476
3485
3477
- def _enforce_estimator_tags_X (estimator , X , kernel = linear_kernel ):
3486
+ def _enforce_estimator_tags_X (estimator , X , X_test = None , kernel = linear_kernel ):
3478
3487
# Estimators with `1darray` in `X_types` tag only accept
3479
3488
# X of shape (`n_samples`,)
3480
3489
if get_tags (estimator ).input_tags .one_d_array :
3481
3490
X = X [:, 0 ]
3491
+ if X_test is not None :
3492
+ X_test = X_test [:, 0 ] # pragma: no cover
3482
3493
# Estimators with a `requires_positive_X` tag only accept
3483
3494
# strictly positive data
3484
3495
if get_tags (estimator ).input_tags .positive_only :
3485
3496
X = X - X .min ()
3497
+ if X_test is not None :
3498
+ X_test = X_test - X_test .min () # pragma: no cover
3486
3499
if get_tags (estimator ).input_tags .categorical :
3487
3500
dtype = np .float64 if get_tags (estimator ).input_tags .allow_nan else np .int32
3488
3501
X = np .round ((X - X .min ())).astype (dtype )
3502
+ if X_test is not None :
3503
+ X_test = np .round ((X_test - X_test .min ())).astype (dtype ) # pragma: no cover
3489
3504
3490
3505
if estimator .__class__ .__name__ == "SkewedChi2Sampler" :
3491
3506
# SkewedChi2Sampler requires X > -skewdness in transform
3492
3507
X = X - X .min ()
3508
+ if X_test is not None :
3509
+ X_test = X_test - X_test .min () # pragma: no cover
3510
+
3511
+ X_res = X
3493
3512
3494
3513
# Pairwise estimators only accept
3495
3514
# X of shape (`n_samples`, `n_samples`)
3496
3515
if _is_pairwise_metric (estimator ):
3497
- X = pairwise_distances (X , metric = "euclidean" )
3516
+ X_res = pairwise_distances (X , metric = "euclidean" )
3517
+ if X_test is not None :
3518
+ X_test = pairwise_distances (
3519
+ X_test , X , metric = "euclidean"
3520
+ ) # pragma: no cover
3498
3521
elif get_tags (estimator ).input_tags .pairwise :
3499
- X = kernel (X , X )
3500
- return X
3522
+ X_res = kernel (X , X )
3523
+ if X_test is not None :
3524
+ X_test = kernel (X_test , X ) # pragma: no cover
3525
+ if X_test is not None :
3526
+ return X_res , X_test
3527
+ return X_res
3501
3528
3502
3529
3503
3530
@ignore_warnings (category = FutureWarning )
@@ -3913,8 +3940,16 @@ def check_n_features_in_after_fitting(name, estimator_orig):
3913
3940
y = rng .randint (low = 0 , high = 2 , size = n_samples )
3914
3941
y = _enforce_estimator_tags_y (estimator , y )
3915
3942
3943
+ err_msg = (
3944
+ "`{name}.fit()` does not set the `n_features_in_` attribute. "
3945
+ "You might want to use `sklearn.utils.validation.validate_data` instead "
3946
+ "of `check_array` in `{name}.fit()` which takes care of setting the "
3947
+ "attribute." .format (name = name )
3948
+ )
3949
+
3916
3950
estimator .fit (X , y )
3917
- assert estimator .n_features_in_ == X .shape [1 ]
3951
+ assert hasattr (estimator , "n_features_in_" ), err_msg
3952
+ assert estimator .n_features_in_ == X .shape [1 ], err_msg
3918
3953
3919
3954
# check methods will check n_features_in_
3920
3955
check_methods = [
@@ -3926,6 +3961,28 @@ def check_n_features_in_after_fitting(name, estimator_orig):
3926
3961
]
3927
3962
X_bad = X [:, [1 ]]
3928
3963
3964
+ err_msg = """\
3965
+ `{name}.{method}()` does not check for consistency between input number
3966
+ of features with {name}.fit(), via the `n_features_in_` attribute.
3967
+ You might want to use `sklearn.utils.validation.validate_data` instead
3968
+ of `check_array` in `{name}.fit()` and {name}.{method}()`. This can be done
3969
+ like the following:
3970
+ from sklearn.utils.validation import validate_data
3971
+ ...
3972
+ class MyEstimator(BaseEstimator):
3973
+ ...
3974
+ def fit(self, X, y):
3975
+ X, y = validate_data(self, X, y, ...)
3976
+ ...
3977
+ return self
3978
+ ...
3979
+ def {method}(self, X):
3980
+ X = validate_data(self, X, ..., reset=False)
3981
+ ...
3982
+ return X
3983
+ """
3984
+ err_msg = textwrap .dedent (err_msg )
3985
+
3929
3986
msg = f"X has 1 features, but \\ w+ is expecting { X .shape [1 ]} features as input"
3930
3987
for method in check_methods :
3931
3988
if not hasattr (estimator , method ):
@@ -3935,7 +3992,9 @@ def check_n_features_in_after_fitting(name, estimator_orig):
3935
3992
if method == "score" :
3936
3993
callable_method = partial (callable_method , y = y )
3937
3994
3938
- with raises (ValueError , match = msg ):
3995
+ with raises (
3996
+ ValueError , match = msg , err_msg = err_msg .format (name = name , method = method )
3997
+ ):
3939
3998
callable_method (X_bad )
3940
3999
3941
4000
# partial_fit will check in the second call
0 commit comments