Skip to content

MAINT conversion old->new/new->old tags #30302

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 25 commits into from

Conversation

glemaitre
Copy link
Member

@glemaitre glemaitre commented Nov 18, 2024

Towards #30298

This PR provides a way to convert old tags into new tags and new tags into old tags.
We should make sure that:

  • An estimator implementing _get_tags or _more_tags get a warning but that the check_estimator should be working reasonably
  • _safe_tags should raise a deprecation warning
  • we should make sure that get_tags can get old tags and convert it to new tags. The tricky part is to detect whether __sklearn_tags__ if present is implemented by the BaseEstimator from scikit-learn or the child class. If this is only the BaseEstimator, we should raise a warning to move to the new API and we should temporary use the _get_tags/_safe_tags, otherwise we are fine.

Copy link

github-actions bot commented Nov 18, 2024

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: b41d1a6. Link to the linter CI: here

@glemaitre glemaitre changed the title MAINT conversion old->new new->old tags MAINT conversion old->new/new->old tags Nov 18, 2024
@@ -318,6 +318,14 @@ def get_tags(estimator) -> Tags:

if hasattr(estimator, "__sklearn_tags__"):
tags = estimator.__sklearn_tags__()
elif hasattr(estimator, "_get_tags"):
warnings.warn("BROKEN SOON, IT WILL BE", FutureWarning)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😄

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm entertaining myself as I can ;)

@glemaitre
Copy link
Member Author

OK, so I'm converging with something working. I'm sorry for the reviewers of this PR... Trust the tests ;)

ping @jeremiedbb @ogrisel @adrinjalali

Copy link
Member

@adrinjalali adrinjalali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a few comments

# TODO(1.7): Remove this block
if _find_tags_provider(self) == "_get_tags":
# one of the children classes only implements `_get_tags` so we need to
# warn and and mix old-style and new-style tags.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# warn and and mix old-style and new-style tags.
# warn and mix old-style and new-style tags.


collected_tags = {}
for base_class in reversed(inspect.getmro(self.__class__)):
if hasattr(base_class, "_more_tags"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this equivalent to checking vars?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the old implementation of _get_tags. So I would not touch it.
The only change here is to fallback to __sklearn_tags__ in case someone in the chain of inheritance implemented it (like scikit-learn implementing the new tags and removing the old tags).

Copy link
Member Author

@glemaitre glemaitre Nov 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now, I get your remark. I think that _more_tags was the API trick to avoid going through the MRO and thus an equivalent of vars(__sklearn_tags__) but by design. This should solve my bug below :) (edit: or not because I don't know how to avoid the climbing of the inheritance).

Comment on lines +463 to +464
more_tags = base_class.__sklearn_tags__(self)
collected_tags.update(_to_old_tags(more_tags))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pretty sure this is gonna be buggy since __sklearn_tags__ collects all tags from base_class upwards.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you provide me a case where it would break?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is not covered, I'm not doing the case that I was thinking of :).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this would be:

class Mixin:
    def __sklearn_tags__(self):
        tags = super().__sklearn_tags__()
        tags.input_tags.allow_nan = True
        return tags

class MyClass(Mixin, BaseEstimator):

    def _more_tags(self):
        return {
            "multioutput": True
        }

and indeed it end up with a recursion error.

Comment on lines +557 to +560
def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.transformer_tags = TransformerTags(preserves_dtype=[])
return tags
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this class is not a transformer, why do we attach transformer tags here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A transformer is something that has a fit + transform method, isn't it?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, and the _VectorizerMixin mixin has neither of them.

@@ -572,6 +573,7 @@ def __init__(

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.transformer_tags = TransformerTags(preserves_dtype=["float64"])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why isn't this inheriting from TransformerMixin?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because we don't want to implement transform but only fit_transform such that you can only provide the same data at fit and transform.

At least this is what I recall.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TransformerMixin doesn't do anything to transform. Here we'd use its tags, and override fit_transform from the mixin.

@@ -654,6 +655,7 @@ def __sklearn_tags__(self):
"precomputed",
"precomputed_nearest_neighbors",
]
tags.transformer_tags = TransformerTags(preserves_dtype=["float64"])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same answer than above ;)

Comment on lines 101 to +103
# Neither of the following two estimators inherit from BaseEstimator,
# to test hyperparameter search on user-defined classifiers.
class MockClassifier(ClassifierMixin, BaseEstimator):
class MockClassifier(TransformerMixin, ClassifierMixin, BaseEstimator):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the comment is not valid anymore, and we have the same estimator in test_rfe

@@ -243,7 +245,7 @@ class Tags:
input_tags: InputTags = field(default_factory=InputTags)


# TODO(1.8): Remove this function
# TODO(1.7): Remove this function
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then we also need to change the warnings / comments about _estimator_type as well

@glemaitre
Copy link
Member Author

I'll add more test to cover the missing bits. I'll do it in a way to emulate old estimators.

@glemaitre
Copy link
Member Author

I open #30327 that is a simpler alternative: when we cannot rely solely on __sklearn_tags__, we patch all classes in the inheritance tree by detecting the tags to be added in a _more_tags old API.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants