diff --git a/sklearn/base.py b/sklearn/base.py index d646f8d3e56bf..a5ffe50c1f50f 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -26,6 +26,8 @@ Tags, TargetTags, TransformerTags, + _to_old_tags, + default_tags, get_tags, ) from .utils.fixes import _IS_32BIT @@ -390,6 +392,22 @@ def __setstate__(self, state): self.__dict__.update(state) def __sklearn_tags__(self): + from sklearn.utils._tags import _find_tags_provider, _to_new_tags + + # TODO(1.7): Remove this block + if _find_tags_provider(self) == "_get_tags": + # one of the children classes only implements `_get_tags` so we need to + # warn and and mix old-style and new-style tags. + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + category=FutureWarning, + message="The `_get_tags` tag provider is deprecated", + ) + # silence the warning to avoid a false positive warning since this + # is not a direct user call but an internal one. + return _to_new_tags(self._get_tags()) + return Tags( estimator_type=None, target_tags=TargetTags(required=False), @@ -398,6 +416,54 @@ def __sklearn_tags__(self): classifier_tags=None, ) + # TODO(1.7): Remove this method + def _more_tags(self): + warnings.warn( + "The `_more_tags` method is deprecated in 1.6 and will be removed in " + "1.7. Please implement the `__sklearn_tags__` method.", + category=FutureWarning, + ) + return _to_old_tags(default_tags(self)) + + # TODO(1.7): Remove this method + def _get_tags(self): + warnings.warn( + "The `_get_tags` tag provider is deprecated in 1.6 and will be removed in " + "1.7. Please implement the `__sklearn_tags__` method.", + category=FutureWarning, + ) + # In case a user called `_get_tags` but that the estimator already did the job + # implementing `__sklearn_tags__` completely and removed `_more_tags`, let's + # default back to the future behaviour. Otherwise, we will get the default tags. + from sklearn.utils._tags import _find_tags_provider, _to_old_tags, get_tags + + if _find_tags_provider(self, warn=False) == "__sklearn_tags__": + return _to_old_tags(get_tags(self)) + + collected_tags = {} + for base_class in reversed(inspect.getmro(self.__class__)): + if hasattr(base_class, "_more_tags"): + # need the if because mixins might not have _more_tags + # but might do redundant work in estimators + # (i.e. calling more tags on BaseEstimator multiple times) + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + category=FutureWarning, + message="The `_more_tags` method is deprecated", + ) + # silence the warning to avoid a false positive warning since this + # is not a direct user call but an internal one. + more_tags = base_class._more_tags(self) + collected_tags.update(more_tags) + elif hasattr(base_class, "__sklearn_tags__"): + # Since that some people will inherit from scikit-learn that implements + # the new infrastructure, we need to collect it and merge it with + # the old tags. + more_tags = base_class.__sklearn_tags__(self) + collected_tags.update(_to_old_tags(more_tags)) + return collected_tags + def _validate_params(self): """Validate types and values of constructor parameters @@ -509,6 +575,10 @@ class ClassifierMixin: # TODO(1.8): Remove this attribute _estimator_type = "classifier" + # TODO(1.7): Remove this method + def _more_tags(self): + return {"requires_y": True} + def __sklearn_tags__(self): tags = super().__sklearn_tags__() tags.estimator_type = "classifier" @@ -582,6 +652,10 @@ class RegressorMixin: # TODO(1.8): Remove this attribute _estimator_type = "regressor" + # TODO(1.7): Remove this method + def _more_tags(self): + return {"requires_y": True} + def __sklearn_tags__(self): tags = super().__sklearn_tags__() tags.estimator_type = "regressor" @@ -658,6 +732,10 @@ class ClusterMixin: # TODO(1.8): Remove this attribute _estimator_type = "clusterer" + # TODO(1.7): Remove this method + def _more_tags(self): + return {"preserves_dtype": []} + def __sklearn_tags__(self): tags = super().__sklearn_tags__() tags.estimator_type = "clusterer" @@ -1153,6 +1231,10 @@ class MetaEstimatorMixin: class MultiOutputMixin: """Mixin to mark estimators that support multioutput.""" + # TODO(1.7): Remove this method + def _more_tags(self): + return {"multioutput": True} + def __sklearn_tags__(self): tags = super().__sklearn_tags__() tags.target_tags.multi_output = True @@ -1162,6 +1244,13 @@ def __sklearn_tags__(self): class _UnstableArchMixin: """Mark estimators that are non-determinstic on 32bit or PowerPC""" + # TODO(1.7): Remove this method + def _more_tags(self): + return { + "non_deterministic": _IS_32BIT + or platform.machine().startswith(("ppc", "powerpc")) + } + def __sklearn_tags__(self): tags = super().__sklearn_tags__() tags.non_deterministic = _IS_32BIT or platform.machine().startswith( diff --git a/sklearn/feature_extraction/text.py b/sklearn/feature_extraction/text.py index e1bdfd5a7dee5..a94920024ff98 100644 --- a/sklearn/feature_extraction/text.py +++ b/sklearn/feature_extraction/text.py @@ -16,7 +16,7 @@ import numpy as np import scipy.sparse as sp -from sklearn.utils import metadata_routing +from sklearn.utils import TransformerTags, metadata_routing from ..base import BaseEstimator, OneToOneFeatureMixin, TransformerMixin, _fit_context from ..exceptions import NotFittedError @@ -554,6 +554,11 @@ def _warn_for_unused_params(self): " since 'analyzer' != 'word'" ) + def __sklearn_tags__(self): + tags = super().__sklearn_tags__() + tags.transformer_tags = TransformerTags(preserves_dtype=[]) + return tags + class HashingVectorizer( TransformerMixin, _VectorizerMixin, BaseEstimator, auto_wrap_output_keys=None diff --git a/sklearn/feature_selection/tests/test_rfe.py b/sklearn/feature_selection/tests/test_rfe.py index ae11de2fadf59..dc54bdd6560a7 100644 --- a/sklearn/feature_selection/tests/test_rfe.py +++ b/sklearn/feature_selection/tests/test_rfe.py @@ -9,7 +9,7 @@ from joblib import parallel_backend from numpy.testing import assert_allclose, assert_array_almost_equal, assert_array_equal -from sklearn.base import BaseEstimator, ClassifierMixin, is_classifier +from sklearn.base import BaseEstimator, ClassifierMixin, TransformerMixin, is_classifier from sklearn.compose import TransformedTargetRegressor from sklearn.cross_decomposition import CCA, PLSCanonical, PLSRegression from sklearn.datasets import load_iris, make_classification, make_friedman1 @@ -27,7 +27,7 @@ from sklearn.utils.fixes import CSR_CONTAINERS -class MockClassifier(ClassifierMixin, BaseEstimator): +class MockClassifier(TransformerMixin, ClassifierMixin, BaseEstimator): """ Dummy classifier to test recursive feature elimination """ diff --git a/sklearn/manifold/_mds.py b/sklearn/manifold/_mds.py index dc9f88b502da5..07e7f541fa41a 100644 --- a/sklearn/manifold/_mds.py +++ b/sklearn/manifold/_mds.py @@ -16,6 +16,7 @@ from ..metrics import euclidean_distances from ..utils import check_array, check_random_state, check_symmetric from ..utils._param_validation import Interval, StrOptions, validate_params +from ..utils._tags import TransformerTags from ..utils.parallel import Parallel, delayed from ..utils.validation import validate_data @@ -572,6 +573,7 @@ def __init__( def __sklearn_tags__(self): tags = super().__sklearn_tags__() + tags.transformer_tags = TransformerTags(preserves_dtype=["float64"]) tags.input_tags.pairwise = self.dissimilarity == "precomputed" return tags diff --git a/sklearn/manifold/_spectral_embedding.py b/sklearn/manifold/_spectral_embedding.py index ebd5d7c5b651b..3d7ede8fb9358 100644 --- a/sklearn/manifold/_spectral_embedding.py +++ b/sklearn/manifold/_spectral_embedding.py @@ -23,6 +23,7 @@ ) from ..utils._arpack import _init_arpack_v0 from ..utils._param_validation import Interval, StrOptions, validate_params +from ..utils._tags import TransformerTags from ..utils.extmath import _deterministic_vector_sign_flip from ..utils.fixes import laplacian as csgraph_laplacian from ..utils.fixes import parse_version, sp_version @@ -654,6 +655,7 @@ def __sklearn_tags__(self): "precomputed", "precomputed_nearest_neighbors", ] + tags.transformer_tags = TransformerTags(preserves_dtype=["float64"]) return tags def _get_affinity_matrix(self, X, Y=None): diff --git a/sklearn/model_selection/_search.py b/sklearn/model_selection/_search.py index 7515436af33da..3ee877e1f7597 100644 --- a/sklearn/model_selection/_search.py +++ b/sklearn/model_selection/_search.py @@ -487,6 +487,7 @@ def __sklearn_tags__(self): tags.estimator_type = sub_estimator_tags.estimator_type tags.classifier_tags = deepcopy(sub_estimator_tags.classifier_tags) tags.regressor_tags = deepcopy(sub_estimator_tags.regressor_tags) + tags.transformer_tags = deepcopy(sub_estimator_tags.transformer_tags) # allows cross-validation to see 'precomputed' metrics tags.input_tags.pairwise = get_tags(self.estimator).input_tags.pairwise tags.array_api_support = get_tags(self.estimator).array_api_support diff --git a/sklearn/model_selection/tests/test_search.py b/sklearn/model_selection/tests/test_search.py index 5313e5d28a1a7..6e38536d1f509 100644 --- a/sklearn/model_selection/tests/test_search.py +++ b/sklearn/model_selection/tests/test_search.py @@ -15,7 +15,7 @@ from scipy.stats import bernoulli, expon, uniform from sklearn import config_context -from sklearn.base import BaseEstimator, ClassifierMixin, is_classifier +from sklearn.base import BaseEstimator, ClassifierMixin, TransformerMixin, is_classifier from sklearn.cluster import KMeans from sklearn.compose import ColumnTransformer from sklearn.datasets import ( @@ -100,7 +100,7 @@ # Neither of the following two estimators inherit from BaseEstimator, # to test hyperparameter search on user-defined classifiers. -class MockClassifier(ClassifierMixin, BaseEstimator): +class MockClassifier(TransformerMixin, ClassifierMixin, BaseEstimator): """Dummy classifier to test the parameter search algorithms""" def __init__(self, foo_param=0): diff --git a/sklearn/pipeline.py b/sklearn/pipeline.py index 9ff8a3549ef28..0b52559a92334 100644 --- a/sklearn/pipeline.py +++ b/sklearn/pipeline.py @@ -7,6 +7,7 @@ from collections import Counter, defaultdict from contextlib import contextmanager from copy import deepcopy +from functools import reduce from itertools import chain, islice import numpy as np @@ -23,7 +24,7 @@ _get_container_adapter, _safe_set_output, ) -from .utils._tags import get_tags +from .utils._tags import TransformerTags, get_tags from .utils._user_interface import _print_elapsed_time from .utils.deprecation import _deprecate_Xt_in_inverse_transform from .utils.metadata_routing import ( @@ -1230,6 +1231,18 @@ def __sklearn_tags__(self): pass try: + # dtype preservation will depend on the intersection of all steps + preserves_dtype = [] + for step in self.steps: + if step[1] is not None and step[1] != "passthrough": + step_tags = get_tags(step[1]) + if step_tags.transformer_tags is not None: + preserves_dtype.append( + set(step_tags.transformer_tags.preserves_dtype) + ) + if preserves_dtype: + preserves_dtype = list(reduce(set.intersection, preserves_dtype)) + if self.steps[-1][1] is not None and self.steps[-1][1] != "passthrough": last_step_tags = get_tags(self.steps[-1][1]) tags.estimator_type = last_step_tags.estimator_type @@ -1237,6 +1250,11 @@ def __sklearn_tags__(self): tags.classifier_tags = deepcopy(last_step_tags.classifier_tags) tags.regressor_tags = deepcopy(last_step_tags.regressor_tags) tags.transformer_tags = deepcopy(last_step_tags.transformer_tags) + if tags.transformer_tags is not None: + tags.transformer_tags.preserves_dtype = preserves_dtype + elif self.steps[-1][1] is None or self.steps[-1][1] == "passthrough": + # None and "passthrough" behave like a transformer + tags.transformer_tags = TransformerTags(preserves_dtype=[]) except (ValueError, AttributeError, TypeError): # This happens when the `steps` is not a list of (name, estimator) # tuples and `fit` is not called yet to validate the steps. diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index 59b45b93a7e24..c74c69fb27eab 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -17,7 +17,7 @@ from scipy.linalg import LinAlgWarning import sklearn -from sklearn.base import BaseEstimator +from sklearn.base import BaseEstimator, TransformerMixin from sklearn.compose import ColumnTransformer from sklearn.datasets import make_classification from sklearn.exceptions import ConvergenceWarning @@ -412,7 +412,7 @@ def test_transition_public_api_deprecations(): to the new developer public API from 1.5 to 1.6. """ - class OldEstimator(BaseEstimator): + class OldEstimator(TransformerMixin, BaseEstimator): def fit(self, X, y=None): X = self._validate_data(X) self._check_n_features(X, reset=True) diff --git a/sklearn/tests/test_metaestimators.py b/sklearn/tests/test_metaestimators.py index 214fc75a68364..4cf2a05a325dd 100644 --- a/sklearn/tests/test_metaestimators.py +++ b/sklearn/tests/test_metaestimators.py @@ -19,6 +19,7 @@ from sklearn.preprocessing import MaxAbsScaler, StandardScaler from sklearn.semi_supervised import SelfTrainingClassifier from sklearn.utils import all_estimators +from sklearn.utils._tags import TransformerTags from sklearn.utils._test_common.instance_generator import _construct_instances from sklearn.utils._testing import SkipTest, set_random_state from sklearn.utils.estimator_checks import ( @@ -143,6 +144,12 @@ def score(self, X, y, *args, **kwargs): self._check_fit() return 1.0 + def __sklearn_tags__(self): + tags = super().__sklearn_tags__() + if hasattr(self, "transform"): + tags.transformer_tags = TransformerTags(preserves_dtype=[]) + return tags + methods = [ k for k in SubEstimator.__dict__.keys() diff --git a/sklearn/tests/test_pipeline.py b/sklearn/tests/test_pipeline.py index d7a201f3abf6f..130793b7d9a46 100644 --- a/sklearn/tests/test_pipeline.py +++ b/sklearn/tests/test_pipeline.py @@ -2058,7 +2058,7 @@ def transform(self, X): ], ) def test_pipeline_warns_not_fitted(method): - class StatelessEstimator(BaseEstimator): + class StatelessEstimator(TransformerMixin, ClassifierMixin, BaseEstimator): """Stateless estimator that doesn't check if it's fitted. Stateless estimators that don't require fit, should properly set the @@ -2102,7 +2102,7 @@ def inverse_transform(self, X): # ===================================================================== -class SimpleEstimator(BaseEstimator): +class SimpleEstimator(TransformerMixin, ClassifierMixin, BaseEstimator): # This class is used in this section for testing routing in the pipeline. # This class should have every set_{method}_request def __sklearn_is_fitted__(self): diff --git a/sklearn/utils/_tags.py b/sklearn/utils/_tags.py index ccbc9d2438268..2bf43a1f0401c 100644 --- a/sklearn/utils/_tags.py +++ b/sklearn/utils/_tags.py @@ -1,7 +1,9 @@ from __future__ import annotations import warnings +from collections import OrderedDict from dataclasses import dataclass, field +from itertools import chain from .fixes import _dataclass_args @@ -243,7 +245,7 @@ class Tags: input_tags: InputTags = field(default_factory=InputTags) -# TODO(1.8): Remove this function +# TODO(1.7): Remove this function def default_tags(estimator) -> Tags: """Get the default tags for an estimator. @@ -290,6 +292,83 @@ def default_tags(estimator) -> Tags: ) +# TODO(1.7): Remove this function +def _find_tags_provider(estimator, warn=True): + """Find the tags provider for an estimator. + + Parameters + ---------- + estimator : estimator object + The estimator to find the tags provider for. + + warn : bool, default=True + Whether to warn if the tags provider is not found. + + Returns + ------- + tag_provider : str + The tags provider for the estimator. Can be one of: + - "_get_tags": to use the old tags infrastructure + - "__sklearn_tags__": to use the new tags infrastructure + """ + mro_model = type(estimator).mro() + tags_mro = OrderedDict() + for klass in mro_model: + tags_provider = [] + if "_more_tags" in vars(klass): + tags_provider.append("_more_tags") + if "_get_tags" in vars(klass): + tags_provider.append("_get_tags") + if "__sklearn_tags__" in vars(klass): + tags_provider.append("__sklearn_tags__") + tags_mro[klass.__name__] = tags_provider + + all_providers = set(chain.from_iterable(tags_mro.values())) + if "__sklearn_tags__" not in all_providers: + # default on the old tags infrastructure + return "_get_tags" + + tag_provider = "__sklearn_tags__" + encounter_sklearn_tags = False + err_msg = ( + f"Some classes from which {estimator.__class__.__name__} inherits only " + "use `_get_tags` and `_more_tags` while others implement the new " + "`__sklearn_tags__` method. There is no safe way to resolve the tags. " + "Please make sure to implement the `__sklearn_tags__` method in all " + "classes in the hierarchy." + ) + for klass in tags_mro: + has_get_or_more_tags = any( + provider in tags_mro[klass] for provider in ("_get_tags", "_more_tags") + ) + has_sklearn_tags = "__sklearn_tags__" in tags_mro[klass] + + if tags_mro[klass] and tag_provider == "__sklearn_tags__": # is it empty + if has_get_or_more_tags and not has_sklearn_tags: + if encounter_sklearn_tags: + # One of the child class already implemented __sklearn_tags__ + # We cannot anymore fallback to _get_tags + raise ValueError(err_msg) + # Case where a class does not implement __sklearn_tags__ and we fallback + # to _get_tags. We should therefore warn for implementing + # __sklearn_tags__. + tag_provider = "_get_tags" + encounter_sklearn_tags = True + + if warn and tag_provider == "_get_tags": + warnings.warn( + f"The {estimator.__class__.__name__} or classes from which it inherits " + "only use `_get_tags` and `_more_tags`. Please define the " + "`__sklearn_tags__` method, or inherit from `sklearn.base.BaseEstimator` " + "and other appropriate mixins such as `sklearn.base.TransformerMixin`, " + "`sklearn.base.ClassifierMixin`, `sklearn.base.RegressorMixin`, and " + "`sklearn.base.OutlierMixin`. From scikit-learn 1.7, not defining " + "`__sklearn_tags__` will raise an error.", + category=FutureWarning, + ) + return tag_provider + + def get_tags(estimator) -> Tags: """Get estimator tags. @@ -316,19 +395,190 @@ def get_tags(estimator) -> Tags: The estimator tags. """ - if hasattr(estimator, "__sklearn_tags__"): + tag_provider = _find_tags_provider(estimator) + if tag_provider == "__sklearn_tags__": tags = estimator.__sklearn_tags__() + + # TODO (1.7): Remove this block + # Catch the corner case where a transformer inheriting from BaseEstimator but + # that does not inherit from TransformerMixin ends up without the + # transformer_tags set properly. + if ( + hasattr(estimator, "transform") or hasattr(estimator, "fit_transform") + ) and tags.transformer_tags is None: + warnings.warn( + "The transformer tags are not set properly for the estimator " + f"{estimator.__class__.__name__}. This will raise an error in " + "scikit-learn 1.7. Inherit from `TransformerMixin` or properly set " + "the `transformer_tags` attribute in `__sklearn_tags__`.", + category=FutureWarning, + ) + tags.transformer_tags = TransformerTags() + # TODO(1.7): Remove this block + elif tag_provider == "_get_tags": + if hasattr(estimator, "_get_tags"): + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + category=FutureWarning, + message="The `_get_tags` tag provider is deprecated", + ) + # silence the warning to avoid a false positive warning since this + # is not a direct user call but an internal one. + tags = _to_new_tags(estimator._get_tags()) + elif hasattr(estimator, "_more_tags"): + tags = _to_old_tags(default_tags(estimator)) + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + category=FutureWarning, + message="The `_more_tags` method is deprecated", + ) + # silence the warning to avoid a false positive warning since this + # is not a direct user call but an internal one. + tags = {**tags, **estimator._more_tags()} + tags = _to_new_tags(tags) + else: + tags = default_tags(estimator) else: - warnings.warn( - f"Estimator {estimator} has no __sklearn_tags__ attribute, which is " - "defined in `sklearn.base.BaseEstimator`. This will raise an error in " - "scikit-learn 1.8. Please define the __sklearn_tags__ method, or inherit " - "from `sklearn.base.BaseEstimator` and other appropriate mixins such as " - "`sklearn.base.TransformerMixin`, `sklearn.base.ClassifierMixin`, " - "`sklearn.base.RegressorMixin`, and `sklearn.base.ClusterMixin`, and " - "`sklearn.base.OutlierMixin`.", - category=FutureWarning, - ) tags = default_tags(estimator) return tags + + +# TODO(1.7): Remove this function +def _to_new_tags(old_tags, estimator_type=None): + """Utility function convert old tags (dictionary) to new tags (dataclass).""" + input_tags = InputTags( + one_d_array="1darray" in old_tags["X_types"], + two_d_array="2darray" in old_tags["X_types"], + three_d_array="3darray" in old_tags["X_types"], + sparse="sparse" in old_tags["X_types"], + categorical="categorical" in old_tags["X_types"], + string="string" in old_tags["X_types"], + dict="dict" in old_tags["X_types"], + positive_only=old_tags["requires_positive_X"], + allow_nan=old_tags["allow_nan"], + pairwise=old_tags["pairwise"], + ) + target_tags = TargetTags( + required=old_tags["requires_y"], + one_d_labels="1dlabels" in old_tags["X_types"], + two_d_labels="2dlabels" in old_tags["X_types"], + positive_only=old_tags["requires_positive_y"], + multi_output=old_tags["multioutput"] or old_tags["multioutput_only"], + single_output=not old_tags["multioutput_only"], + ) + transformer_tags = TransformerTags( + preserves_dtype=old_tags["preserves_dtype"], + ) + classifier_tags = ClassifierTags( + poor_score=old_tags["poor_score"], + multi_class=not old_tags["binary_only"], + multi_label=old_tags["multilabel"], + ) + regressor_tags = RegressorTags( + poor_score=old_tags["poor_score"], + multi_label=old_tags["multilabel"], + ) + return Tags( + estimator_type=estimator_type, + target_tags=target_tags, + transformer_tags=transformer_tags, + classifier_tags=classifier_tags, + regressor_tags=regressor_tags, + input_tags=input_tags, + array_api_support=old_tags["array_api_support"], + no_validation=old_tags["no_validation"], + non_deterministic=old_tags["non_deterministic"], + requires_fit=old_tags["requires_fit"], + _skip_test=old_tags["_skip_test"], + ) + + +# TODO(1.7): Remove this function +def _to_old_tags(new_tags): + """Utility function convert old tags (dictionary) to new tags (dataclass).""" + if new_tags.classifier_tags: + binary_only = not new_tags.classifier_tags.multi_class + multilabel_clf = new_tags.classifier_tags.multi_label + poor_score_clf = new_tags.classifier_tags.poor_score + else: + binary_only = False + multilabel_clf = False + poor_score_clf = False + + if new_tags.regressor_tags: + multilabel_reg = new_tags.regressor_tags.multi_label + poor_score_reg = new_tags.regressor_tags.poor_score + else: + multilabel_reg = False + poor_score_reg = False + + if new_tags.transformer_tags: + preserves_dtype = new_tags.transformer_tags.preserves_dtype + else: + preserves_dtype = ["float64"] + + tags = { + "allow_nan": new_tags.input_tags.allow_nan, + "array_api_support": new_tags.array_api_support, + "binary_only": binary_only, + "multilabel": multilabel_clf or multilabel_reg, + "multioutput": new_tags.target_tags.multi_output, + "multioutput_only": ( + not new_tags.target_tags.single_output and new_tags.target_tags.multi_output + ), + "no_validation": new_tags.no_validation, + "non_deterministic": new_tags.non_deterministic, + "pairwise": new_tags.input_tags.pairwise, + "preserves_dtype": preserves_dtype, + "poor_score": poor_score_clf or poor_score_reg, + "requires_fit": new_tags.requires_fit, + "requires_positive_X": new_tags.input_tags.positive_only, + "requires_y": new_tags.target_tags.required, + "requires_positive_y": new_tags.target_tags.positive_only, + "_skip_test": new_tags._skip_test, + "stateless": new_tags.requires_fit, + } + X_types = [] + if new_tags.input_tags.one_d_array: + X_types.append("1darray") + if new_tags.input_tags.two_d_array: + X_types.append("2darray") + if new_tags.input_tags.three_d_array: + X_types.append("3darray") + if new_tags.input_tags.sparse: + X_types.append("sparse") + if new_tags.input_tags.categorical: + X_types.append("categorical") + if new_tags.input_tags.string: + X_types.append("string") + if new_tags.input_tags.dict: + X_types.append("dict") + if new_tags.target_tags.one_d_labels: + X_types.append("1dlabels") + if new_tags.target_tags.two_d_labels: + X_types.append("2dlabels") + tags["X_types"] = X_types + return tags + + +# TODO(1.7): Remove this function +def _safe_tags(estimator, key=None): + warnings.warn( + "The `_safe_tags` utility function is deprecated in 1.6 and will be removed in " + "1.7. Use the public `get_tags` function instead and make sure to implement " + "the `__sklearn_tags__` method.", + category=FutureWarning, + ) + tags = _to_old_tags(get_tags(estimator)) + + if key is not None: + if key not in tags: + raise ValueError( + f"The key {key} is not defined for the class " + f"{estimator.__class__.__name__}." + ) + return tags[key] + return tags diff --git a/sklearn/utils/tests/test_estimator_checks.py b/sklearn/utils/tests/test_estimator_checks.py index d09b3e7f366ec..4cb03bbfc4aa9 100644 --- a/sklearn/utils/tests/test_estimator_checks.py +++ b/sklearn/utils/tests/test_estimator_checks.py @@ -39,6 +39,7 @@ from sklearn.svm import SVC, NuSVC from sklearn.utils import _array_api, all_estimators, deprecated from sklearn.utils._param_validation import Interval, StrOptions +from sklearn.utils._tags import TransformerTags from sklearn.utils._test_common.instance_generator import ( _construct_instances, _get_expected_failed_checks, @@ -314,7 +315,31 @@ def fit(self, X, y): return self -class BadTransformerWithoutMixin(BaseEstimator): +class BadTransformerWithoutMixinWithTags(BaseEstimator): + def fit(self, X, y=None): + X = validate_data(self, X) + return self + + def transform(self, X): + check_is_fitted(self) + X = validate_data(self, X, reset=False) + return X + + def __sklearn_tags__(self): + tags = super().__sklearn_tags__() + tags.transformer_tags = TransformerTags() + return tags + + +class BadTransformerWithoutMixinWithoutTags(BaseEstimator): + """Transformer that does not implement `fit_transform` and the tags. + + TODO(1.7): + In 1.6, it will raise an AttributeError for `fit_transform` and a warning to + mention that the `transformer_tags` tag is not set. + As for 1.7, it will raise a RuntimeError because the tag is not set. + """ + def fit(self, X, y=None): X = validate_data(self, X) return self @@ -844,11 +869,20 @@ def test_check_outlier_corruption(): check_outlier_corruption(1, 2, decision) -def test_check_estimator_transformer_no_mixin(): - # check that TransformerMixin is not required for transformer tests to run - # but it fails since the tag is not set - with raises(RuntimeError, "the `transformer_tags` tag is not set"): - check_estimator(BadTransformerWithoutMixin()) +def test_check_estimator_transformer_no_mixin_with_tags(): + with raises(AttributeError, ".*fit_transform.*"): + check_estimator(BadTransformerWithoutMixinWithTags()) + + +def test_check_estimator_transformer_no_mixin_without_tags(): + # TODO(1.7): replace the type of exception raised and remove the warning + with raises(AttributeError, ".*fit_transform.*"): + with warnings.catch_warnings(record=True) as record: + warnings.filterwarnings("ignore", category=FutureWarning) + check_estimator(BadTransformerWithoutMixinWithoutTags()) + for rec in record: + assert issubclass(rec.category, FutureWarning) + assert "The transformer tags are not set properly" in str(rec.message) def test_check_estimator_clones(): diff --git a/sklearn/utils/tests/test_tags.py b/sklearn/utils/tests/test_tags.py index 413fbc6bbd3de..1ddefcaf17489 100644 --- a/sklearn/utils/tests/test_tags.py +++ b/sklearn/utils/tests/test_tags.py @@ -8,6 +8,7 @@ TransformerMixin, ) from sklearn.utils import Tags, get_tags +from sklearn.utils._tags import _safe_tags, _to_new_tags from sklearn.utils.estimator_checks import ( check_estimator_tags_renamed, check_valid_tag_types, @@ -78,3 +79,131 @@ def __sklearn_tags__(self): return tags check_valid_tag_types("MyEstimator", MyEstimator()) + + +######################################################################################## +# Test for the deprecation +# TODO(1.7): Remove this +######################################################################################## + + +def test_tags_deprecation(): + class ChildClass(RegressorMixin, BaseEstimator): + """Child implementing the old tags API together with our new API.""" + + def _more_tags(self): + return {"allow_nan": True} + + main_warn_msg = "only use `_get_tags` and `_more_tags`" + with pytest.warns(FutureWarning, match=main_warn_msg): + tags = ChildClass().__sklearn_tags__() + assert tags.input_tags.allow_nan + + with pytest.warns(FutureWarning) as warning_list: + tags = _safe_tags(ChildClass()) + assert len(warning_list) == 2, len(warning_list) + assert str(warning_list[0].message).startswith( + "The `_safe_tags` utility function is deprecated" + ) + assert main_warn_msg in str(warning_list[1].message) + + assert isinstance(tags, dict) + assert _to_new_tags(tags).input_tags.allow_nan + + with pytest.warns(FutureWarning, match=main_warn_msg): + tags = get_tags(ChildClass()) + assert tags.input_tags.allow_nan + + class ChildClass(RegressorMixin, BaseEstimator): + """Child implementing the old and new tags API during the transition period.""" + + def _more_tags(self): + return {"allow_nan": True} + + def __sklearn_tags__(self): + tags = super().__sklearn_tags__() + tags.input_tags.allow_nan = True + return tags + + tags = get_tags(ChildClass()) + assert tags.input_tags.allow_nan + + warn_msg = "`_get_tags` tag provider is deprecated" + with pytest.warns(FutureWarning, match=warn_msg): + tags = ChildClass()._get_tags() + assert isinstance(tags, dict) + assert _to_new_tags(tags).input_tags.allow_nan + + warn_msg = "`_safe_tags` utility function is deprecated" + with pytest.warns(FutureWarning, match=warn_msg): + tags = _safe_tags(ChildClass()) + assert isinstance(tags, dict) + assert _to_new_tags(tags).input_tags.allow_nan + + class ChildClass(RegressorMixin, BaseEstimator): + """Child not setting any tags.""" + + tags = get_tags(ChildClass()) + assert tags.target_tags.required + + warn_msg = "`_get_tags` tag provider is deprecated" + with pytest.warns(FutureWarning, match=warn_msg): + tags = ChildClass()._get_tags() + assert isinstance(tags, dict) + assert _to_new_tags(tags).target_tags.required + + warn_msg = "`_safe_tags` utility function is deprecated" + with pytest.warns(FutureWarning, match=warn_msg): + tags = _safe_tags(ChildClass()) + assert isinstance(tags, dict) + assert _to_new_tags(tags).target_tags.required + + class Mixin: + def _more_tags(self): + return {"allow_nan": True} + + class ChildClass(Mixin, BaseEstimator): + """Child following the new API with mixin following the old API.""" + + def __sklearn_tags__(self): + tags = super().__sklearn_tags__() + tags.target_tags.required = True + return tags + + err_msg = ( + "Some classes from which ChildClass inherits only use `_get_tags` and " + "`_more_tags`" + ) + with pytest.raises(ValueError, match=err_msg): + tags = get_tags(ChildClass()) + with pytest.raises(ValueError, match=err_msg): + with pytest.warns(FutureWarning): + tags = ChildClass()._get_tags() + with pytest.raises(ValueError, match=err_msg): + with pytest.warns(FutureWarning): + tags = _safe_tags(ChildClass()) + + class Mixin: + def _more_tags(self): + return {"allow_nan": True} + + class ChildClass(Mixin, BaseEstimator): + """Child following the old API with mixin following the old API.""" + + def _more_tags(self): + return {"requires_y": True} + + with pytest.warns(FutureWarning, match=main_warn_msg): + tags = ChildClass().__sklearn_tags__() + assert tags.input_tags.allow_nan + + with pytest.warns(FutureWarning) as warning_list: + tags = _safe_tags(ChildClass()) + assert len(warning_list) == 2, len(warning_list) + assert str(warning_list[0].message).startswith( + "The `_safe_tags` utility function is deprecated" + ) + assert main_warn_msg in str(warning_list[1].message) + + assert isinstance(tags, dict) + assert _to_new_tags(tags).input_tags.allow_nan diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index 669e40e137e17..e64f0ac4efb99 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -15,7 +15,7 @@ import sklearn from sklearn._config import config_context from sklearn._min_dependencies import dependent_packages -from sklearn.base import BaseEstimator +from sklearn.base import BaseEstimator, TransformerMixin from sklearn.datasets import make_blobs from sklearn.ensemble import RandomForestRegressor from sklearn.exceptions import NotFittedError, PositiveSpectrumWarning @@ -1990,7 +1990,7 @@ def test_get_feature_names_invalid_dtypes(names, dtypes): names = _get_feature_names(X) -class PassthroughTransformer(BaseEstimator): +class PassthroughTransformer(TransformerMixin, BaseEstimator): def fit(self, X, y=None): validate_data(self, X, reset=True) return self