Skip to content

MAINT conversion old->new/new->old tags #30302

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 25 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
Tags,
TargetTags,
TransformerTags,
_to_old_tags,
default_tags,
get_tags,
)
from .utils.fixes import _IS_32BIT
Expand Down Expand Up @@ -390,6 +392,22 @@
self.__dict__.update(state)

def __sklearn_tags__(self):
from sklearn.utils._tags import _find_tags_provider, _to_new_tags

# TODO(1.7): Remove this block
if _find_tags_provider(self) == "_get_tags":
# one of the children classes only implements `_get_tags` so we need to
# warn and and mix old-style and new-style tags.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# warn and and mix old-style and new-style tags.
# warn and mix old-style and new-style tags.

with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
category=FutureWarning,
message="The `_get_tags` tag provider is deprecated",
)
# silence the warning to avoid a false positive warning since this
# is not a direct user call but an internal one.
return _to_new_tags(self._get_tags())

return Tags(
estimator_type=None,
target_tags=TargetTags(required=False),
Expand All @@ -398,6 +416,54 @@
classifier_tags=None,
)

# TODO(1.7): Remove this method
def _more_tags(self):
warnings.warn(
"The `_more_tags` method is deprecated in 1.6 and will be removed in "
"1.7. Please implement the `__sklearn_tags__` method.",
category=FutureWarning,
)
return _to_old_tags(default_tags(self))

# TODO(1.7): Remove this method
def _get_tags(self):
warnings.warn(
"The `_get_tags` tag provider is deprecated in 1.6 and will be removed in "
"1.7. Please implement the `__sklearn_tags__` method.",
category=FutureWarning,
)
# In case a user called `_get_tags` but that the estimator already did the job
# implementing `__sklearn_tags__` completely and removed `_more_tags`, let's
# default back to the future behaviour. Otherwise, we will get the default tags.
from sklearn.utils._tags import _find_tags_provider, _to_old_tags, get_tags

if _find_tags_provider(self, warn=False) == "__sklearn_tags__":
return _to_old_tags(get_tags(self))

collected_tags = {}
for base_class in reversed(inspect.getmro(self.__class__)):
if hasattr(base_class, "_more_tags"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this equivalent to checking vars?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the old implementation of _get_tags. So I would not touch it.
The only change here is to fallback to __sklearn_tags__ in case someone in the chain of inheritance implemented it (like scikit-learn implementing the new tags and removing the old tags).

Copy link
Member Author

@glemaitre glemaitre Nov 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now, I get your remark. I think that _more_tags was the API trick to avoid going through the MRO and thus an equivalent of vars(__sklearn_tags__) but by design. This should solve my bug below :) (edit: or not because I don't know how to avoid the climbing of the inheritance).

# need the if because mixins might not have _more_tags
# but might do redundant work in estimators
# (i.e. calling more tags on BaseEstimator multiple times)
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
category=FutureWarning,
message="The `_more_tags` method is deprecated",
)
# silence the warning to avoid a false positive warning since this
# is not a direct user call but an internal one.
more_tags = base_class._more_tags(self)
collected_tags.update(more_tags)
elif hasattr(base_class, "__sklearn_tags__"):
# Since that some people will inherit from scikit-learn that implements
# the new infrastructure, we need to collect it and merge it with
# the old tags.
more_tags = base_class.__sklearn_tags__(self)
collected_tags.update(_to_old_tags(more_tags))

Check warning on line 464 in sklearn/base.py

View check run for this annotation

Codecov / codecov/patch

sklearn/base.py#L463-L464

Added lines #L463 - L464 were not covered by tests
Comment on lines +463 to +464
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pretty sure this is gonna be buggy since __sklearn_tags__ collects all tags from base_class upwards.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you provide me a case where it would break?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is not covered, I'm not doing the case that I was thinking of :).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this would be:

class Mixin:
    def __sklearn_tags__(self):
        tags = super().__sklearn_tags__()
        tags.input_tags.allow_nan = True
        return tags

