33
33
from ..linear_model import Ridge
34
34
35
35
from ..base import (clone , ClusterMixin , is_classifier , is_regressor ,
36
- _DEFAULT_TAGS , RegressorMixin , is_outlier_detector )
36
+ RegressorMixin , is_outlier_detector )
37
37
38
38
from ..metrics import accuracy_score , adjusted_rand_score , f1_score
39
39
from ..random_projection import BaseRandomProjection
58
58
BOSTON = None
59
59
CROSS_DECOMPOSITION = ['PLSCanonical' , 'PLSRegression' , 'CCA' , 'PLSSVD' ]
60
60
61
- def _safe_tags (estimator , key = None ):
62
- # if estimator doesn't have _get_tags, use _DEFAULT_TAGS
63
- # if estimator has tags but not key, use _DEFAULT_TAGS[key]
64
- if hasattr (estimator , "_get_tags" ):
65
- if key is not None :
66
- return estimator ._get_tags ().get (key , _DEFAULT_TAGS [key ])
67
- tags = estimator ._get_tags ()
68
- return {key : tags .get (key , _DEFAULT_TAGS [key ])
69
- for key in _DEFAULT_TAGS .keys ()}
70
- if key is not None :
71
- return _DEFAULT_TAGS [key ]
72
- return _DEFAULT_TAGS
73
-
74
61
75
62
def _yield_checks (name , estimator ):
76
- tags = _safe_tags ( estimator )
63
+ tags = estimator . _get_tags ( )
77
64
yield check_no_attributes_set_in_init
78
65
yield check_estimators_dtypes
79
66
yield check_fit_score_takes_y
@@ -116,7 +103,7 @@ def _yield_checks(name, estimator):
116
103
117
104
118
105
def _yield_classifier_checks (name , classifier ):
119
- tags = _safe_tags ( classifier )
106
+ tags = classifier . _get_tags ( )
120
107
121
108
# test classifiers can handle non-array data and pandas objects
122
109
yield check_classifier_data_not_an_array
@@ -171,7 +158,7 @@ def check_supervised_y_no_nan(name, estimator_orig):
171
158
172
159
173
160
def _yield_regressor_checks (name , regressor ):
174
- tags = _safe_tags ( regressor )
161
+ tags = regressor . _get_tags ( )
175
162
# TODO: test with intercept
176
163
# TODO: test with multiple responses
177
164
# basic testing
@@ -198,12 +185,12 @@ def _yield_regressor_checks(name, regressor):
198
185
def _yield_transformer_checks (name , transformer ):
199
186
# All transformers should either deal with sparse data or raise an
200
187
# exception with type TypeError and an intelligible error message
201
- if not _safe_tags ( transformer , "no_validation" ) :
188
+ if not transformer . _get_tags ()[ "no_validation" ] :
202
189
yield check_transformer_data_not_an_array
203
190
# these don't actually fit the data, so don't raise errors
204
191
yield check_transformer_general
205
192
yield partial (check_transformer_general , readonly_memmap = True )
206
- if not _safe_tags ( transformer , "stateless" ) :
193
+ if not transformer . _get_tags ()[ "stateless" ] :
207
194
yield check_transformers_unfitted
208
195
# Dependent on external solvers and hence accessing the iter
209
196
# param is non-trivial.
@@ -237,12 +224,12 @@ def _yield_outliers_checks(name, estimator):
237
224
# test outlier detectors can handle non-array data
238
225
yield check_classifier_data_not_an_array
239
226
# test if NotFittedError is raised
240
- if _safe_tags ( estimator , "requires_fit" ) :
227
+ if estimator . _get_tags ()[ "requires_fit" ] :
241
228
yield check_estimators_unfitted
242
229
243
230
244
231
def _yield_all_checks (name , estimator ):
245
- tags = _safe_tags ( estimator )
232
+ tags = estimator . _get_tags ( )
246
233
if "2darray" not in tags ["X_types" ]:
247
234
warnings .warn ("Can't test estimator {} which requires input "
248
235
" of type {}" .format (name , tags ["X_types" ]),
@@ -369,7 +356,7 @@ def _mark_xfail_checks(estimator, check, pytest):
369
356
except Exception :
370
357
return estimator , check
371
358
372
- xfail_checks = _safe_tags ( estimator , '_xfail_checks' ) or {}
359
+ xfail_checks = estimator . _get_tags ()[ '_xfail_checks' ] or {}
373
360
check_name = _set_check_estimator_ids (check )
374
361
375
362
if check_name not in xfail_checks :
@@ -701,7 +688,7 @@ def check_estimator_sparse_data(name, estimator_orig):
701
688
X [X < .8 ] = 0
702
689
X = _pairwise_estimator_convert_X (X , estimator_orig )
703
690
X_csr = sparse .csr_matrix (X )
704
- tags = _safe_tags ( estimator_orig )
691
+ tags = estimator_orig . _get_tags ( )
705
692
if tags ['binary_only' ]:
706
693
y = (2 * rng .rand (40 )).astype (np .int )
707
694
else :
@@ -767,7 +754,7 @@ def check_sample_weights_pandas_series(name, estimator_orig):
767
754
X = pd .DataFrame (_pairwise_estimator_convert_X (X , estimator_orig ))
768
755
y = pd .Series ([1 , 1 , 1 , 1 , 2 , 2 , 2 , 2 , 1 , 1 , 2 , 2 ])
769
756
weights = pd .Series ([1 ] * 12 )
770
- if _safe_tags ( estimator , "multioutput_only" ) :
757
+ if estimator . _get_tags ()[ "multioutput_only" ] :
771
758
y = pd .DataFrame (y )
772
759
try :
773
760
estimator .fit (X , y , sample_weight = weights )
@@ -792,7 +779,7 @@ def check_sample_weights_not_an_array(name, estimator_orig):
792
779
X = _NotAnArray (pairwise_estimator_convert_X (X , estimator_orig ))
793
780
y = _NotAnArray ([1 , 1 , 1 , 1 , 2 , 2 , 2 , 2 , 1 , 1 , 2 , 2 ])
794
781
weights = _NotAnArray ([1 ] * 12 )
795
- if _safe_tags ( estimator , "multioutput_only" ) :
782
+ if estimator . _get_tags ()[ "multioutput_only" ] :
796
783
y = _NotAnArray (y .data .reshape (- 1 , 1 ))
797
784
estimator .fit (X , y , sample_weight = weights )
798
785
@@ -806,8 +793,8 @@ def check_sample_weights_list(name, estimator_orig):
806
793
rnd = np .random .RandomState (0 )
807
794
n_samples = 30
808
795
X = _pairwise_estimator_convert_X (rnd .uniform (size = (n_samples , 3 )),
809
- estimator_orig )
810
- if _safe_tags ( estimator , 'binary_only' ) :
796
+ estimator_orig )
797
+ if estimator . _get_tags ()[ 'binary_only' ] :
811
798
y = np .arange (n_samples ) % 2
812
799
else :
813
800
y = np .arange (n_samples ) % 3
@@ -886,7 +873,7 @@ def check_dtype_object(name, estimator_orig):
886
873
rng = np .random .RandomState (0 )
887
874
X = _pairwise_estimator_convert_X (rng .rand (40 , 10 ), estimator_orig )
888
875
X = X .astype (object )
889
- tags = _safe_tags ( estimator_orig )
876
+ tags = estimator_orig . _get_tags ( )
890
877
if tags ['binary_only' ]:
891
878
y = (X [:, 0 ] * 2 ).astype (np .int )
892
879
else :
@@ -990,7 +977,7 @@ def check_dont_overwrite_parameters(name, estimator_orig):
990
977
X = 3 * rnd .uniform (size = (20 , 3 ))
991
978
X = _pairwise_estimator_convert_X (X , estimator_orig )
992
979
y = X [:, 0 ].astype (np .int )
993
- if _safe_tags ( estimator , 'binary_only' ) :
980
+ if estimator . _get_tags ()[ 'binary_only' ] :
994
981
y [y == 2 ] = 1
995
982
y = _enforce_estimator_tags_y (estimator , y )
996
983
@@ -1041,7 +1028,7 @@ def check_fit2d_predict1d(name, estimator_orig):
1041
1028
X = 3 * rnd .uniform (size = (20 , 3 ))
1042
1029
X = _pairwise_estimator_convert_X (X , estimator_orig )
1043
1030
y = X [:, 0 ].astype (np .int )
1044
- tags = _safe_tags ( estimator_orig )
1031
+ tags = estimator_orig . _get_tags ( )
1045
1032
if tags ['binary_only' ]:
1046
1033
y [y == 2 ] = 1
1047
1034
estimator = clone (estimator_orig )
@@ -1092,7 +1079,7 @@ def check_methods_subset_invariance(name, estimator_orig):
1092
1079
X = 3 * rnd .uniform (size = (20 , 3 ))
1093
1080
X = _pairwise_estimator_convert_X (X , estimator_orig )
1094
1081
y = X [:, 0 ].astype (np .int )
1095
- if _safe_tags ( estimator_orig , 'binary_only' ) :
1082
+ if estimator_orig . _get_tags ()[ 'binary_only' ] :
1096
1083
y [y == 2 ] = 1
1097
1084
estimator = clone (estimator_orig )
1098
1085
y = _enforce_estimator_tags_y (estimator , y )
@@ -1193,7 +1180,7 @@ def check_fit1d(name, estimator_orig):
1193
1180
X = 3 * rnd .uniform (size = (20 ))
1194
1181
y = X .astype (np .int )
1195
1182
estimator = clone (estimator_orig )
1196
- tags = _safe_tags ( estimator )
1183
+ tags = estimator . _get_tags ( )
1197
1184
if tags ["no_validation" ]:
1198
1185
# FIXME this is a bit loose
1199
1186
return
@@ -1285,7 +1272,7 @@ def _check_transformer(name, transformer_orig, X, y):
1285
1272
X_pred2 = transformer .transform (X )
1286
1273
X_pred3 = transformer .fit_transform (X , y = y_ )
1287
1274
1288
- if _safe_tags ( transformer_orig , 'non_deterministic' ) :
1275
+ if transformer_orig . _get_tags ()[ 'non_deterministic' ] :
1289
1276
msg = name + ' is non deterministic'
1290
1277
raise SkipTest (msg )
1291
1278
if isinstance (X_pred , tuple ) and isinstance (X_pred2 , tuple ):
@@ -1316,7 +1303,7 @@ def _check_transformer(name, transformer_orig, X, y):
1316
1303
1317
1304
# raises error on malformed input for transform
1318
1305
if hasattr (X , 'shape' ) and \
1319
- not _safe_tags ( transformer , "stateless" ) and \
1306
+ not transformer . _get_tags ()[ "stateless" ] and \
1320
1307
X .ndim == 2 and X .shape [1 ] > 1 :
1321
1308
1322
1309
# If it's not an array, it does not have a 'T' property
@@ -1330,7 +1317,7 @@ def _check_transformer(name, transformer_orig, X, y):
1330
1317
1331
1318
@ignore_warnings
1332
1319
def check_pipeline_consistency (name , estimator_orig ):
1333
- if _safe_tags ( estimator_orig , 'non_deterministic' ) :
1320
+ if estimator_orig . _get_tags ()[ 'non_deterministic' ] :
1334
1321
msg = name + ' is non deterministic'
1335
1322
raise SkipTest (msg )
1336
1323
@@ -1365,7 +1352,7 @@ def check_fit_score_takes_y(name, estimator_orig):
1365
1352
n_samples = 30
1366
1353
X = rnd .uniform (size = (n_samples , 3 ))
1367
1354
X = _pairwise_estimator_convert_X (X , estimator_orig )
1368
- if _safe_tags ( estimator_orig , 'binary_only' ) :
1355
+ if estimator_orig . _get_tags ()[ 'binary_only' ] :
1369
1356
y = np .arange (n_samples ) % 2
1370
1357
else :
1371
1358
y = np .arange (n_samples ) % 3
@@ -1398,7 +1385,7 @@ def check_estimators_dtypes(name, estimator_orig):
1398
1385
X_train_int_64 = X_train_32 .astype (np .int64 )
1399
1386
X_train_int_32 = X_train_32 .astype (np .int32 )
1400
1387
y = X_train_int_64 [:, 0 ]
1401
- if _safe_tags ( estimator_orig , 'binary_only' ) :
1388
+ if estimator_orig . _get_tags ()[ 'binary_only' ] :
1402
1389
y [y == 2 ] = 1
1403
1390
y = _enforce_estimator_tags_y (estimator_orig , y )
1404
1391
@@ -1534,7 +1521,7 @@ def check_estimators_pickle(name, estimator_orig):
1534
1521
X -= X .min ()
1535
1522
X = _pairwise_estimator_convert_X (X , estimator_orig , kernel = rbf_kernel )
1536
1523
1537
- tags = _safe_tags ( estimator_orig )
1524
+ tags = estimator_orig . _get_tags ( )
1538
1525
# include NaN values when the estimator should deal with them
1539
1526
if tags ['allow_nan' ]:
1540
1527
# set randomly 10 elements to np.nan
@@ -1599,7 +1586,7 @@ def check_estimators_partial_fit_n_features(name, estimator_orig):
1599
1586
@ignore_warnings (category = FutureWarning )
1600
1587
def check_classifier_multioutput (name , estimator ):
1601
1588
n_samples , n_labels , n_classes = 42 , 5 , 3
1602
- tags = _safe_tags ( estimator )
1589
+ tags = estimator . _get_tags ( )
1603
1590
estimator = clone (estimator )
1604
1591
X , y = make_multilabel_classification (random_state = 42 ,
1605
1592
n_samples = n_samples ,
@@ -1706,7 +1693,7 @@ def check_clustering(name, clusterer_orig, readonly_memmap=False):
1706
1693
pred = clusterer .labels_
1707
1694
assert pred .shape == (n_samples ,)
1708
1695
assert adjusted_rand_score (pred , y ) > 0.4
1709
- if _safe_tags ( clusterer , 'non_deterministic' ) :
1696
+ if clusterer . _get_tags ()[ 'non_deterministic' ] :
1710
1697
return
1711
1698
set_random_state (clusterer )
1712
1699
with warnings .catch_warnings (record = True ):
@@ -1805,7 +1792,7 @@ def check_classifiers_train(name, classifier_orig, readonly_memmap=False,
1805
1792
X_m , y_m , X_b , y_b = create_memmap_backed_data ([X_m , y_m , X_b , y_b ])
1806
1793
1807
1794
problems = [(X_b , y_b )]
1808
- tags = _safe_tags ( classifier_orig )
1795
+ tags = classifier_orig . _get_tags ( )
1809
1796
if not tags ['binary_only' ]:
1810
1797
problems .append ((X_m , y_m ))
1811
1798
@@ -2044,7 +2031,7 @@ def check_classifiers_multilabel_representation_invariance(name,
2044
2031
def check_estimators_fit_returns_self (name , estimator_orig ,
2045
2032
readonly_memmap = False ):
2046
2033
"""Check if self is returned when calling fit"""
2047
- if _safe_tags ( estimator_orig , 'binary_only' ) :
2034
+ if estimator_orig . _get_tags ()[ 'binary_only' ] :
2048
2035
n_centers = 2
2049
2036
else :
2050
2037
n_centers = 3
@@ -2081,7 +2068,7 @@ def check_estimators_unfitted(name, estimator_orig):
2081
2068
2082
2069
@ignore_warnings (category = FutureWarning )
2083
2070
def check_supervised_y_2d (name , estimator_orig ):
2084
- tags = _safe_tags ( estimator_orig )
2071
+ tags = estimator_orig . _get_tags ( )
2085
2072
if tags ['multioutput_only' ]:
2086
2073
# These only work on 2d, so this test makes no sense
2087
2074
return
@@ -2197,7 +2184,7 @@ def check_classifiers_classes(name, classifier_orig):
2197
2184
y_names_binary = np .take (labels_binary , y_binary )
2198
2185
2199
2186
problems = [(X_binary , y_binary , y_names_binary )]
2200
- if not _safe_tags ( classifier_orig , 'binary_only' ) :
2187
+ if not classifier_orig . _get_tags ()[ 'binary_only' ] :
2201
2188
problems .append ((X_multiclass , y_multiclass , y_names_multiclass ))
2202
2189
2203
2190
for X , y , y_names in problems :
@@ -2282,7 +2269,7 @@ def check_regressors_train(name, regressor_orig, readonly_memmap=False,
2282
2269
# TODO: find out why PLS and CCA fail. RANSAC is random
2283
2270
# and furthermore assumes the presence of outliers, hence
2284
2271
# skipped
2285
- if not _safe_tags ( regressor , "poor_score" ) :
2272
+ if not regressor . _get_tags ()[ "poor_score" ] :
2286
2273
assert regressor .score (X , y_ ) > 0.5
2287
2274
2288
2275
@@ -2315,7 +2302,7 @@ def check_regressors_no_decision_function(name, regressor_orig):
2315
2302
@ignore_warnings (category = FutureWarning )
2316
2303
def check_class_weight_classifiers (name , classifier_orig ):
2317
2304
2318
- if _safe_tags ( classifier_orig , 'binary_only' ) :
2305
+ if classifier_orig . _get_tags ()[ 'binary_only' ] :
2319
2306
problems = [2 ]
2320
2307
else :
2321
2308
problems = [2 , 3 ]
@@ -2418,7 +2405,7 @@ def check_class_weight_balanced_linear_classifier(name, Classifier):
2418
2405
2419
2406
@ignore_warnings (category = FutureWarning )
2420
2407
def check_estimators_overwrite_params (name , estimator_orig ):
2421
- if _safe_tags ( estimator_orig , 'binary_only' ) :
2408
+ if estimator_orig . _get_tags ()[ 'binary_only' ] :
2422
2409
n_centers = 2
2423
2410
else :
2424
2411
n_centers = 3
@@ -2654,13 +2641,13 @@ def enforce_estimator_tags_y(estimator, y):
2654
2641
def _enforce_estimator_tags_y (estimator , y ):
2655
2642
# Estimators with a `requires_positive_y` tag only accept strictly positive
2656
2643
# data
2657
- if _safe_tags ( estimator , "requires_positive_y" ) :
2644
+ if estimator . _get_tags ()[ "requires_positive_y" ] :
2658
2645
# Create strictly positive y. The minimal increment above 0 is 1, as
2659
2646
# y could be of integer dtype.
2660
2647
y += 1 + abs (y .min ())
2661
2648
# Estimators in mono_output_task_error raise ValueError if y is of 1-D
2662
2649
# Convert into a 2-D y for those estimators.
2663
- if _safe_tags ( estimator , "multioutput_only" ) :
2650
+ if estimator . _get_tags ()[ "multioutput_only" ] :
2664
2651
return np .reshape (y , (- 1 , 1 ))
2665
2652
return y
2666
2653
@@ -2672,11 +2659,11 @@ def _enforce_estimator_tags_x(estimator, X):
2672
2659
X = X .dot (X .T )
2673
2660
# Estimators with `1darray` in `X_types` tag only accept
2674
2661
# X of shape (`n_samples`,)
2675
- if '1darray' in _safe_tags ( estimator , 'X_types' ) :
2662
+ if '1darray' in estimator . _get_tags ()[ 'X_types' ] :
2676
2663
X = X [:, 0 ]
2677
2664
# Estimators with a `requires_positive_X` tag only accept
2678
2665
# strictly positive data
2679
- if _safe_tags ( estimator , 'requires_positive_X' ) :
2666
+ if estimator . _get_tags ()[ 'requires_positive_X' ] :
2680
2667
X -= X .min ()
2681
2668
return X
2682
2669
@@ -2814,7 +2801,7 @@ def check_classifiers_regression_target(name, estimator_orig):
2814
2801
X , y = load_boston (return_X_y = True )
2815
2802
e = clone (estimator_orig )
2816
2803
msg = 'Unknown label type: '
2817
- if not _safe_tags ( e , "no_validation" ) :
2804
+ if not e . _get_tags ()[ "no_validation" ] :
2818
2805
assert_raises_regex (ValueError , msg , e .fit , X , y )
2819
2806
2820
2807
0 commit comments