Skip to content

Commit cb49ad4

Browse files
authored
MNT removed _safe_tags utility (#16950)
1 parent 5abd22f commit cb49ad4

File tree

3 files changed

+43
-56
lines changed

3 files changed

+43
-56
lines changed

sklearn/tests/test_docstring_parameters.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from sklearn.utils._testing import _get_func_name
1818
from sklearn.utils._testing import ignore_warnings
1919
from sklearn.utils._testing import all_estimators
20-
from sklearn.utils.estimator_checks import _safe_tags
2120
from sklearn.utils.estimator_checks import _enforce_estimator_tags_y
2221
from sklearn.utils.estimator_checks import _enforce_estimator_tags_x
2322
from sklearn.utils.deprecation import _is_deprecated
@@ -206,9 +205,9 @@ def test_fit_docstring_attributes(name, Estimator):
206205
y = _enforce_estimator_tags_y(est, y)
207206
X = _enforce_estimator_tags_x(est, X)
208207

209-
if '1dlabels' in _safe_tags(est, 'X_types'):
208+
if '1dlabels' in est._get_tags()['X_types']:
210209
est.fit(y)
211-
elif '2dlabels' in _safe_tags(est, 'X_types'):
210+
elif '2dlabels' in est._get_tags()['X_types']:
212211
est.fit(np.c_[y, y])
213212
else:
214213
est.fit(X, y)

sklearn/utils/estimator_checks.py

Lines changed: 39 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from ..linear_model import Ridge
3434

3535
from ..base import (clone, ClusterMixin, is_classifier, is_regressor,
36-
_DEFAULT_TAGS, RegressorMixin, is_outlier_detector)
36+
RegressorMixin, is_outlier_detector)
3737

3838
from ..metrics import accuracy_score, adjusted_rand_score, f1_score
3939
from ..random_projection import BaseRandomProjection
@@ -58,22 +58,9 @@
5858
BOSTON = None
5959
CROSS_DECOMPOSITION = ['PLSCanonical', 'PLSRegression', 'CCA', 'PLSSVD']
6060

61-
def _safe_tags(estimator, key=None):
62-
# if estimator doesn't have _get_tags, use _DEFAULT_TAGS
63-
# if estimator has tags but not key, use _DEFAULT_TAGS[key]
64-
if hasattr(estimator, "_get_tags"):
65-
if key is not None:
66-
return estimator._get_tags().get(key, _DEFAULT_TAGS[key])
67-
tags = estimator._get_tags()
68-
return {key: tags.get(key, _DEFAULT_TAGS[key])
69-
for key in _DEFAULT_TAGS.keys()}
70-
if key is not None:
71-
return _DEFAULT_TAGS[key]
72-
return _DEFAULT_TAGS
73-
7461

7562
def _yield_checks(name, estimator):
76-
tags = _safe_tags(estimator)
63+
tags = estimator._get_tags()
7764
yield check_no_attributes_set_in_init
7865
yield check_estimators_dtypes
7966
yield check_fit_score_takes_y
@@ -116,7 +103,7 @@ def _yield_checks(name, estimator):
116103

117104

118105
def _yield_classifier_checks(name, classifier):
119-
tags = _safe_tags(classifier)
106+
tags = classifier._get_tags()
120107

121108
# test classifiers can handle non-array data and pandas objects
122109
yield check_classifier_data_not_an_array
@@ -171,7 +158,7 @@ def check_supervised_y_no_nan(name, estimator_orig):
171158

172159