class MyClass(Mixin, BaseEstimator):

    def _more_tags(self):
        return {
            "multioutput": True
        }

and indeed it end up with a recursion error.

return collected_tags

def _validate_params(self):
"""Validate types and values of constructor parameters

Expand Down Expand Up @@ -509,6 +575,10 @@
# TODO(1.8): Remove this attribute
_estimator_type = "classifier"

# TODO(1.7): Remove this method
def _more_tags(self):
return {"requires_y": True}

Check warning on line 580 in sklearn/base.py

View check run for this annotation

Codecov / codecov/patch

sklearn/base.py#L580

Added line #L580 was not covered by tests

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.estimator_type = "classifier"
Expand Down Expand Up @@ -582,6 +652,10 @@
# TODO(1.8): Remove this attribute
_estimator_type = "regressor"

# TODO(1.7): Remove this method
def _more_tags(self):
return {"requires_y": True}

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.estimator_type = "regressor"
Expand Down Expand Up @@ -658,6 +732,10 @@
# TODO(1.8): Remove this attribute
_estimator_type = "clusterer"

# TODO(1.7): Remove this method
def _more_tags(self):
return {"preserves_dtype": []}

Check warning on line 737 in sklearn/base.py

View check run for this annotation

Codecov / codecov/patch

sklearn/base.py#L737

Added line #L737 was not covered by tests

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.estimator_type = "clusterer"
Expand Down Expand Up @@ -1153,6 +1231,10 @@
class MultiOutputMixin:
"""Mixin to mark estimators that support multioutput."""

# TODO(1.7): Remove this method
def _more_tags(self):
return {"multioutput": True}

Check warning on line 1236 in sklearn/base.py

View check run for this annotation

Codecov / codecov/patch

sklearn/base.py#L1236

Added line #L1236 was not covered by tests

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.target_tags.multi_output = True
Expand All @@ -1162,6 +1244,13 @@
class _UnstableArchMixin:
"""Mark estimators that are non-determinstic on 32bit or PowerPC"""

# TODO(1.7): Remove this method
def _more_tags(self):
return {

Check warning on line 1249 in sklearn/base.py

View check run for this annotation

Codecov / codecov/patch

sklearn/base.py#L1249

Added line #L1249 was not covered by tests
"non_deterministic": _IS_32BIT
or platform.machine().startswith(("ppc", "powerpc"))
}

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.non_deterministic = _IS_32BIT or platform.machine().startswith(
Expand Down
7 changes: 6 additions & 1 deletion sklearn/feature_extraction/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import numpy as np
import scipy.sparse as sp

from sklearn.utils import metadata_routing
from sklearn.utils import TransformerTags, metadata_routing

from ..base import BaseEstimator, OneToOneFeatureMixin, TransformerMixin, _fit_context
from ..exceptions import NotFittedError
Expand Down Expand Up @@ -554,6 +554,11 @@ def _warn_for_unused_params(self):
" since 'analyzer' != 'word'"
)

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.transformer_tags = TransformerTags(preserves_dtype=[])
return tags
Comment on lines +557 to +560
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this class is not a transformer, why do we attach transformer tags here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A transformer is something that has a fit + transform method, isn't it?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, and the _VectorizerMixin mixin has neither of them.



class HashingVectorizer(
TransformerMixin, _VectorizerMixin, BaseEstimator, auto_wrap_output_keys=None
Expand Down
4 changes: 2 additions & 2 deletions sklearn/feature_selection/tests/test_rfe.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from joblib import parallel_backend
from numpy.testing import assert_allclose, assert_array_almost_equal, assert_array_equal

from sklearn.base import BaseEstimator, ClassifierMixin, is_classifier
from sklearn.base import BaseEstimator, ClassifierMixin, TransformerMixin, is_classifier
from sklearn.compose import TransformedTargetRegressor
from sklearn.cross_decomposition import CCA, PLSCanonical, PLSRegression
from sklearn.datasets import load_iris, make_classification, make_friedman1
Expand All @@ -27,7 +27,7 @@
from sklearn.utils.fixes import CSR_CONTAINERS


class MockClassifier(ClassifierMixin, BaseEstimator):
class MockClassifier(TransformerMixin, ClassifierMixin, BaseEstimator):
"""
Dummy classifier to test recursive feature elimination
"""
Expand Down
2 changes: 2 additions & 0 deletions sklearn/manifold/_mds.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ..metrics import euclidean_distances
from ..utils import check_array, check_random_state, check_symmetric
from ..utils._param_validation import Interval, StrOptions, validate_params
from ..utils._tags import TransformerTags
from ..utils.parallel import Parallel, delayed
from ..utils.validation import validate_data

Expand Down Expand Up @@ -572,6 +573,7 @@ def __init__(

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.transformer_tags = TransformerTags(preserves_dtype=["float64"])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why isn't this inheriting from TransformerMixin?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because we don't want to implement transform but only fit_transform such that you can only provide the same data at fit and transform.

At least this is what I recall.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TransformerMixin doesn't do anything to transform. Here we'd use its tags, and override fit_transform from the mixin.

tags.input_tags.pairwise = self.dissimilarity == "precomputed"
return tags

Expand Down
2 changes: 2 additions & 0 deletions sklearn/manifold/_spectral_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
)
from ..utils._arpack import _init_arpack_v0
from ..utils._param_validation import Interval, StrOptions, validate_params
from ..utils._tags import TransformerTags
from ..utils.extmath import _deterministic_vector_sign_flip
from ..utils.fixes import laplacian as csgraph_laplacian
from ..utils.fixes import parse_version, sp_version
Expand Down Expand Up @@ -654,6 +655,7 @@ def __sklearn_tags__(self):
"precomputed",
"precomputed_nearest_neighbors",
]
tags.transformer_tags = TransformerTags(preserves_dtype=["float64"])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same answer than above ;)

return tags

def _get_affinity_matrix(self, X, Y=None):
Expand Down
1 change: 1 addition & 0 deletions sklearn/model_selection/_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,7 @@ def __sklearn_tags__(self):
tags.estimator_type = sub_estimator_tags.estimator_type
tags.classifier_tags = deepcopy(sub_estimator_tags.classifier_tags)
tags.regressor_tags = deepcopy(sub_estimator_tags.regressor_tags)
tags.transformer_tags = deepcopy(sub_estimator_tags.transformer_tags)
# allows cross-validation to see 'precomputed' metrics
tags.input_tags.pairwise = get_tags(self.estimator).input_tags.pairwise
tags.array_api_support = get_tags(self.estimator).array_api_support
Expand Down
4 changes: 2 additions & 2 deletions sklearn/model_selection/tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from scipy.stats import bernoulli, expon, uniform

from sklearn import config_context
from sklearn.base import BaseEstimator, ClassifierMixin, is_classifier
from sklearn.base import BaseEstimator, ClassifierMixin, TransformerMixin, is_classifier
from sklearn.cluster import KMeans
from sklearn.compose import ColumnTransformer
from sklearn.datasets import (
Expand Down Expand Up @@ -100,7 +100,7 @@

# Neither of the following two estimators inherit from BaseEstimator,
# to test hyperparameter search on user-defined classifiers.
class MockClassifier(ClassifierMixin, BaseEstimator):
class MockClassifier(TransformerMixin, ClassifierMixin, BaseEstimator):
Comment on lines 101 to +103
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the comment is not valid anymore, and we have the same estimator in test_rfe

"""Dummy classifier to test the parameter search algorithms"""

def __init__(self, foo_param=0):
Expand Down
20 changes: 19 additions & 1 deletion sklearn/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from collections import Counter, defaultdict
from contextlib import contextmanager
from copy import deepcopy
from functools import reduce
from itertools import chain, islice

import numpy as np
Expand All @@ -23,7 +24,7 @@
_get_container_adapter,
_safe_set_output,
)
from .utils._tags import get_tags
from .utils._tags import TransformerTags, get_tags
from .utils._user_interface import _print_elapsed_time
from .utils.deprecation import _deprecate_Xt_in_inverse_transform
from .utils.metadata_routing import (
Expand Down Expand Up @@ -1230,13 +1231,30 @@ def __sklearn_tags__(self):
pass

try:
# dtype preservation will depend on the intersection of all steps
preserves_dtype = []
for step in self.steps:
if step[1] is not None and step[1] != "passthrough":
step_tags = get_tags(step[1])
if step_tags.transformer_tags is not None:
preserves_dtype.append(
set(step_tags.transformer_tags.preserves_dtype)
)
if preserves_dtype:
preserves_dtype = list(reduce(set.intersection, preserves_dtype))

if self.steps[-1][1] is not None and self.steps[-1][1] != "passthrough":
last_step_tags = get_tags(self.steps[-1][1])
tags.estimator_type = last_step_tags.estimator_type
tags.target_tags.multi_output = last_step_tags.target_tags.multi_output
tags.classifier_tags = deepcopy(last_step_tags.classifier_tags)
tags.regressor_tags = deepcopy(last_step_tags.regressor_tags)
tags.transformer_tags = deepcopy(last_step_tags.transformer_tags)
if tags.transformer_tags is not None:
tags.transformer_tags.preserves_dtype = preserves_dtype
elif self.steps[-1][1] is None or self.steps[-1][1] == "passthrough":
# None and "passthrough" behave like a transformer
tags.transformer_tags = TransformerTags(preserves_dtype=[])
except (ValueError, AttributeError, TypeError):
# This happens when the `steps` is not a list of (name, estimator)
# tuples and `fit` is not called yet to validate the steps.
Expand Down
4 changes: 2 additions & 2 deletions sklearn/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from scipy.linalg import LinAlgWarning

import sklearn
from sklearn.base import BaseEstimator
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.compose import ColumnTransformer
from sklearn.datasets import make_classification
from sklearn.exceptions import ConvergenceWarning
Expand Down Expand Up @@ -412,7 +412,7 @@ def test_transition_public_api_deprecations():
to the new developer public API from 1.5 to 1.6.
"""

