-
-
Notifications
You must be signed in to change notification settings - Fork 26k
MAINT conversion old->new/new->old tags #30302
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
d74c0e8
4d70067
6366366
4ea106d
109c214
b3c7be9
bb24091
3d107eb
ae20c68
fe55288
7aa9981
98ac133
7f6b14b
c68f368
a46e593
55d7603
822e9c8
c914e09
521718c
dfe7c1f
5c9f23f
8dbb358
424ed90
2bf876e
b41d1a6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,6 +26,8 @@ | |
Tags, | ||
TargetTags, | ||
TransformerTags, | ||
_to_old_tags, | ||
default_tags, | ||
get_tags, | ||
) | ||
from .utils.fixes import _IS_32BIT | ||
|
@@ -390,6 +392,22 @@ | |
self.__dict__.update(state) | ||
|
||
def __sklearn_tags__(self): | ||
from sklearn.utils._tags import _find_tags_provider, _to_new_tags | ||
|
||
# TODO(1.7): Remove this block | ||
if _find_tags_provider(self) == "_get_tags": | ||
# one of the children classes only implements `_get_tags` so we need to | ||
# warn and and mix old-style and new-style tags. | ||
with warnings.catch_warnings(): | ||
warnings.filterwarnings( | ||
"ignore", | ||
category=FutureWarning, | ||
message="The `_get_tags` tag provider is deprecated", | ||
) | ||
# silence the warning to avoid a false positive warning since this | ||
# is not a direct user call but an internal one. | ||
return _to_new_tags(self._get_tags()) | ||
|
||
return Tags( | ||
estimator_type=None, | ||
target_tags=TargetTags(required=False), | ||
|
@@ -398,6 +416,54 @@ | |
classifier_tags=None, | ||
) | ||
|
||
# TODO(1.7): Remove this method | ||
def _more_tags(self): | ||
warnings.warn( | ||
"The `_more_tags` method is deprecated in 1.6 and will be removed in " | ||
"1.7. Please implement the `__sklearn_tags__` method.", | ||
category=FutureWarning, | ||
) | ||
return _to_old_tags(default_tags(self)) | ||
|
||
# TODO(1.7): Remove this method | ||
def _get_tags(self): | ||
warnings.warn( | ||
"The `_get_tags` tag provider is deprecated in 1.6 and will be removed in " | ||
"1.7. Please implement the `__sklearn_tags__` method.", | ||
category=FutureWarning, | ||
) | ||
# In case a user called `_get_tags` but that the estimator already did the job | ||
# implementing `__sklearn_tags__` completely and removed `_more_tags`, let's | ||
# default back to the future behaviour. Otherwise, we will get the default tags. | ||
from sklearn.utils._tags import _find_tags_provider, _to_old_tags, get_tags | ||
|
||
if _find_tags_provider(self, warn=False) == "__sklearn_tags__": | ||
return _to_old_tags(get_tags(self)) | ||
|
||
collected_tags = {} | ||
for base_class in reversed(inspect.getmro(self.__class__)): | ||
if hasattr(base_class, "_more_tags"): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this equivalent to checking There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the old implementation of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Now, I get your remark. I think that |
||
# need the if because mixins might not have _more_tags | ||
# but might do redundant work in estimators | ||
# (i.e. calling more tags on BaseEstimator multiple times) | ||
with warnings.catch_warnings(): | ||
warnings.filterwarnings( | ||
"ignore", | ||
category=FutureWarning, | ||
message="The `_more_tags` method is deprecated", | ||
) | ||
# silence the warning to avoid a false positive warning since this | ||
# is not a direct user call but an internal one. | ||
more_tags = base_class._more_tags(self) | ||
collected_tags.update(more_tags) | ||
elif hasattr(base_class, "__sklearn_tags__"): | ||
# Since that some people will inherit from scikit-learn that implements | ||
# the new infrastructure, we need to collect it and merge it with | ||
# the old tags. | ||
more_tags = base_class.__sklearn_tags__(self) | ||
collected_tags.update(_to_old_tags(more_tags)) | ||
Comment on lines
+463
to
+464
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. pretty sure this is gonna be buggy since There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you provide me a case where it would break? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since this is not covered, I'm not doing the case that I was thinking of :). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So this would be: class Mixin:
def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.input_tags.allow_nan = True
return tags
class MyClass(Mixin, BaseEstimator):
def _more_tags(self):
return {
"multioutput": True
} and indeed it end up with a recursion error. |
||
return collected_tags | ||
|
||
def _validate_params(self): | ||
"""Validate types and values of constructor parameters | ||
|
||
|
@@ -509,6 +575,10 @@ | |
# TODO(1.8): Remove this attribute | ||
_estimator_type = "classifier" | ||
|
||
# TODO(1.7): Remove this method | ||
def _more_tags(self): | ||
return {"requires_y": True} | ||
|
||
def __sklearn_tags__(self): | ||
tags = super().__sklearn_tags__() | ||
tags.estimator_type = "classifier" | ||
|
@@ -582,6 +652,10 @@ | |
# TODO(1.8): Remove this attribute | ||
_estimator_type = "regressor" | ||
|
||
# TODO(1.7): Remove this method | ||
def _more_tags(self): | ||
return {"requires_y": True} | ||
|
||
def __sklearn_tags__(self): | ||
tags = super().__sklearn_tags__() | ||
tags.estimator_type = "regressor" | ||
|
@@ -658,6 +732,10 @@ | |
# TODO(1.8): Remove this attribute | ||
_estimator_type = "clusterer" | ||
|
||
# TODO(1.7): Remove this method | ||
def _more_tags(self): | ||
return {"preserves_dtype": []} | ||
|
||
def __sklearn_tags__(self): | ||
tags = super().__sklearn_tags__() | ||
tags.estimator_type = "clusterer" | ||
|
@@ -1153,6 +1231,10 @@ | |
class MultiOutputMixin: | ||
"""Mixin to mark estimators that support multioutput.""" | ||
|
||
# TODO(1.7): Remove this method | ||
def _more_tags(self): | ||
return {"multioutput": True} | ||
|
||
def __sklearn_tags__(self): | ||
tags = super().__sklearn_tags__() | ||
tags.target_tags.multi_output = True | ||
|
@@ -1162,6 +1244,13 @@ | |
class _UnstableArchMixin: | ||
"""Mark estimators that are non-determinstic on 32bit or PowerPC""" | ||
|
||
# TODO(1.7): Remove this method | ||
def _more_tags(self): | ||
return { | ||
"non_deterministic": _IS_32BIT | ||
or platform.machine().startswith(("ppc", "powerpc")) | ||
} | ||
|
||
def __sklearn_tags__(self): | ||
tags = super().__sklearn_tags__() | ||
tags.non_deterministic = _IS_32BIT or platform.machine().startswith( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,7 +16,7 @@ | |
import numpy as np | ||
import scipy.sparse as sp | ||
|
||
from sklearn.utils import metadata_routing | ||
from sklearn.utils import TransformerTags, metadata_routing | ||
|
||
from ..base import BaseEstimator, OneToOneFeatureMixin, TransformerMixin, _fit_context | ||
from ..exceptions import NotFittedError | ||
|
@@ -554,6 +554,11 @@ def _warn_for_unused_params(self): | |
" since 'analyzer' != 'word'" | ||
) | ||
|
||
def __sklearn_tags__(self): | ||
tags = super().__sklearn_tags__() | ||
tags.transformer_tags = TransformerTags(preserves_dtype=[]) | ||
return tags | ||
Comment on lines
+557
to
+560
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this class is not a transformer, why do we attach transformer tags here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A transformer is something that has a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, and the |
||
|
||
|
||
class HashingVectorizer( | ||
TransformerMixin, _VectorizerMixin, BaseEstimator, auto_wrap_output_keys=None | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,7 @@ | |
from ..metrics import euclidean_distances | ||
from ..utils import check_array, check_random_state, check_symmetric | ||
from ..utils._param_validation import Interval, StrOptions, validate_params | ||
from ..utils._tags import TransformerTags | ||
from ..utils.parallel import Parallel, delayed | ||
from ..utils.validation import validate_data | ||
|
||
|
@@ -572,6 +573,7 @@ def __init__( | |
|
||
def __sklearn_tags__(self): | ||
tags = super().__sklearn_tags__() | ||
tags.transformer_tags = TransformerTags(preserves_dtype=["float64"]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why isn't this inheriting from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because we don't want to implement At least this is what I recall. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
tags.input_tags.pairwise = self.dissimilarity == "precomputed" | ||
return tags | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,6 +23,7 @@ | |
) | ||
from ..utils._arpack import _init_arpack_v0 | ||
from ..utils._param_validation import Interval, StrOptions, validate_params | ||
from ..utils._tags import TransformerTags | ||
from ..utils.extmath import _deterministic_vector_sign_flip | ||
from ..utils.fixes import laplacian as csgraph_laplacian | ||
from ..utils.fixes import parse_version, sp_version | ||
|
@@ -654,6 +655,7 @@ def __sklearn_tags__(self): | |
"precomputed", | ||
"precomputed_nearest_neighbors", | ||
] | ||
tags.transformer_tags = TransformerTags(preserves_dtype=["float64"]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same answer than above ;) |
||
return tags | ||
|
||
def _get_affinity_matrix(self, X, Y=None): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,7 +15,7 @@ | |
from scipy.stats import bernoulli, expon, uniform | ||
|
||
from sklearn import config_context | ||
from sklearn.base import BaseEstimator, ClassifierMixin, is_classifier | ||
from sklearn.base import BaseEstimator, ClassifierMixin, TransformerMixin, is_classifier | ||
from sklearn.cluster import KMeans | ||
from sklearn.compose import ColumnTransformer | ||
from sklearn.datasets import ( | ||
|
@@ -100,7 +100,7 @@ | |
|
||
# Neither of the following two estimators inherit from BaseEstimator, | ||
# to test hyperparameter search on user-defined classifiers. | ||
class MockClassifier(ClassifierMixin, BaseEstimator): | ||
class MockClassifier(TransformerMixin, ClassifierMixin, BaseEstimator): | ||
Comment on lines
101
to
+103
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the comment is not valid anymore, and we have the same estimator in |
||
"""Dummy classifier to test the parameter search algorithms""" | ||
|
||
def __init__(self, foo_param=0): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.