-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
MAINT conversion old->new/new->old tags #30302
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
f14ac4f
to
d74c0e8
Compare
sklearn/utils/_tags.py
Outdated
@@ -318,6 +318,14 @@ def get_tags(estimator) -> Tags: | |||
|
|||
if hasattr(estimator, "__sklearn_tags__"): | |||
tags = estimator.__sklearn_tags__() | |||
elif hasattr(estimator, "_get_tags"): | |||
warnings.warn("BROKEN SOON, IT WILL BE", FutureWarning) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
😄
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm entertaining myself as I can ;)
OK, so I'm converging with something working. I'm sorry for the reviewers of this PR... Trust the tests ;) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a few comments
# TODO(1.7): Remove this block | ||
if _find_tags_provider(self) == "_get_tags": | ||
# one of the children classes only implements `_get_tags` so we need to | ||
# warn and and mix old-style and new-style tags. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# warn and and mix old-style and new-style tags. | |
# warn and mix old-style and new-style tags. |
|
||
collected_tags = {} | ||
for base_class in reversed(inspect.getmro(self.__class__)): | ||
if hasattr(base_class, "_more_tags"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this equivalent to checking vars
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the old implementation of _get_tags
. So I would not touch it.
The only change here is to fallback to __sklearn_tags__
in case someone in the chain of inheritance implemented it (like scikit-learn implementing the new tags and removing the old tags).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now, I get your remark. I think that _more_tags
was the API trick to avoid going through the MRO and thus an equivalent of vars(__sklearn_tags__)
but by design. This should solve my bug below :) (edit: or not because I don't know how to avoid the climbing of the inheritance).
more_tags = base_class.__sklearn_tags__(self) | ||
collected_tags.update(_to_old_tags(more_tags)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pretty sure this is gonna be buggy since __sklearn_tags__
collects all tags from base_class
upwards.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you provide me a case where it would break?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this is not covered, I'm not doing the case that I was thinking of :).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So this would be:
class Mixin:
def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.input_tags.allow_nan = True
return tags
class MyClass(Mixin, BaseEstimator):
def _more_tags(self):
return {
"multioutput": True
}
and indeed it end up with a recursion error.
def __sklearn_tags__(self): | ||
tags = super().__sklearn_tags__() | ||
tags.transformer_tags = TransformerTags(preserves_dtype=[]) | ||
return tags |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this class is not a transformer, why do we attach transformer tags here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A transformer is something that has a fit
+ transform
method, isn't it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, and the _VectorizerMixin
mixin has neither of them.
@@ -572,6 +573,7 @@ def __init__( | |||
|
|||
def __sklearn_tags__(self): | |||
tags = super().__sklearn_tags__() | |||
tags.transformer_tags = TransformerTags(preserves_dtype=["float64"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why isn't this inheriting from TransformerMixin
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because we don't want to implement transform
but only fit_transform
such that you can only provide the same data at fit
and transform
.
At least this is what I recall.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TransformerMixin
doesn't do anything to transform
. Here we'd use its tags, and override fit_transform
from the mixin.
@@ -654,6 +655,7 @@ def __sklearn_tags__(self): | |||
"precomputed", | |||
"precomputed_nearest_neighbors", | |||
] | |||
tags.transformer_tags = TransformerTags(preserves_dtype=["float64"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same answer than above ;)
# Neither of the following two estimators inherit from BaseEstimator, | ||
# to test hyperparameter search on user-defined classifiers. | ||
class MockClassifier(ClassifierMixin, BaseEstimator): | ||
class MockClassifier(TransformerMixin, ClassifierMixin, BaseEstimator): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the comment is not valid anymore, and we have the same estimator in test_rfe
@@ -243,7 +245,7 @@ class Tags: | |||
input_tags: InputTags = field(default_factory=InputTags) | |||
|
|||
|
|||
# TODO(1.8): Remove this function | |||
# TODO(1.7): Remove this function |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
then we also need to change the warnings / comments about _estimator_type
as well
I'll add more test to cover the missing bits. I'll do it in a way to emulate old estimators. |
I open #30327 that is a simpler alternative: when we cannot rely solely on |
Towards #30298
This PR provides a way to convert old tags into new tags and new tags into old tags.
We should make sure that:
_get_tags
or_more_tags
get a warning but that the check_estimator should be working reasonably_safe_tags
should raise a deprecation warningget_tags
can get old tags and convert it to new tags. The tricky part is to detect whether__sklearn_tags__
if present is implemented by theBaseEstimator
from scikit-learn or the child class. If this is only theBaseEstimator
, we should raise a warning to move to the new API and we should temporary use the_get_tags
/_safe_tags
, otherwise we are fine.