173160
def _yield_regressor_checks(name, regressor):
174-
tags = _safe_tags(regressor)
161+
tags = regressor._get_tags()
175162
# TODO: test with intercept
176163
# TODO: test with multiple responses
177164
# basic testing
@@ -198,12 +185,12 @@ def _yield_regressor_checks(name, regressor):
198185
def _yield_transformer_checks(name, transformer):
199186
# All transformers should either deal with sparse data or raise an
200187
# exception with type TypeError and an intelligible error message
201-
if not _safe_tags(transformer, "no_validation"):
188+
if not transformer._get_tags()["no_validation"]:
202189
yield check_transformer_data_not_an_array
203190
# these don't actually fit the data, so don't raise errors
204191
yield check_transformer_general
205192
yield partial(check_transformer_general, readonly_memmap=True)
206-
if not _safe_tags(transformer, "stateless"):
193+
if not transformer._get_tags()["stateless"]:
207194
yield check_transformers_unfitted
208195
# Dependent on external solvers and hence accessing the iter
209196
# param is non-trivial.
@@ -237,12 +224,12 @@ def _yield_outliers_checks(name, estimator):
237224
# test outlier detectors can handle non-array data
238225
yield check_classifier_data_not_an_array
239226
# test if NotFittedError is raised
240-
if _safe_tags(estimator, "requires_fit"):
227+
if estimator._get_tags()["requires_fit"]:
241228
yield check_estimators_unfitted
242229

243230

244231
def _yield_all_checks(name, estimator):
245-
tags = _safe_tags(estimator)
232+
tags = estimator._get_tags()
246233
if "2darray" not in tags["X_types"]:
247234
warnings.warn("Can't test estimator {} which requires input "
248235
" of type {}".format(name, tags["X_types"]),
@@ -369,7 +356,7 @@ def _mark_xfail_checks(estimator, check, pytest):
369356
except Exception:
370357
return estimator, check
371358

372-
xfail_checks = _safe_tags(estimator, '_xfail_checks') or {}
359+
xfail_checks = estimator._get_tags()['_xfail_checks'] or {}
373360
check_name = _set_check_estimator_ids(check)
374361

375362
if check_name not in xfail_checks:
@@ -701,7 +688,7 @@ def check_estimator_sparse_data(name, estimator_orig):
701688
X[X < .8] = 0
702689
X = _pairwise_estimator_convert_X(X, estimator_orig)
703690
X_csr = sparse.csr_matrix(X)
704-
tags = _safe_tags(estimator_orig)
691+
tags = estimator_orig._get_tags()
705692
if tags['binary_only']:
706693
y = (2 * rng.rand(40)).astype(np.int)
707694
else:
@@ -767,7 +754,7 @@ def check_sample_weights_pandas_series(name, estimator_orig):
767754
X = pd.DataFrame(_pairwise_estimator_convert_X(X, estimator_orig))
768755
y = pd.Series([1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 2, 2])
769756
weights = pd.Series([1] * 12)
770-
if _safe_tags(estimator, "multioutput_only"):
757+
if estimator._get_tags()["multioutput_only"]:
771758
y = pd.DataFrame(y)
772759
try:
773760
estimator.fit(X, y, sample_weight=weights)
@@ -792,7 +779,7 @@ def check_sample_weights_not_an_array(name, estimator_orig):
792779
X = _NotAnArray(pairwise_estimator_convert_X(X, estimator_orig))
793780
y = _NotAnArray([1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 2, 2])
794781
weights = _NotAnArray([1] * 12)
795-
if _safe_tags(estimator, "multioutput_only"):
782+
if estimator._get_tags()["multioutput_only"]:
796783
y = _NotAnArray(y.data.reshape(-1, 1))
797784
estimator.fit(X, y, sample_weight=weights)
798785

@@ -806,8 +793,8 @@ def check_sample_weights_list(name, estimator_orig):
806793
rnd = np.random.RandomState(0)
807794
n_samples = 30
808795
X = _pairwise_estimator_convert_X(rnd.uniform(size=(n_samples, 3)),
809-
estimator_orig)
810-
if _safe_tags(estimator, 'binary_only'):
796+
estimator_orig)
797+
if estimator._get_tags()['binary_only']:
811798
y = np.arange(n_samples) % 2
812799
else:
813800
y = np.arange(n_samples) % 3
@@ -886,7 +873,7 @@ def check_dtype_object(name, estimator_orig):
886873
rng = np.random.RandomState(0)
887874
X = _pairwise_estimator_convert_X(rng.rand(40, 10), estimator_orig)
888875
X = X.astype(object)
889-
tags = _safe_tags(estimator_orig)
876+
tags = estimator_orig._get_tags()
890877
if tags['binary_only']:
891878
y = (X[:, 0] * 2).astype(np.int)
892879
else:
@@ -990,7 +977,7 @@ def check_dont_overwrite_parameters(name, estimator_orig):
990977
X = 3 * rnd.uniform(size=(20, 3))
991978
X = _pairwise_estimator_convert_X(X, estimator_orig)
992979
y = X[:, 0].astype(np.int)
993-
if _safe_tags(estimator, 'binary_only'):
980+
if estimator._get_tags()['binary_only']:
994981
y[y == 2] = 1
995982
y = _enforce_estimator_tags_y(estimator, y)
996983

@@ -1041,7 +1028,7 @@ def check_fit2d_predict1d(name, estimator_orig):
10411028
X = 3 * rnd.uniform(size=(20, 3))
10421029
X = _pairwise_estimator_convert_X(X, estimator_orig)
10431030
y = X[:, 0].astype(np.int)
1044-
tags = _safe_tags(estimator_orig)
1031+
tags = estimator_orig._get_tags()
10451032
if tags['binary_only']:
10461033
y[y == 2] = 1
10471034
estimator = clone(estimator_orig)
@@ -1092,7 +1079,7 @@ def check_methods_subset_invariance(name, estimator_orig):
10921079
X = 3 * rnd.uniform(size=(20, 3))
10931080
X = _pairwise_estimator_convert_X(X, estimator_orig)
10941081
y = X[:, 0].astype(np.int)
1095-
if _safe_tags(estimator_orig, 'binary_only'):
1082+
if estimator_orig._get_tags()['binary_only']:
10961083
y[y == 2] = 1
10971084
estimator = clone(estimator_orig)
10981085
y = _enforce_estimator_tags_y(estimator, y)
@@ -1193,7 +1180,7 @@ def check_fit1d(name, estimator_orig):
11931180
X = 3 * rnd.uniform(size=(20))
11941181
y = X.astype(np.int)
11951182
estimator = clone(estimator_orig)
1196-
tags = _safe_tags(estimator)
1183+
tags = estimator._get_tags()
11971184
if tags["no_validation"]:
11981185
# FIXME this is a bit loose
11991186
return
@@ -1285,7 +1272,7 @@ def _check_transformer(name, transformer_orig, X, y):
12851272
X_pred2 = transformer.transform(X)
12861273
X_pred3 = transformer.fit_transform(X, y=y_)
12871274

1288-
if _safe_tags(transformer_orig, 'non_deterministic'):
1275+
if transformer_orig._get_tags()['non_deterministic']:
12891276
msg = name + ' is non deterministic'
12901277
raise SkipTest(msg)
12911278
if isinstance(X_pred, tuple) and isinstance(X_pred2, tuple):
@@ -1316,7 +1303,7 @@ def _check_transformer(name, transformer_orig, X, y):
13161303

13171304
# raises error on malformed input for transform
13181305
if hasattr(X, 'shape') and \
1319-
not _safe_tags(transformer, "stateless") and \
1306+
not transformer._get_tags()["stateless"] and \
13201307
X.ndim == 2 and X.shape[1] > 1:
13211308

13221309
# If it's not an array, it does not have a 'T' property
@@ -1330,7 +1317,7 @@ def _check_transformer(name, transformer_orig, X, y):
13301317

13311318
@ignore_warnings
13321319
def check_pipeline_consistency(name, estimator_orig):
1333-
if _safe_tags(estimator_orig, 'non_deterministic'):
1320+
if estimator_orig._get_tags()['non_deterministic']:
13341321
msg = name + ' is non deterministic'
13351322
raise SkipTest(msg)
13361323

