Skip to content

MNT Clean-up deprecations for 1.7: old tags #31134

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 17, 2025

Conversation

jeremiedbb
Copy link
Member

Removed deprecated old tags infrastructure.

@jeremiedbb jeremiedbb added this to the 1.7 milestone Apr 2, 2025
Comment on lines -410 to +349
"alternatively define `__sklearn_tags__` but we don't recommend "
"this approach). Note that `BaseEstimator` needs to be on the "
"right side of other Mixins in the inheritance order. The "
"default are now used instead since retrieving tags failed. "
"This warning will be replaced by an error in 1.7.",
category=DeprecationWarning,
)
tags = default_tags(estimator)
else:
raise
else:
# TODO(1.7): Remove this branch of the code
# Let's go through the MRO and patch each class implementing _more_tags
sklearn_tags_provider = {}
more_tags_provider = {}
class_order = []
for klass in reversed(type(estimator).mro()):
if "__sklearn_tags__" in vars(klass):
sklearn_tags_provider[klass] = klass.__sklearn_tags__(estimator) # type: ignore[attr-defined]
class_order.append(klass)
elif "_more_tags" in vars(klass):
more_tags_provider[klass] = klass._more_tags(estimator) # type: ignore[attr-defined]
class_order.append(klass)

# Find differences between consecutive in the case of __sklearn_tags__
# inheritance
sklearn_tags_diff = {}
items = list(sklearn_tags_provider.items())
for current_item, next_item in pairwise(items):
current_name, current_tags = current_item
next_name, next_tags = next_item
current_tags = _to_old_tags(current_tags)
next_tags = _to_old_tags(next_tags)

# Compare tags and store differences
diff = {}
for key in current_tags:
if current_tags[key] != next_tags[key]:
diff[key] = next_tags[key]

sklearn_tags_diff[next_name] = diff

tags = {}
for klass in class_order:
if klass in sklearn_tags_diff:
tags.update(sklearn_tags_diff[klass])
elif klass in more_tags_provider:
tags.update(more_tags_provider[klass])

tags = _to_new_tags(
{**_to_old_tags(default_tags(estimator)), **tags}, estimator
)

return tags


# TODO(1.7): Remove this function
def _safe_tags(estimator, key=None):
warnings.warn(
"The `_safe_tags` function is deprecated in 1.6 and will be removed in "
"1.7. Use the public `get_tags` function instead and make sure to implement "
"the `__sklearn_tags__` method.",
category=DeprecationWarning,
)
tags = _to_old_tags(get_tags(estimator))

