-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
MNT Clean-up deprecations for 1.7: old tags #31134
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MNT Clean-up deprecations for 1.7: old tags #31134
Conversation
"alternatively define `__sklearn_tags__` but we don't recommend " | ||
"this approach). Note that `BaseEstimator` needs to be on the " | ||
"right side of other Mixins in the inheritance order. The " | ||
"default are now used instead since retrieving tags failed. " | ||
"This warning will be replaced by an error in 1.7.", | ||
category=DeprecationWarning, | ||
) | ||
tags = default_tags(estimator) | ||
else: | ||
raise | ||
else: | ||
# TODO(1.7): Remove this branch of the code | ||
# Let's go through the MRO and patch each class implementing _more_tags | ||
sklearn_tags_provider = {} | ||
more_tags_provider = {} | ||
class_order = [] | ||
for klass in reversed(type(estimator).mro()): | ||
if "__sklearn_tags__" in vars(klass): | ||
sklearn_tags_provider[klass] = klass.__sklearn_tags__(estimator) # type: ignore[attr-defined] | ||
class_order.append(klass) | ||
elif "_more_tags" in vars(klass): | ||
more_tags_provider[klass] = klass._more_tags(estimator) # type: ignore[attr-defined] | ||
class_order.append(klass) | ||
|
||
# Find differences between consecutive in the case of __sklearn_tags__ | ||
# inheritance | ||
sklearn_tags_diff = {} | ||
items = list(sklearn_tags_provider.items()) | ||
for current_item, next_item in pairwise(items): | ||
current_name, current_tags = current_item | ||
next_name, next_tags = next_item | ||
current_tags = _to_old_tags(current_tags) | ||
next_tags = _to_old_tags(next_tags) | ||
|
||
# Compare tags and store differences | ||
diff = {} | ||
for key in current_tags: | ||
if current_tags[key] != next_tags[key]: | ||
diff[key] = next_tags[key] | ||
|
||
sklearn_tags_diff[next_name] = diff | ||
|
||
tags = {} | ||
for klass in class_order: | ||
if klass in sklearn_tags_diff: | ||
tags.update(sklearn_tags_diff[klass]) | ||
elif klass in more_tags_provider: | ||
tags.update(more_tags_provider[klass]) | ||
|
||
tags = _to_new_tags( | ||
{**_to_old_tags(default_tags(estimator)), **tags}, estimator | ||
) | ||
|
||
return tags | ||
|
||
|
||
# TODO(1.7): Remove this function | ||
def _safe_tags(estimator, key=None): | ||
warnings.warn( | ||
"The `_safe_tags` function is deprecated in 1.6 and will be removed in " | ||
"1.7. Use the public `get_tags` function instead and make sure to implement " | ||
"the `__sklearn_tags__` method.", | ||
category=DeprecationWarning, | ||
) | ||
tags = _to_old_tags(get_tags(estimator)) | ||
|
||
if key is not None: | ||
if key not in tags: | ||
raise ValueError( | ||
f"The key {key} is not defined for the class " | ||
f"{estimator.__class__.__name__}." | ||
try: | ||
tags = estimator.__sklearn_tags__() | ||
except AttributeError as exc: | ||
# TODO(1.8): turn the warning into an error | ||
if "object has no attribute '__sklearn_tags__'" in str(exc): | ||
# Fall back to the default tags if the estimator does not | ||
# implement __sklearn_tags__. | ||
# In particular, workaround the regression reported in | ||
# https://github.com/scikit-learn/scikit-learn/issues/30479 | ||
# `__sklearn_tags__` is implemented by calling | ||
# `super().__sklearn_tags__()` but there is no `__sklearn_tags__` | ||
# method in the base class. Typically happens when only inheriting | ||
# from Mixins. | ||
|
||
warnings.warn( | ||
f"The following error was raised: {exc}. It seems that " | ||
"there are no classes that implement `__sklearn_tags__` " | ||
"in the MRO and/or all classes in the MRO call " | ||
"`super().__sklearn_tags__()`. Make sure to inherit from " | ||
"`BaseEstimator` which implements `__sklearn_tags__` (or " | ||
"alternatively define `__sklearn_tags__` but we don't recommend " | ||
"this approach). Note that `BaseEstimator` needs to be on the " | ||
"right side of other Mixins in the inheritance order. The " | ||
"default are now used instead since retrieving tags failed. " | ||
"This warning will be replaced by an error in 1.8.", | ||
category=DeprecationWarning, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did not follow the original plan here see my following comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting approach. I don't mind.
In So the following doesn't raise any warning in main: import numpy as np
from sklearn.pipeline import Pipeline
class MyEstimator:
def __init__(self, *, param=1):
self.param = param
def fit(self, X, y=None):
self.is_fitted_ = True
return self
def predict(self, X):
return np.full(shape=X.shape[0], fill_value=self.param)
my_pipeline = Pipeline([("estimator", MyEstimator(param=1))])
X = np.array([[1, 2], [2, 3], [3, 4]])
y = np.array([1, 0, 1])
my_pipeline.fit(X, y).predict(X) Cleaning-up |
@@ -1221,7 +1221,7 @@ def __sklearn_tags__(self): | |||
tags.input_tags.sparse = all( | |||
get_tags(step).input_tags.sparse | |||
for name, step in self.steps | |||
if step != "passthrough" | |||
if step is not None and step != "passthrough" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That was missed because get_tags(None)
used to return default tags.
# TODO(1.8): Update when implementing __sklearn_tags__ is required | ||
@pytest.mark.filterwarnings( | ||
"ignore:.*no __sklearn_tags__ attribute.*:DeprecationWarning" | ||
"ignore:.*no attribute '__sklearn_tags__'.*:DeprecationWarning" | ||
) | ||
@pytest.mark.parametrize( | ||
"estimator, value", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This filter used to not filter anything but didn't have to :)
Now that I've made it so that a warning is raised for estimator not implementing __sklearn_tags__
, it triggers and showed that the filter was wrong.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. I'll let @glemaitre have a look too.
"alternatively define `__sklearn_tags__` but we don't recommend " | ||
"this approach). Note that `BaseEstimator` needs to be on the " | ||
"right side of other Mixins in the inheritance order. The " | ||
"default are now used instead since retrieving tags failed. " | ||
"This warning will be replaced by an error in 1.7.", | ||
category=DeprecationWarning, | ||
) | ||
tags = default_tags(estimator) | ||
else: | ||
raise | ||
else: | ||
# TODO(1.7): Remove this branch of the code | ||
# Let's go through the MRO and patch each class implementing _more_tags | ||
sklearn_tags_provider = {} | ||
more_tags_provider = {} | ||
class_order = [] | ||
for klass in reversed(type(estimator).mro()): | ||
if "__sklearn_tags__" in vars(klass): | ||
sklearn_tags_provider[klass] = klass.__sklearn_tags__(estimator) # type: ignore[attr-defined] | ||
class_order.append(klass) | ||
elif "_more_tags" in vars(klass): | ||
more_tags_provider[klass] = klass._more_tags(estimator) # type: ignore[attr-defined] | ||
class_order.append(klass) | ||
|
||
# Find differences between consecutive in the case of __sklearn_tags__ | ||
# inheritance | ||
sklearn_tags_diff = {} | ||
items = list(sklearn_tags_provider.items()) | ||
for current_item, next_item in pairwise(items): | ||
current_name, current_tags = current_item | ||
next_name, next_tags = next_item | ||
current_tags = _to_old_tags(current_tags) | ||
next_tags = _to_old_tags(next_tags) | ||
|
||
# Compare tags and store differences | ||
diff = {} | ||
for key in current_tags: | ||
if current_tags[key] != next_tags[key]: | ||
diff[key] = next_tags[key] | ||
|
||
sklearn_tags_diff[next_name] = diff | ||
|
||
tags = {} | ||
for klass in class_order: | ||
if klass in sklearn_tags_diff: | ||
tags.update(sklearn_tags_diff[klass]) | ||
elif klass in more_tags_provider: | ||
tags.update(more_tags_provider[klass]) | ||
|
||
tags = _to_new_tags( | ||
{**_to_old_tags(default_tags(estimator)), **tags}, estimator | ||
) | ||
|
||
return tags | ||
|
||
|
||
# TODO(1.7): Remove this function | ||
def _safe_tags(estimator, key=None): | ||
warnings.warn( | ||
"The `_safe_tags` function is deprecated in 1.6 and will be removed in " | ||
"1.7. Use the public `get_tags` function instead and make sure to implement " | ||
"the `__sklearn_tags__` method.", | ||
category=DeprecationWarning, | ||
) | ||
tags = _to_old_tags(get_tags(estimator)) | ||
|
||
if key is not None: | ||
if key not in tags: | ||
raise ValueError( | ||
f"The key {key} is not defined for the class " | ||
f"{estimator.__class__.__name__}." | ||
try: | ||
tags = estimator.__sklearn_tags__() | ||
except AttributeError as exc: | ||
# TODO(1.8): turn the warning into an error | ||
if "object has no attribute '__sklearn_tags__'" in str(exc): | ||
# Fall back to the default tags if the estimator does not | ||
# implement __sklearn_tags__. | ||
# In particular, workaround the regression reported in | ||
# https://github.com/scikit-learn/scikit-learn/issues/30479 | ||
# `__sklearn_tags__` is implemented by calling | ||
# `super().__sklearn_tags__()` but there is no `__sklearn_tags__` | ||
# method in the base class. Typically happens when only inheriting | ||
# from Mixins. | ||
|
||
warnings.warn( | ||
f"The following error was raised: {exc}. It seems that " | ||
"there are no classes that implement `__sklearn_tags__` " | ||
"in the MRO and/or all classes in the MRO call " | ||
"`super().__sklearn_tags__()`. Make sure to inherit from " | ||
"`BaseEstimator` which implements `__sklearn_tags__` (or " | ||
"alternatively define `__sklearn_tags__` but we don't recommend " | ||
"this approach). Note that `BaseEstimator` needs to be on the " | ||
"right side of other Mixins in the inheritance order. The " | ||
"default are now used instead since retrieving tags failed. " | ||
"This warning will be replaced by an error in 1.8.", | ||
category=DeprecationWarning, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting approach. I don't mind.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks good as well. I prefer this softer deprecation for one more release.
Removed deprecated old tags infrastructure.