class OldEstimator(BaseEstimator):
class OldEstimator(TransformerMixin, BaseEstimator):
def fit(self, X, y=None):
X = self._validate_data(X)
self._check_n_features(X, reset=True)
Expand Down
7 changes: 7 additions & 0 deletions sklearn/tests/test_metaestimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from sklearn.preprocessing import MaxAbsScaler, StandardScaler
from sklearn.semi_supervised import SelfTrainingClassifier
from sklearn.utils import all_estimators
from sklearn.utils._tags import TransformerTags
from sklearn.utils._test_common.instance_generator import _construct_instances
from sklearn.utils._testing import SkipTest, set_random_state
from sklearn.utils.estimator_checks import (
Expand Down Expand Up @@ -143,6 +144,12 @@ def score(self, X, y, *args, **kwargs):
self._check_fit()
return 1.0

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
if hasattr(self, "transform"):
tags.transformer_tags = TransformerTags(preserves_dtype=[])
return tags

methods = [
k
for k in SubEstimator.__dict__.keys()
Expand Down
4 changes: 2 additions & 2 deletions sklearn/tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2058,7 +2058,7 @@ def transform(self, X):
],
)
def test_pipeline_warns_not_fitted(method):
class StatelessEstimator(BaseEstimator):
class StatelessEstimator(TransformerMixin, ClassifierMixin, BaseEstimator):
"""Stateless estimator that doesn't check if it's fitted.

Stateless estimators that don't require fit, should properly set the
Expand Down Expand Up @@ -2102,7 +2102,7 @@ def inverse_transform(self, X):
# =====================================================================


class SimpleEstimator(BaseEstimator):
class SimpleEstimator(TransformerMixin, ClassifierMixin, BaseEstimator):
# This class is used in this section for testing routing in the pipeline.
# This class should have every set_{method}_request
def __sklearn_is_fitted__(self):
Expand Down
Loading