@@ -1365,7 +1352,7 @@ def check_fit_score_takes_y(name, estimator_orig):
13651352
n_samples = 30
13661353
X = rnd.uniform(size=(n_samples, 3))
13671354
X = _pairwise_estimator_convert_X(X, estimator_orig)
1368-
if _safe_tags(estimator_orig, 'binary_only'):
1355+
if estimator_orig._get_tags()['binary_only']:
13691356
y = np.arange(n_samples) % 2
13701357
else:
13711358
y = np.arange(n_samples) % 3
@@ -1398,7 +1385,7 @@ def check_estimators_dtypes(name, estimator_orig):
13981385
X_train_int_64 = X_train_32.astype(np.int64)
13991386
X_train_int_32 = X_train_32.astype(np.int32)
14001387
y = X_train_int_64[:, 0]
1401-
if _safe_tags(estimator_orig, 'binary_only'):
1388+
if estimator_orig._get_tags()['binary_only']:
14021389
y[y == 2] = 1
14031390
y = _enforce_estimator_tags_y(estimator_orig, y)
14041391

@@ -1534,7 +1521,7 @@ def check_estimators_pickle(name, estimator_orig):
15341521
X -= X.min()
15351522
X = _pairwise_estimator_convert_X(X, estimator_orig, kernel=rbf_kernel)
15361523

1537-
tags = _safe_tags(estimator_orig)
1524+
tags = estimator_orig._get_tags()
15381525
# include NaN values when the estimator should deal with them
15391526
if tags['allow_nan']:
15401527
# set randomly 10 elements to np.nan
@@ -1599,7 +1586,7 @@ def check_estimators_partial_fit_n_features(name, estimator_orig):
15991586
@ignore_warnings(category=FutureWarning)
16001587
def check_classifier_multioutput(name, estimator):
16011588
n_samples, n_labels, n_classes = 42, 5, 3
1602-
tags = _safe_tags(estimator)
1589+
tags = estimator._get_tags()
16031590
estimator = clone(estimator)
16041591
X, y = make_multilabel_classification(random_state=42,
16051592
n_samples=n_samples,
@@ -1706,7 +1693,7 @@ def check_clustering(name, clusterer_orig, readonly_memmap=False):
17061693
pred = clusterer.labels_
17071694
assert pred.shape == (n_samples,)
17081695
assert adjusted_rand_score(pred, y) > 0.4
1709-
if _safe_tags(clusterer, 'non_deterministic'):
1696+
if clusterer._get_tags()['non_deterministic']:
17101697
return
17111698
set_random_state(clusterer)
17121699
with warnings.catch_warnings(record=True):
@@ -1805,7 +1792,7 @@ def check_classifiers_train(name, classifier_orig, readonly_memmap=False,
18051792
X_m, y_m, X_b, y_b = create_memmap_backed_data([X_m, y_m, X_b, y_b])
18061793

18071794
problems = [(X_b, y_b)]
1808-
tags = _safe_tags(classifier_orig)
1795+
tags = classifier_orig._get_tags()
18091796
if not tags['binary_only']:
18101797
problems.append((X_m, y_m))
18111798

@@ -2044,7 +2031,7 @@ def check_classifiers_multilabel_representation_invariance(name,
20442031
def check_estimators_fit_returns_self(name, estimator_orig,
20452032
readonly_memmap=False):
20462033
"""Check if self is returned when calling fit"""
2047-
if _safe_tags(estimator_orig, 'binary_only'):
2034+
if estimator_orig._get_tags()['binary_only']:
20482035
n_centers = 2
20492036
else:
20502037
n_centers = 3
@@ -2081,7 +2068,7 @@ def check_estimators_unfitted(name, estimator_orig):
20812068

20822069
@ignore_warnings(category=FutureWarning)
20832070
def check_supervised_y_2d(name, estimator_orig):
2084-
tags = _safe_tags(estimator_orig)
2071+
tags = estimator_orig._get_tags()
20852072
if tags['multioutput_only']:
20862073
# These only work on 2d, so this test makes no sense
20872074
return
@@ -2197,7 +2184,7 @@ def check_classifiers_classes(name, classifier_orig):
21972184
y_names_binary = np.take(labels_binary, y_binary)
21982185

21992186
problems = [(X_binary, y_binary, y_names_binary)]
2200-
if not _safe_tags(classifier_orig, 'binary_only'):
2187+
if not classifier_orig._get_tags()['binary_only']:
22012188
problems.append((X_multiclass, y_multiclass, y_names_multiclass))
22022189

22032190
for X, y, y_names in problems:
@@ -2282,7 +2269,7 @@ def check_regressors_train(name, regressor_orig, readonly_memmap=False,
22822269
# TODO: find out why PLS and CCA fail. RANSAC is random
22832270
# and furthermore assumes the presence of outliers, hence
22842271
# skipped
2285-
if not _safe_tags(regressor, "poor_score"):
2272+
if not regressor._get_tags()["poor_score"]:
22862273
assert regressor.score(X, y_) > 0.5
22872274

22882275

@@ -2315,7 +2302,7 @@ def check_regressors_no_decision_function(name, regressor_orig):
23152302
@ignore_warnings(category=FutureWarning)
23162303
def check_class_weight_classifiers(name, classifier_orig):
23172304

2318-
if _safe_tags(classifier_orig, 'binary_only'):
2305+
if classifier_orig._get_tags()['binary_only']:
23192306
problems = [2]
23202307
else:
23212308
problems = [2, 3]
@@ -2418,7 +2405,7 @@ def check_class_weight_balanced_linear_classifier(name, Classifier):
24182405

24192406
@ignore_warnings(category=FutureWarning)
24202407
def check_estimators_overwrite_params(name, estimator_orig):
2421-
if _safe_tags(estimator_orig, 'binary_only'):
2408+
if estimator_orig._get_tags()['binary_only']:
24222409
n_centers = 2
24232410
else:
24242411
n_centers = 3
@@ -2654,13 +2641,13 @@ def enforce_estimator_tags_y(estimator, y):
26542641
def _enforce_estimator_tags_y(estimator, y):
26552642
# Estimators with a `requires_positive_y` tag only accept strictly positive
26562643
# data
2657-
if _safe_tags(estimator, "requires_positive_y"):
2644+
if estimator._get_tags()["requires_positive_y"]:
26582645
# Create strictly positive y. The minimal increment above 0 is 1, as
26592646
# y could be of integer dtype.
26602647
y += 1 + abs(y.min())
26612648
# Estimators in mono_output_task_error raise ValueError if y is of 1-D
26622649
# Convert into a 2-D y for those estimators.
2663-
if _safe_tags(estimator, "multioutput_only"):
2650+
if estimator._get_tags()["multioutput_only"]:
26642651
return np.reshape(y, (-1, 1))
26652652
return y
26662653

@@ -2672,11 +2659,11 @@ def _enforce_estimator_tags_x(estimator, X):
26722659
X = X.dot(X.T)
26732660
# Estimators with `1darray` in `X_types` tag only accept
26742661
# X of shape (`n_samples`,)
2675-
if '1darray' in _safe_tags(estimator, 'X_types'):
2662+
if '1darray' in estimator._get_tags()['X_types']:
26762663
X = X[:, 0]
26772664
# Estimators with a `requires_positive_X` tag only accept
26782665
# strictly positive data
2679-
if _safe_tags(estimator, 'requires_positive_X'):
2666+
if estimator._get_tags()['requires_positive_X']:
26802667
X -= X.min()
26812668
return X
26822669

@@ -2814,7 +2801,7 @@ def check_classifiers_regression_target(name, estimator_orig):
28142801
X, y = load_boston(return_X_y=True)
28152802
e = clone(estimator_orig)
28162803
msg = 'Unknown label type: '
2817-
if not _safe_tags(e, "no_validation"):
2804+
if not e._get_tags()["no_validation"]:
28182805
assert_raises_regex(ValueError, msg, e.fit, X, y)
28192806

28202807

sklearn/utils/tests/test_estimator_checks.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,8 @@ def test_check_estimator():
363363
# check that we have a set_params and can clone
364364
msg = "it does not implement a 'get_params' method"
365365
assert_raises_regex(TypeError, msg, check_estimator, object)
366-
assert_raises_regex(TypeError, msg, check_estimator, object())
366+
msg = "object has no attribute '_get_tags'"
367+
assert_raises_regex(AttributeError, msg, check_estimator, object())
367368
# check that values returned by get_params match set_params
368369
msg = "get_params result does not match what was passed to set_params"
369370
assert_raises_regex(AssertionError, msg, check_estimator,

0 commit comments

Comments
 (0)