Skip to content

Commit d74c0e8

Browse files
committed
MAINT conversion old->new new->old tags
1 parent 6bf2061 commit d74c0e8

File tree

2 files changed

+163
-6
lines changed

2 files changed

+163
-6
lines changed

sklearn/utils/_tags.py

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,14 @@ def get_tags(estimator) -> Tags:
318318

319319
if hasattr(estimator, "__sklearn_tags__"):
320320
tags = estimator.__sklearn_tags__()
321+
elif hasattr(estimator, "_get_tags"):
322+
warnings.warn("BROKEN SOON, IT WILL BE", FutureWarning)
323+
tags = _to_new_tags(estimator._get_tags())
324+
elif hasattr(estimator, "_more_tags"):
325+
warnings.warn("BROKEN SOON, IT WILL BE", FutureWarning)
326+
tags = _to_old_tags(default_tags(estimator))
327+
tags = {**tags, **estimator._more_tags()}
328+
tags = _to_new_tags(tags)
321329
else:
322330
warnings.warn(
323331
f"Estimator {estimator} has no __sklearn_tags__ attribute, which is "
@@ -332,3 +340,148 @@ def get_tags(estimator) -> Tags:
332340
tags = default_tags(estimator)
333341

334342
return tags
343+
344+
345+
def _to_new_tags(old_tags, estimator_type=None):
346+
"""Utility function convert old tags (dictionary) to new tags (dataclass)."""
347+
input_tags = InputTags(
348+
one_d_array="1darray" in old_tags["X_types"],
349+
two_d_array="2darray" in old_tags["X_types"],
350+
three_d_array="3darray" in old_tags["X_types"],
351+
sparse="sparse" in old_tags["X_types"],
352+
categorical="categorical" in old_tags["X_types"],
353+
string="string" in old_tags["X_types"],
354+
dict="dict" in old_tags["X_types"],
355+
positive_only=old_tags["requires_positive_X"],
356+
allow_nan=old_tags["allow_nan"],
357+
pairwise=old_tags["pairwise"],
358+
)
359+
target_tags = TargetTags(
360+
required=old_tags["requires_y"],
361+
one_d_labels="1dlabels" in old_tags["X_types"],
362+
two_d_labels="2dlabels" in old_tags["X_types"],
363+
positive_only=old_tags["requires_positive_y"],
364+
multi_output=old_tags["multioutput"] or old_tags["multioutput_only"],
365+
single_output=not old_tags["multioutput_only"],
366+
)
367+
transformer_tags = TransformerTags(
368+
preserves_dtype=old_tags["preserves_dtype"],
369+
)
370+
classifier_tags = ClassifierTags(
371+
poor_score=old_tags["poor_score"],
372+
multi_class=not old_tags["binary_only"],
373+
multi_label=old_tags["multilabel"],
374+
)
375+
regressor_tags = RegressorTags(
376+
poor_score=old_tags["poor_score"],
377+
multi_label=old_tags["multilabel"],
378+
)
379+
return Tags(
380+
estimator_type=estimator_type,
381+
target_tags=target_tags,
382+
transformer_tags=transformer_tags,
383+
classifier_tags=classifier_tags,
384+
regressor_tags=regressor_tags,
385+
input_tags=input_tags,
386+
array_api_support=old_tags["array_api_support"],
387+
no_validation=old_tags["no_validation"],
388+
non_deterministic=old_tags["non_deterministic"],
389+
requires_fit=old_tags["requires_fit"],
390+
_skip_test=old_tags["_skip_test"],
391+
)
392+
393+
394+
def _to_old_tags(new_tags):
395+
"""Utility function convert old tags (dictionary) to new tags (dataclass)."""
396+
if new_tags.classifier_tags:
397+
binary_only = not new_tags.classifier_tags.multi_class
398+
multilabel_clf = new_tags.classifier_tags.multi_label
399+
poor_score_clf = new_tags.classifier_tags.poor_score
400+
else:
401+
binary_only = False
402+
multilabel_clf = False
403+
poor_score_clf = False
404+
405+
if new_tags.regressor_tags:
406+
multilabel_reg = new_tags.regressor_tags.multi_label
407+
poor_score_reg = new_tags.regressor_tags.poor_score
408+
else:
409+
multilabel_reg = False
410+
poor_score_reg = False
411+
412+
if new_tags.transformer_tags:
413+
preserves_dtype = new_tags.transformer_tags.preserves_dtype
414+
else:
415+
preserves_dtype = ["float64"]
416+
417+
tags = {
418+
"allow_nan": new_tags.input_tags.allow_nan,
419+
"array_api_support": new_tags.array_api_support,
420+
"binary_only": binary_only,
421+
"multilabel": multilabel_clf or multilabel_reg,
422+
"multioutput": new_tags.target_tags.multi_output,
423+
"multioutput_only": (
424+
not new_tags.target_tags.single_output
425+
and new_tags.target_tags.multi_output
426+
),
427+
"no_validation": new_tags.no_validation,
428+
"non_deterministic": new_tags.non_deterministic,
429+
"pairwise": new_tags.input_tags.pairwise,
430+
"preserves_dtype": preserves_dtype,
431+
"poor_score": poor_score_clf or poor_score_reg,
432+
"requires_fit": new_tags.requires_fit,
433+
"requires_positive_X": new_tags.input_tags.positive_only,
434+
"requires_y": new_tags.target_tags.required,
435+
"requires_positive_y": new_tags.target_tags.positive_only,
436+
"_skip_test": new_tags._skip_test,
437+
"stateless": new_tags.requires_fit,
438+
}
439+
X_types = []
440+
if new_tags.input_tags.one_d_array:
441+
X_types.append("1darray")
442+
if new_tags.input_tags.two_d_array:
443+
X_types.append("2darray")
444+
if new_tags.input_tags.three_d_array:
445+
X_types.append("3darray")
446+
if new_tags.input_tags.sparse:
447+
X_types.append("sparse")
448+
if new_tags.input_tags.categorical:
449+
X_types.append("categorical")
450+
if new_tags.input_tags.string:
451+
X_types.append("string")
452+
if new_tags.input_tags.dict:
453+
X_types.append("dict")
454+
if new_tags.target_tags.one_d_labels:
455+
X_types.append("1dlabels")
456+
if new_tags.target_tags.two_d_labels:
457+
X_types.append("2dlabels")
458+
tags["X_types"] = X_types
459+
return tags
460+
461+
462+
def _safe_tags(estimator, key=None):
463+
warnings.warn(
464+
"The `_safe_tags` utility function is deprecated in 1.6 and will be removed in "
465+
"1.7. Use the public `get_tags` function instead and make sure to implement "
466+
"the `__sklearn_tags__` method.",
467+
category=FutureWarning,
468+
)
469+
if hasattr(estimator, "_get_tags"):
470+
tags_provider = "_get_tags()"
471+
tags = estimator._get_tags()
472+
elif hasattr(estimator, "_more_tags"):
473+
tags_provider = "_more_tags()"
474+
tags = _to_old_tags(default_tags(estimator))
475+
tags = {**tags, **estimator._more_tags()}
476+
else:
477+
tags_provider = "_DEFAULT_TAGS"
478+
tags = _to_old_tags(default_tags(estimator))
479+
480+
if key is not None:
481+
if key not in tags:
482+
raise ValueError(
483+
f"The key {key} is not defined in {tags_provider} for the "
484+
f"class {estimator.__class__.__name__}."
485+
)
486+
return tags[key]
487+
return tags

sklearn/utils/estimator_checks.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -115,12 +115,14 @@ def _raise_for_missing_tags(estimator, tag_name, Mixin):
115115
tags = get_tags(estimator)
116116
estimator_type = Mixin.__name__.replace("Mixin", "")
117117
if getattr(tags, tag_name) is None:
118-
raise RuntimeError(
119-
f"Estimator {estimator.__class__.__name__} seems to be a {estimator_type},"
120-
f" but the `{tag_name}` tag is not set. Either set the tag manually"
121-
f" or inherit from the {Mixin.__name__}. Note that the order of inheritance"
122-
f" matters, the {Mixin.__name__} should come before BaseEstimator."
123-
)
118+
# raise RuntimeError(
119+
# f"Estimator {estimator.__class__.__name__} seems to be a {estimator_type},"
120+
# f" but the `{tag_name}` tag is not set. Either set the tag manually"
121+
# f" or inherit from the {Mixin.__name__}. Note that the order of inheritance"
122+
# f" matters, the {Mixin.__name__} should come before BaseEstimator."
123+
# )
124+
warnings.warn("BROKEN SOON, IT WILL BE", FutureWarning)
125+
124126

125127

126128
def _yield_api_checks(estimator):
@@ -276,6 +278,8 @@ def _yield_regressor_checks(regressor):
276278
def _yield_transformer_checks(transformer):
277279
_raise_for_missing_tags(transformer, "transformer_tags", TransformerMixin)
278280
tags = get_tags(transformer)
281+
print(transformer)
282+
print(tags)
279283
# All transformers should either deal with sparse data or raise an
280284
# exception with type TypeError and an intelligible error message
281285
if not tags.no_validation:

0 commit comments

Comments
 (0)