if key is not None:
if key not in tags:
raise ValueError(
f"The key {key} is not defined for the class "
f"{estimator.__class__.__name__}."
try:
tags = estimator.__sklearn_tags__()
except AttributeError as exc:
# TODO(1.8): turn the warning into an error
if "object has no attribute '__sklearn_tags__'" in str(exc):
# Fall back to the default tags if the estimator does not
# implement __sklearn_tags__.
# In particular, workaround the regression reported in
# https://github.com/scikit-learn/scikit-learn/issues/30479
# `__sklearn_tags__` is implemented by calling
# `super().__sklearn_tags__()` but there is no `__sklearn_tags__`
# method in the base class. Typically happens when only inheriting
# from Mixins.

warnings.warn(
f"The following error was raised: {exc}. It seems that "
"there are no classes that implement `__sklearn_tags__` "
"in the MRO and/or all classes in the MRO call "
"`super().__sklearn_tags__()`. Make sure to inherit from "
"`BaseEstimator` which implements `__sklearn_tags__` (or "
"alternatively define `__sklearn_tags__` but we don't recommend "
"this approach). Note that `BaseEstimator` needs to be on the "
"right side of other Mixins in the inheritance order. The "
"default are now used instead since retrieving tags failed. "
"This warning will be replaced by an error in 1.8.",
category=DeprecationWarning,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not follow the original plan here see my following comment

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting approach. I don't mind.

Copy link

github-actions bot commented Apr 2, 2025

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: de801a4. Link to the linter CI: here

@jeremiedbb
Copy link
Member Author

In get_tags we added a deprecation warning for cases where inherit from classes that implement __sklearn_tags__ only calling super().__sklearn_tags__. However we did nothing for classes that don't implement __sklearn_tags__ tags at all.

So the following doesn't raise any warning in main:

import numpy as np
from sklearn.pipeline import Pipeline


class MyEstimator:
    def __init__(self, *, param=1):
        self.param = param

    def fit(self, X, y=None):
        self.is_fitted_ = True
        return self

    def predict(self, X):
        return np.full(shape=X.shape[0], fill_value=self.param)

my_pipeline = Pipeline([("estimator", MyEstimator(param=1))])

X = np.array([[1, 2], [2, 3], [3, 4]])
y = np.array([1, 0, 1])

my_pipeline.fit(X, y).predict(X)

Cleaning-up get_tags according to the plan would raise an error in 1.7 which I find not very user friendly. So I think it's better to raise the warning for this case as well for at least 1 more release. What do you think @glemaitre @adrinjalali ?

@@ -1221,7 +1221,7 @@ def __sklearn_tags__(self):
tags.input_tags.sparse = all(
get_tags(step).input_tags.sparse
for name, step in self.steps
if step != "passthrough"
if step is not None and step != "passthrough"
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was missed because get_tags(None) used to return default tags.

Comment on lines +40 to 45
# TODO(1.8): Update when implementing __sklearn_tags__ is required
@pytest.mark.filterwarnings(
"ignore:.*no __sklearn_tags__ attribute.*:DeprecationWarning"
"ignore:.*no attribute '__sklearn_tags__'.*:DeprecationWarning"
)
@pytest.mark.parametrize(
"estimator, value",
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This filter used to not filter anything but didn't have to :)
Now that I've made it so that a warning is raised for estimator not implementing __sklearn_tags__, it triggers and showed that the filter was wrong.

Copy link
Member

@adrinjalali adrinjalali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. I'll let @glemaitre have a look too.

Comment on lines -410 to +349
"alternatively define `__sklearn_tags__` but we don't recommend "
"this approach). Note that `BaseEstimator` needs to be on the "
"right side of other Mixins in the inheritance order. The "
"default are now used instead since retrieving tags failed. "
"This warning will be replaced by an error in 1.7.",
category=DeprecationWarning,
)
tags = default_tags(estimator)
else:
raise
else:
# TODO(1.7): Remove this branch of the code
# Let's go through the MRO and patch each class implementing _more_tags
sklearn_tags_provider = {}
more_tags_provider = {}
class_order = []
for klass in reversed(type(estimator).mro()):
if "__sklearn_tags__" in vars(klass):
sklearn_tags_provider[klass] = klass.__sklearn_tags__(estimator) # type: ignore[attr-defined]
class_order.append(klass)
elif "_more_tags" in vars(klass):
more_tags_provider[klass] = klass._more_tags(estimator) # type: ignore[attr-defined]
class_order.append(klass)

# Find differences between consecutive in the case of __sklearn_tags__
# inheritance
sklearn_tags_diff = {}
items = list(sklearn_tags_provider.items())
for current_item, next_item in pairwise(items):
current_name, current_tags = current_item
next_name, next_tags = next_item
current_tags = _to_old_tags(current_tags)
next_tags = _to_old_tags(next_tags)

# Compare tags and store differences
diff = {}
for key in current_tags:
if current_tags[key] != next_tags[key]:
diff[key] = next_tags[key]

sklearn_tags_diff[next_name] = diff

tags = {}
for klass in class_order:
if klass in sklearn_tags_diff:
tags.update(sklearn_tags_diff[klass])
elif klass in more_tags_provider:
tags.update(more_tags_provider[klass])

tags = _to_new_tags(
{**_to_old_tags(default_tags(estimator)), **tags}, estimator
)

return tags


# TODO(1.7): Remove this function
def _safe_tags(estimator, key=None):
warnings.warn(
"The `_safe_tags` function is deprecated in 1.6 and will be removed in "
"1.7. Use the public `get_tags` function instead and make sure to implement "
"the `__sklearn_tags__` method.",
category=DeprecationWarning,
)
tags = _to_old_tags(get_tags(estimator))

if key is not None:
if key not in tags:
raise ValueError(
f"The key {key} is not defined for the class "
f"{estimator.__class__.__name__}."
try:
tags = estimator.__sklearn_tags__()
except AttributeError as exc:
# TODO(1.8): turn the warning into an error
if "object has no attribute '__sklearn_tags__'" in str(exc):
# Fall back to the default tags if the estimator does not
# implement __sklearn_tags__.
# In particular, workaround the regression reported in
# https://github.com/scikit-learn/scikit-learn/issues/30479
# `__sklearn_tags__` is implemented by calling
# `super().__sklearn_tags__()` but there is no `__sklearn_tags__`
# method in the base class. Typically happens when only inheriting
# from Mixins.

warnings.warn(
f"The following error was raised: {exc}. It seems that "
"there are no classes that implement `__sklearn_tags__` "
"in the MRO and/or all classes in the MRO call "
"`super().__sklearn_tags__()`. Make sure to inherit from "
"`BaseEstimator` which implements `__sklearn_tags__` (or "
"alternatively define `__sklearn_tags__` but we don't recommend "
"this approach). Note that `BaseEstimator` needs to be on the "
"right side of other Mixins in the inheritance order. The "
"default are now used instead since retrieving tags failed. "
"This warning will be replaced by an error in 1.8.",
category=DeprecationWarning,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting approach. I don't mind.

Copy link
Member

@glemaitre glemaitre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks good as well. I prefer this softer deprecation for one more release.

@glemaitre glemaitre merged commit 32aa82d into scikit-learn:main Apr 17, 2025
36 checks passed
lucyleeow pushed a commit to EmilyXinyi/scikit-learn that referenced this pull request Apr 23, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants