From d74c0e80865e7cb5dc2adccaef05e3e98eaa78d3 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 18 Nov 2024 22:45:59 +0100 Subject: [PATCH 01/24] MAINT conversion old->new new->old tags --- sklearn/utils/_tags.py | 153 ++++++++++++++++++++++++++++++ sklearn/utils/estimator_checks.py | 16 ++-- 2 files changed, 163 insertions(+), 6 deletions(-) diff --git a/sklearn/utils/_tags.py b/sklearn/utils/_tags.py index ccbc9d2438268..07f9b6162cdf3 100644 --- a/sklearn/utils/_tags.py +++ b/sklearn/utils/_tags.py @@ -318,6 +318,14 @@ def get_tags(estimator) -> Tags: if hasattr(estimator, "__sklearn_tags__"): tags = estimator.__sklearn_tags__() + elif hasattr(estimator, "_get_tags"): + warnings.warn("BROKEN SOON, IT WILL BE", FutureWarning) + tags = _to_new_tags(estimator._get_tags()) + elif hasattr(estimator, "_more_tags"): + warnings.warn("BROKEN SOON, IT WILL BE", FutureWarning) + tags = _to_old_tags(default_tags(estimator)) + tags = {**tags, **estimator._more_tags()} + tags = _to_new_tags(tags) else: warnings.warn( f"Estimator {estimator} has no __sklearn_tags__ attribute, which is " @@ -332,3 +340,148 @@ def get_tags(estimator) -> Tags: tags = default_tags(estimator) return tags + + +def _to_new_tags(old_tags, estimator_type=None): + """Utility function convert old tags (dictionary) to new tags (dataclass).""" + input_tags = InputTags( + one_d_array="1darray" in old_tags["X_types"], + two_d_array="2darray" in old_tags["X_types"], + three_d_array="3darray" in old_tags["X_types"], + sparse="sparse" in old_tags["X_types"], + categorical="categorical" in old_tags["X_types"], + string="string" in old_tags["X_types"], + dict="dict" in old_tags["X_types"], + positive_only=old_tags["requires_positive_X"], + allow_nan=old_tags["allow_nan"], + pairwise=old_tags["pairwise"], + ) + target_tags = TargetTags( + required=old_tags["requires_y"], + one_d_labels="1dlabels" in old_tags["X_types"], + two_d_labels="2dlabels" in old_tags["X_types"], + positive_only=old_tags["requires_positive_y"], + multi_output=old_tags["multioutput"] or old_tags["multioutput_only"], + single_output=not old_tags["multioutput_only"], + ) + transformer_tags = TransformerTags( + preserves_dtype=old_tags["preserves_dtype"], + ) + classifier_tags = ClassifierTags( + poor_score=old_tags["poor_score"], + multi_class=not old_tags["binary_only"], + multi_label=old_tags["multilabel"], + ) + regressor_tags = RegressorTags( + poor_score=old_tags["poor_score"], + multi_label=old_tags["multilabel"], + ) + return Tags( + estimator_type=estimator_type, + target_tags=target_tags, + transformer_tags=transformer_tags, + classifier_tags=classifier_tags, + regressor_tags=regressor_tags, + input_tags=input_tags, + array_api_support=old_tags["array_api_support"], + no_validation=old_tags["no_validation"], + non_deterministic=old_tags["non_deterministic"], + requires_fit=old_tags["requires_fit"], + _skip_test=old_tags["_skip_test"], + ) + + +def _to_old_tags(new_tags): + """Utility function convert old tags (dictionary) to new tags (dataclass).""" + if new_tags.classifier_tags: + binary_only = not new_tags.classifier_tags.multi_class + multilabel_clf = new_tags.classifier_tags.multi_label + poor_score_clf = new_tags.classifier_tags.poor_score + else: + binary_only = False + multilabel_clf = False + poor_score_clf = False + + if new_tags.regressor_tags: + multilabel_reg = new_tags.regressor_tags.multi_label + poor_score_reg = new_tags.regressor_tags.poor_score + else: + multilabel_reg = False + poor_score_reg = False + + if new_tags.transformer_tags: + preserves_dtype = new_tags.transformer_tags.preserves_dtype + else: + preserves_dtype = ["float64"] + + tags = { + "allow_nan": new_tags.input_tags.allow_nan, + "array_api_support": new_tags.array_api_support, + "binary_only": binary_only, + "multilabel": multilabel_clf or multilabel_reg, + "multioutput": new_tags.target_tags.multi_output, + "multioutput_only": ( + not new_tags.target_tags.single_output + and new_tags.target_tags.multi_output + ), + "no_validation": new_tags.no_validation, + "non_deterministic": new_tags.non_deterministic, + "pairwise": new_tags.input_tags.pairwise, + "preserves_dtype": preserves_dtype, + "poor_score": poor_score_clf or poor_score_reg, + "requires_fit": new_tags.requires_fit, + "requires_positive_X": new_tags.input_tags.positive_only, + "requires_y": new_tags.target_tags.required, + "requires_positive_y": new_tags.target_tags.positive_only, + "_skip_test": new_tags._skip_test, + "stateless": new_tags.requires_fit, + } + X_types = [] + if new_tags.input_tags.one_d_array: + X_types.append("1darray") + if new_tags.input_tags.two_d_array: + X_types.append("2darray") + if new_tags.input_tags.three_d_array: + X_types.append("3darray") + if new_tags.input_tags.sparse: + X_types.append("sparse") + if new_tags.input_tags.categorical: + X_types.append("categorical") + if new_tags.input_tags.string: + X_types.append("string") + if new_tags.input_tags.dict: + X_types.append("dict") + if new_tags.target_tags.one_d_labels: + X_types.append("1dlabels") + if new_tags.target_tags.two_d_labels: + X_types.append("2dlabels") + tags["X_types"] = X_types + return tags + + +def _safe_tags(estimator, key=None): + warnings.warn( + "The `_safe_tags` utility function is deprecated in 1.6 and will be removed in " + "1.7. Use the public `get_tags` function instead and make sure to implement " + "the `__sklearn_tags__` method.", + category=FutureWarning, + ) + if hasattr(estimator, "_get_tags"): + tags_provider = "_get_tags()" + tags = estimator._get_tags() + elif hasattr(estimator, "_more_tags"): + tags_provider = "_more_tags()" + tags = _to_old_tags(default_tags(estimator)) + tags = {**tags, **estimator._more_tags()} + else: + tags_provider = "_DEFAULT_TAGS" + tags = _to_old_tags(default_tags(estimator)) + + if key is not None: + if key not in tags: + raise ValueError( + f"The key {key} is not defined in {tags_provider} for the " + f"class {estimator.__class__.__name__}." + ) + return tags[key] + return tags diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index abf272e955bc2..622c8f04212ec 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -115,12 +115,14 @@ def _raise_for_missing_tags(estimator, tag_name, Mixin): tags = get_tags(estimator) estimator_type = Mixin.__name__.replace("Mixin", "") if getattr(tags, tag_name) is None: - raise RuntimeError( - f"Estimator {estimator.__class__.__name__} seems to be a {estimator_type}," - f" but the `{tag_name}` tag is not set. Either set the tag manually" - f" or inherit from the {Mixin.__name__}. Note that the order of inheritance" - f" matters, the {Mixin.__name__} should come before BaseEstimator." - ) + # raise RuntimeError( + # f"Estimator {estimator.__class__.__name__} seems to be a {estimator_type}," + # f" but the `{tag_name}` tag is not set. Either set the tag manually" + # f" or inherit from the {Mixin.__name__}. Note that the order of inheritance" + # f" matters, the {Mixin.__name__} should come before BaseEstimator." + # ) + warnings.warn("BROKEN SOON, IT WILL BE", FutureWarning) + def _yield_api_checks(estimator): @@ -276,6 +278,8 @@ def _yield_regressor_checks(regressor): def _yield_transformer_checks(transformer): _raise_for_missing_tags(transformer, "transformer_tags", TransformerMixin) tags = get_tags(transformer) + print(transformer) + print(tags) # All transformers should either deal with sparse data or raise an # exception with type TypeError and an intelligible error message if not tags.no_validation: From 4d70067308e28ed769c87e07ed924bc63c86f8eb Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 19 Nov 2024 19:30:33 +0100 Subject: [PATCH 02/24] iter --- sklearn/base.py | 49 +++++++++++++++ sklearn/utils/_tags.py | 135 +++++++++++++++++++++++++++++++---------- 2 files changed, 151 insertions(+), 33 deletions(-) diff --git a/sklearn/base.py b/sklearn/base.py index bd5e07c2167dd..5ae53898f862e 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -26,7 +26,9 @@ Tags, TargetTags, TransformerTags, + default_tags, get_tags, + _to_old_tags, ) from .utils.fixes import _IS_32BIT from .utils.validation import ( @@ -395,6 +397,30 @@ def __sklearn_tags__(self): classifier_tags=None, ) + def _more_tags(self): + warnings.warn( + "The `_more_tags` method is deprecated in 1.8 and will be removed in " + "1.9. Please implement the `__sklearn_tags__` method.", + category=FutureWarning, + ) + return _to_old_tags(default_tags(self)) + + def _get_tags(self): + warnings.warn( + "The `_get_tags` tag provider is deprecated in 1.8 and will be removed in " + "1.9. Please implement the `__sklearn_tags__` method.", + category=FutureWarning, + ) + collected_tags = {} + for base_class in reversed(inspect.getmro(self.__class__)): + if hasattr(base_class, "_more_tags"): + # need the if because mixins might not have _more_tags + # but might do redundant work in estimators + # (i.e. calling more tags on BaseEstimator multiple times) + more_tags = base_class._more_tags(self) + collected_tags.update(more_tags) + return collected_tags + def _validate_params(self): """Validate types and values of constructor parameters @@ -477,6 +503,10 @@ class ClassifierMixin: # TODO(1.8): Remove this attribute _estimator_type = "classifier" + # TODO(1.8): Remove this method + def _more_tags(self): + return {"requires_y": True} + def __sklearn_tags__(self): tags = super().__sklearn_tags__() tags.estimator_type = "classifier" @@ -550,6 +580,10 @@ class RegressorMixin: # TODO(1.8): Remove this attribute _estimator_type = "regressor" + # TODO(1.8): Remove this method + def _more_tags(self): + return {"requires_y": True} + def __sklearn_tags__(self): tags = super().__sklearn_tags__() tags.estimator_type = "regressor" @@ -626,6 +660,10 @@ class ClusterMixin: # TODO(1.8): Remove this attribute _estimator_type = "clusterer" + # TODO(1.8): Remove this method + def _more_tags(self): + return {"preserves_dtype": []} + def __sklearn_tags__(self): tags = super().__sklearn_tags__() tags.estimator_type = "clusterer" @@ -1121,6 +1159,10 @@ class MetaEstimatorMixin: class MultiOutputMixin: """Mixin to mark estimators that support multioutput.""" + # TODO(1.8): Remove this method + def _more_tags(self): + return {"multioutput": True} + def __sklearn_tags__(self): tags = super().__sklearn_tags__() tags.target_tags.multi_output = True @@ -1130,6 +1172,13 @@ def __sklearn_tags__(self): class _UnstableArchMixin: """Mark estimators that are non-determinstic on 32bit or PowerPC""" + # TODO(1.8): Remove this method + def _more_tags(self): + return { + "non_deterministic": _IS_32BIT + or platform.machine().startswith(("ppc", "powerpc")) + } + def __sklearn_tags__(self): tags = super().__sklearn_tags__() tags.non_deterministic = _IS_32BIT or platform.machine().startswith( diff --git a/sklearn/utils/_tags.py b/sklearn/utils/_tags.py index 07f9b6162cdf3..dd75c6670939c 100644 --- a/sklearn/utils/_tags.py +++ b/sklearn/utils/_tags.py @@ -1,7 +1,9 @@ from __future__ import annotations import warnings +from collections import OrderedDict from dataclasses import dataclass, field +from itertools import chain from .fixes import _dataclass_args @@ -290,6 +292,89 @@ def default_tags(estimator) -> Tags: ) +# TODO(1.7): Remove this function +def _find_tags_provider(estimator): + """Find the tags provider for an estimator. + + Parameters + ---------- + estimator : estimator object + The estimator to find the tags provider for. + + Returns + ------- + tag_provider : str + The tags provider for the estimator. Can be one of: + - "_get_tags": to use the old tags infrastructure + - "__sklearn_tags__": to use the new tags infrastructure + """ + mro_model = type(estimator).mro() + tags_mro = OrderedDict() + for klass in mro_model: + tags_provider = [] + if "_more_tags" in vars(klass): + tags_provider.append("_more_tags") + if "_get_tags" in vars(klass): + tags_provider.append("_get_tags") + if "__sklearn_tags__" in vars(klass): + tags_provider.append("__sklearn_tags__") + tags_mro[klass.__name__] = tags_provider + + all_providers = set(chain.from_iterable(tags_mro.values())) + if "__sklearn_tags__" not in all_providers: + # default on the old tags infrastructure + return "_get_tags" + + tag_provider = "__sklearn_tags__" + encounter_sklearn_tags = False + err_msg = ( + f"Some classes from which {estimator.__class__.__name__} inherits only " + "use `_get_tags` and `_more_tags` while others implement the new " + "`__sklearn_tags__` method. There is no safe way to resolve the tags. " + "Please make sure to implement the `__sklearn_tags__` method in all " + "classes in the hierarchy." + ) + for klass in tags_mro: + has_get_or_more_tags = any( + provider in tags_mro[klass] for provider in ("_get_tags", "_more_tags") + ) + has_sklearn_tags = "__sklearn_tags__" in tags_mro[klass] + + if ( + tags_mro[klass] # is it empty + and tag_provider == "_get_tags" + and not has_get_or_more_tags + and has_sklearn_tags + ): + # Case where a class in the middle implements only __sklearn_tags__ but we + # already fallback to _get_tags. There is no safe way to resolve the tags. + raise ValueError(err_msg) + elif tags_mro[klass] and tag_provider == "__sklearn_tags__": # is it empty + if has_get_or_more_tags and not has_sklearn_tags: + if encounter_sklearn_tags: + # One of the child class already implemented __sklearn_tags__ + # We cannot anymore fallback to _get_tags + raise ValueError(err_msg) + # Case where a class does not implement __sklearn_tags__ and we fallback + # to _get_tags. We should therefore warn for implementing + # __sklearn_tags__. + tag_provider = "_get_tags" + encounter_sklearn_tags = True + + if tag_provider == "_get_tags": + warnings.warn( + f"The {estimator.__class__.__name__} or classes from which it inherits " + "only use `_get_tags` and `_more_tags`. Please define the " + "`__sklearn_tags__` method, or inherit from `sklearn.base.BaseEstimator` " + "and other appropriate mixins such as `sklearn.base.TransformerMixin`, " + "`sklearn.base.ClassifierMixin`, `sklearn.base.RegressorMixin`, and " + "`sklearn.base.OutlierMixin`. From scikit-learn 1.7, not defining " + "`__sklearn_tags__` will raise an error.", + category=FutureWarning, + ) + return tag_provider + + def get_tags(estimator) -> Tags: """Get estimator tags. @@ -316,32 +401,24 @@ def get_tags(estimator) -> Tags: The estimator tags. """ - if hasattr(estimator, "__sklearn_tags__"): + tag_provider = _find_tags_provider(estimator) + if tag_provider == "__sklearn_tags__": tags = estimator.__sklearn_tags__() - elif hasattr(estimator, "_get_tags"): - warnings.warn("BROKEN SOON, IT WILL BE", FutureWarning) - tags = _to_new_tags(estimator._get_tags()) - elif hasattr(estimator, "_more_tags"): - warnings.warn("BROKEN SOON, IT WILL BE", FutureWarning) - tags = _to_old_tags(default_tags(estimator)) - tags = {**tags, **estimator._more_tags()} - tags = _to_new_tags(tags) + # TODO(1.7): Remove this block + elif tag_provider == "_get_tags": + if hasattr(estimator, "_get_tags"): + tags = _to_new_tags(estimator._get_tags()) + elif hasattr(estimator, "_more_tags"): + tags = _to_old_tags(default_tags(estimator)) + tags = {**tags, **estimator._more_tags()} + tags = _to_new_tags(tags) else: - warnings.warn( - f"Estimator {estimator} has no __sklearn_tags__ attribute, which is " - "defined in `sklearn.base.BaseEstimator`. This will raise an error in " - "scikit-learn 1.8. Please define the __sklearn_tags__ method, or inherit " - "from `sklearn.base.BaseEstimator` and other appropriate mixins such as " - "`sklearn.base.TransformerMixin`, `sklearn.base.ClassifierMixin`, " - "`sklearn.base.RegressorMixin`, and `sklearn.base.ClusterMixin`, and " - "`sklearn.base.OutlierMixin`.", - category=FutureWarning, - ) tags = default_tags(estimator) return tags +# TODO(1.7): Remove this function def _to_new_tags(old_tags, estimator_type=None): """Utility function convert old tags (dictionary) to new tags (dataclass).""" input_tags = InputTags( @@ -391,6 +468,7 @@ def _to_new_tags(old_tags, estimator_type=None): ) +# TODO(1.7): Remove this function def _to_old_tags(new_tags): """Utility function convert old tags (dictionary) to new tags (dataclass).""" if new_tags.classifier_tags: @@ -421,8 +499,7 @@ def _to_old_tags(new_tags): "multilabel": multilabel_clf or multilabel_reg, "multioutput": new_tags.target_tags.multi_output, "multioutput_only": ( - not new_tags.target_tags.single_output - and new_tags.target_tags.multi_output + not new_tags.target_tags.single_output and new_tags.target_tags.multi_output ), "no_validation": new_tags.no_validation, "non_deterministic": new_tags.non_deterministic, @@ -459,6 +536,7 @@ def _to_old_tags(new_tags): return tags +# TODO(1.7): Remove this function def _safe_tags(estimator, key=None): warnings.warn( "The `_safe_tags` utility function is deprecated in 1.6 and will be removed in " @@ -466,22 +544,13 @@ def _safe_tags(estimator, key=None): "the `__sklearn_tags__` method.", category=FutureWarning, ) - if hasattr(estimator, "_get_tags"): - tags_provider = "_get_tags()" - tags = estimator._get_tags() - elif hasattr(estimator, "_more_tags"): - tags_provider = "_more_tags()" - tags = _to_old_tags(default_tags(estimator)) - tags = {**tags, **estimator._more_tags()} - else: - tags_provider = "_DEFAULT_TAGS" - tags = _to_old_tags(default_tags(estimator)) + tags = _to_old_tags(get_tags(estimator)) if key is not None: if key not in tags: raise ValueError( - f"The key {key} is not defined in {tags_provider} for the " - f"class {estimator.__class__.__name__}." + f"The key {key} is not defined for the class " + f"{estimator.__class__.__name__}." ) return tags[key] return tags From 6366366b7e3852340d67c425b2cb6e67937cbd64 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 19 Nov 2024 19:32:16 +0100 Subject: [PATCH 03/24] versioon --- sklearn/base.py | 8 ++++---- sklearn/utils/_tags.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/sklearn/base.py b/sklearn/base.py index 5ae53898f862e..1c66ca64c6087 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -399,16 +399,16 @@ def __sklearn_tags__(self): def _more_tags(self): warnings.warn( - "The `_more_tags` method is deprecated in 1.8 and will be removed in " - "1.9. Please implement the `__sklearn_tags__` method.", + "The `_more_tags` method is deprecated in 1.6 and will be removed in " + "1.7. Please implement the `__sklearn_tags__` method.", category=FutureWarning, ) return _to_old_tags(default_tags(self)) def _get_tags(self): warnings.warn( - "The `_get_tags` tag provider is deprecated in 1.8 and will be removed in " - "1.9. Please implement the `__sklearn_tags__` method.", + "The `_get_tags` tag provider is deprecated in 1.6 and will be removed in " + "1.7. Please implement the `__sklearn_tags__` method.", category=FutureWarning, ) collected_tags = {} diff --git a/sklearn/utils/_tags.py b/sklearn/utils/_tags.py index dd75c6670939c..4e7fded9e26f8 100644 --- a/sklearn/utils/_tags.py +++ b/sklearn/utils/_tags.py @@ -245,7 +245,7 @@ class Tags: input_tags: InputTags = field(default_factory=InputTags) -# TODO(1.8): Remove this function +# TODO(1.7): Remove this function def default_tags(estimator) -> Tags: """Get the default tags for an estimator. From 4ea106d59d5bbb374dcd73cbda9621c2e3163a7d Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 19 Nov 2024 19:33:19 +0100 Subject: [PATCH 04/24] correct version removal --- sklearn/base.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sklearn/base.py b/sklearn/base.py index 1c66ca64c6087..645719a3d9806 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -503,7 +503,7 @@ class ClassifierMixin: # TODO(1.8): Remove this attribute _estimator_type = "classifier" - # TODO(1.8): Remove this method + # TODO(1.7): Remove this method def _more_tags(self): return {"requires_y": True} @@ -580,7 +580,7 @@ class RegressorMixin: # TODO(1.8): Remove this attribute _estimator_type = "regressor" - # TODO(1.8): Remove this method + # TODO(1.7): Remove this method def _more_tags(self): return {"requires_y": True} @@ -660,7 +660,7 @@ class ClusterMixin: # TODO(1.8): Remove this attribute _estimator_type = "clusterer" - # TODO(1.8): Remove this method + # TODO(1.7): Remove this method def _more_tags(self): return {"preserves_dtype": []} @@ -1159,7 +1159,7 @@ class MetaEstimatorMixin: class MultiOutputMixin: """Mixin to mark estimators that support multioutput.""" - # TODO(1.8): Remove this method + # TODO(1.7): Remove this method def _more_tags(self): return {"multioutput": True} @@ -1172,7 +1172,7 @@ def __sklearn_tags__(self): class _UnstableArchMixin: """Mark estimators that are non-determinstic on 32bit or PowerPC""" - # TODO(1.8): Remove this method + # TODO(1.7): Remove this method def _more_tags(self): return { "non_deterministic": _IS_32BIT From 109c214de0bd911614819b49bd20f4e4b9ea6893 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 19 Nov 2024 19:53:55 +0100 Subject: [PATCH 05/24] iter --- sklearn/utils/_tags.py | 14 ++++++++++++++ sklearn/utils/estimator_checks.py | 3 ++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/sklearn/utils/_tags.py b/sklearn/utils/_tags.py index 4e7fded9e26f8..3fd7d68696ddb 100644 --- a/sklearn/utils/_tags.py +++ b/sklearn/utils/_tags.py @@ -404,6 +404,20 @@ def get_tags(estimator) -> Tags: tag_provider = _find_tags_provider(estimator) if tag_provider == "__sklearn_tags__": tags = estimator.__sklearn_tags__() + + # TODO (1.7): Remove this block + # Catch the corner case where a transformer inheriting from BaseEstimator but + # that does not inherit from TransformerMixin ends up without the + # transformer_tags set properly. + if hasattr(estimator, "transform") or hasattr(estimator, "fit_transform"): + warnings.warn( + "The transformer tags are not set properly for the estimator " + f"{estimator.__class__.__name__}. This will raise an error in " + "scikit-learn 1.7. Inherit from `TransformerMixin` or properly set " + "the `transformer_tags` attribute in `__sklearn_tags__`.", + category=FutureWarning, + ) + tags.transformer_tags = TransformerTags() # TODO(1.7): Remove this block elif tag_provider == "_get_tags": if hasattr(estimator, "_get_tags"): diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 622c8f04212ec..544588b02ea24 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -121,7 +121,8 @@ def _raise_for_missing_tags(estimator, tag_name, Mixin): # f" or inherit from the {Mixin.__name__}. Note that the order of inheritance" # f" matters, the {Mixin.__name__} should come before BaseEstimator." # ) - warnings.warn("BROKEN SOON, IT WILL BE", FutureWarning) + # warnings.warn("BROKEN SOON, IT WILL BE", FutureWarning) + pass From b3c7be9a94c6d7c1ae36c0234492cafcb7bc2d22 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 19 Nov 2024 20:08:32 +0100 Subject: [PATCH 06/24] iter --- sklearn/base.py | 6 ++++++ sklearn/utils/_tags.py | 16 +++++----------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/sklearn/base.py b/sklearn/base.py index 645719a3d9806..1eefb5a6bd57f 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -419,6 +419,12 @@ def _get_tags(self): # (i.e. calling more tags on BaseEstimator multiple times) more_tags = base_class._more_tags(self) collected_tags.update(more_tags) + elif hasattr(base_class, "__sklearn_tags__"): + # Since that some people will inherit from scikit-learn that implements + # the new infrastructure, we need to collect it and merge it with + # the old tags. + more_tags = base_class.__sklearn_tags__(self) + collected_tags.update(_to_old_tags(more_tags)) return collected_tags def _validate_params(self): diff --git a/sklearn/utils/_tags.py b/sklearn/utils/_tags.py index 3fd7d68696ddb..8f3f230b01b83 100644 --- a/sklearn/utils/_tags.py +++ b/sklearn/utils/_tags.py @@ -340,16 +340,7 @@ def _find_tags_provider(estimator): ) has_sklearn_tags = "__sklearn_tags__" in tags_mro[klass] - if ( - tags_mro[klass] # is it empty - and tag_provider == "_get_tags" - and not has_get_or_more_tags - and has_sklearn_tags - ): - # Case where a class in the middle implements only __sklearn_tags__ but we - # already fallback to _get_tags. There is no safe way to resolve the tags. - raise ValueError(err_msg) - elif tags_mro[klass] and tag_provider == "__sklearn_tags__": # is it empty + if tags_mro[klass] and tag_provider == "__sklearn_tags__": # is it empty if has_get_or_more_tags and not has_sklearn_tags: if encounter_sklearn_tags: # One of the child class already implemented __sklearn_tags__ @@ -403,13 +394,16 @@ def get_tags(estimator) -> Tags: tag_provider = _find_tags_provider(estimator) if tag_provider == "__sklearn_tags__": + from sklearn.base import TransformerMixin # avoid circular dependency tags = estimator.__sklearn_tags__() # TODO (1.7): Remove this block # Catch the corner case where a transformer inheriting from BaseEstimator but # that does not inherit from TransformerMixin ends up without the # transformer_tags set properly. - if hasattr(estimator, "transform") or hasattr(estimator, "fit_transform"): + if ( + hasattr(estimator, "transform") or hasattr(estimator, "fit_transform") + ) and not isinstance(estimator, TransformerMixin): warnings.warn( "The transformer tags are not set properly for the estimator " f"{estimator.__class__.__name__}. This will raise an error in " From bb240911bc5812eacb2a95b352bdccbc18dcb807 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 19 Nov 2024 20:41:56 +0100 Subject: [PATCH 07/24] iter --- sklearn/utils/estimator_checks.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 544588b02ea24..e1851a0a623e1 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -115,13 +115,12 @@ def _raise_for_missing_tags(estimator, tag_name, Mixin): tags = get_tags(estimator) estimator_type = Mixin.__name__.replace("Mixin", "") if getattr(tags, tag_name) is None: - # raise RuntimeError( - # f"Estimator {estimator.__class__.__name__} seems to be a {estimator_type}," - # f" but the `{tag_name}` tag is not set. Either set the tag manually" - # f" or inherit from the {Mixin.__name__}. Note that the order of inheritance" - # f" matters, the {Mixin.__name__} should come before BaseEstimator." - # ) - # warnings.warn("BROKEN SOON, IT WILL BE", FutureWarning) + raise RuntimeError( + f"Estimator {estimator.__class__.__name__} seems to be a {estimator_type}," + f" but the `{tag_name}` tag is not set. Either set the tag manually" + f" or inherit from the {Mixin.__name__}. Note that the order of inheritance" + f" matters, the {Mixin.__name__} should come before BaseEstimator." + ) pass @@ -279,8 +278,6 @@ def _yield_regressor_checks(regressor): def _yield_transformer_checks(transformer): _raise_for_missing_tags(transformer, "transformer_tags", TransformerMixin) tags = get_tags(transformer) - print(transformer) - print(tags) # All transformers should either deal with sparse data or raise an # exception with type TypeError and an intelligible error message if not tags.no_validation: From 3d107ebd7ee19f9f506092a96ce155e86727fce8 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 19 Nov 2024 20:47:20 +0100 Subject: [PATCH 08/24] lint --- sklearn/base.py | 2 +- sklearn/utils/_tags.py | 1 + sklearn/utils/estimator_checks.py | 2 -- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/sklearn/base.py b/sklearn/base.py index 1eefb5a6bd57f..47b66da5965ec 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -26,9 +26,9 @@ Tags, TargetTags, TransformerTags, + _to_old_tags, default_tags, get_tags, - _to_old_tags, ) from .utils.fixes import _IS_32BIT from .utils.validation import ( diff --git a/sklearn/utils/_tags.py b/sklearn/utils/_tags.py index 8f3f230b01b83..92e9acd66d0cf 100644 --- a/sklearn/utils/_tags.py +++ b/sklearn/utils/_tags.py @@ -395,6 +395,7 @@ def get_tags(estimator) -> Tags: tag_provider = _find_tags_provider(estimator) if tag_provider == "__sklearn_tags__": from sklearn.base import TransformerMixin # avoid circular dependency + tags = estimator.__sklearn_tags__() # TODO (1.7): Remove this block diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index e1851a0a623e1..abf272e955bc2 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -121,8 +121,6 @@ def _raise_for_missing_tags(estimator, tag_name, Mixin): f" or inherit from the {Mixin.__name__}. Note that the order of inheritance" f" matters, the {Mixin.__name__} should come before BaseEstimator." ) - pass - def _yield_api_checks(estimator): From ae20c685e382d001a8b5041d4e17197d35be8074 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 19 Nov 2024 21:54:01 +0100 Subject: [PATCH 09/24] call new API in old function when it makes sens --- sklearn/base.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/sklearn/base.py b/sklearn/base.py index 47b66da5965ec..6ed79dcb665bf 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -411,6 +411,13 @@ def _get_tags(self): "1.7. Please implement the `__sklearn_tags__` method.", category=FutureWarning, ) + # In case a user called `_get_tags` but that the estimator already did the job + # implementing `__sklearn_tags__` completely, let's default back to the future + # behaviour. + from sklearn.utils._tags import get_tags, _find_tags_provider, _to_old_tags + if _find_tags_provider(self) == "__sklearn_tags__": + return _to_old_tags(get_tags(self)) + collected_tags = {} for base_class in reversed(inspect.getmro(self.__class__)): if hasattr(base_class, "_more_tags"): From fe55288fc1ad1da7075332af20eb9edbbd79d238 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 19 Nov 2024 21:57:53 +0100 Subject: [PATCH 10/24] lint --- sklearn/base.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sklearn/base.py b/sklearn/base.py index 6ed79dcb665bf..a0e8cd13350ab 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -412,9 +412,10 @@ def _get_tags(self): category=FutureWarning, ) # In case a user called `_get_tags` but that the estimator already did the job - # implementing `__sklearn_tags__` completely, let's default back to the future - # behaviour. - from sklearn.utils._tags import get_tags, _find_tags_provider, _to_old_tags + # implementing `__sklearn_tags__` completely and removed `_more_tags`, let's + # default back to the future behaviour. Otherwise, we will get the default tags. + from sklearn.utils._tags import _find_tags_provider, _to_old_tags, get_tags + if _find_tags_provider(self) == "__sklearn_tags__": return _to_old_tags(get_tags(self)) From 7aa99818b2d253a62249e490fb5537c75ec1cac8 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 19 Nov 2024 23:44:24 +0100 Subject: [PATCH 11/24] iter --- sklearn/utils/_tags.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/utils/_tags.py b/sklearn/utils/_tags.py index 92e9acd66d0cf..3e86e7a7d10b5 100644 --- a/sklearn/utils/_tags.py +++ b/sklearn/utils/_tags.py @@ -404,7 +404,7 @@ def get_tags(estimator) -> Tags: # transformer_tags set properly. if ( hasattr(estimator, "transform") or hasattr(estimator, "fit_transform") - ) and not isinstance(estimator, TransformerMixin): + ) and tags.transformer_tags is None: warnings.warn( "The transformer tags are not set properly for the estimator " f"{estimator.__class__.__name__}. This will raise an error in " From 98ac13326fa63c46995779d0b508e67986c069d7 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 19 Nov 2024 23:55:44 +0100 Subject: [PATCH 12/24] iter --- sklearn/utils/_tags.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/sklearn/utils/_tags.py b/sklearn/utils/_tags.py index 3e86e7a7d10b5..3b05b536c8351 100644 --- a/sklearn/utils/_tags.py +++ b/sklearn/utils/_tags.py @@ -394,8 +394,6 @@ def get_tags(estimator) -> Tags: tag_provider = _find_tags_provider(estimator) if tag_provider == "__sklearn_tags__": - from sklearn.base import TransformerMixin # avoid circular dependency - tags = estimator.__sklearn_tags__() # TODO (1.7): Remove this block From 7f6b14bf8fc73ebbb85d2b9ecb8eb2ae2de2eb08 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 20 Nov 2024 01:02:23 +0100 Subject: [PATCH 13/24] silence some warning and make some tests --- sklearn/base.py | 28 ++++++- sklearn/utils/_tags.py | 27 ++++++- sklearn/utils/tests/test_tags.py | 128 +++++++++++++++++++++++++++++++ 3 files changed, 177 insertions(+), 6 deletions(-) diff --git a/sklearn/base.py b/sklearn/base.py index a0e8cd13350ab..6f26dd84c5e71 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -389,6 +389,22 @@ def __setstate__(self, state): self.__dict__.update(state) def __sklearn_tags__(self): + from sklearn.utils._tags import _find_tags_provider, _to_new_tags + + # TODO(1.7): Remove this block + if _find_tags_provider(self) == "_get_tags": + # one of the children classes only implements `_get_tags` so we need to + # warn and and mix old-style and new-style tags. + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + category=FutureWarning, + message="The `_get_tags` tag provider is deprecated", + ) + # silence the warning to avoid a false positive warning since this + # is not a direct user call but an internal one. + return _to_new_tags(self._get_tags()) + return Tags( estimator_type=None, target_tags=TargetTags(required=False), @@ -416,7 +432,7 @@ def _get_tags(self): # default back to the future behaviour. Otherwise, we will get the default tags. from sklearn.utils._tags import _find_tags_provider, _to_old_tags, get_tags - if _find_tags_provider(self) == "__sklearn_tags__": + if _find_tags_provider(self, warn=False) == "__sklearn_tags__": return _to_old_tags(get_tags(self)) collected_tags = {} @@ -425,7 +441,15 @@ def _get_tags(self): # need the if because mixins might not have _more_tags # but might do redundant work in estimators # (i.e. calling more tags on BaseEstimator multiple times) - more_tags = base_class._more_tags(self) + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + category=FutureWarning, + message="The `_more_tags` method is deprecated", + ) + # silence the warning to avoid a false positive warning since this + # is not a direct user call but an internal one. + more_tags = base_class._more_tags(self) collected_tags.update(more_tags) elif hasattr(base_class, "__sklearn_tags__"): # Since that some people will inherit from scikit-learn that implements diff --git a/sklearn/utils/_tags.py b/sklearn/utils/_tags.py index 3b05b536c8351..1cf8cd2d55afb 100644 --- a/sklearn/utils/_tags.py +++ b/sklearn/utils/_tags.py @@ -293,7 +293,7 @@ def default_tags(estimator) -> Tags: # TODO(1.7): Remove this function -def _find_tags_provider(estimator): +def _find_tags_provider(estimator, warn=True): """Find the tags provider for an estimator. Parameters @@ -301,6 +301,9 @@ def _find_tags_provider(estimator): estimator : estimator object The estimator to find the tags provider for. + warn : bool, default=True + Whether to warn if the tags provider is not found. + Returns ------- tag_provider : str @@ -352,7 +355,7 @@ def _find_tags_provider(estimator): tag_provider = "_get_tags" encounter_sklearn_tags = True - if tag_provider == "_get_tags": + if warn and tag_provider == "_get_tags": warnings.warn( f"The {estimator.__class__.__name__} or classes from which it inherits " "only use `_get_tags` and `_more_tags`. Please define the " @@ -414,10 +417,26 @@ def get_tags(estimator) -> Tags: # TODO(1.7): Remove this block elif tag_provider == "_get_tags": if hasattr(estimator, "_get_tags"): - tags = _to_new_tags(estimator._get_tags()) + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + category=FutureWarning, + message="The `_get_tags` tag provider is deprecated", + ) + # silence the warning to avoid a false positive warning since this + # is not a direct user call but an internal one. + tags = _to_new_tags(estimator._get_tags()) elif hasattr(estimator, "_more_tags"): tags = _to_old_tags(default_tags(estimator)) - tags = {**tags, **estimator._more_tags()} + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + category=FutureWarning, + message="The `_more_tags` method is deprecated", + ) + # silence the warning to avoid a false positive warning since this + # is not a direct user call but an internal one. + tags = {**tags, **estimator._more_tags()} tags = _to_new_tags(tags) else: tags = default_tags(estimator) diff --git a/sklearn/utils/tests/test_tags.py b/sklearn/utils/tests/test_tags.py index 413fbc6bbd3de..e6980b008bb11 100644 --- a/sklearn/utils/tests/test_tags.py +++ b/sklearn/utils/tests/test_tags.py @@ -8,6 +8,7 @@ TransformerMixin, ) from sklearn.utils import Tags, get_tags +from sklearn.utils._tags import _to_new_tags, _to_old_tags, _safe_tags from sklearn.utils.estimator_checks import ( check_estimator_tags_renamed, check_valid_tag_types, @@ -78,3 +79,130 @@ def __sklearn_tags__(self): return tags check_valid_tag_types("MyEstimator", MyEstimator()) + + +######################################################################################## +# Test for the deprecation +# TODO(1.7): Remove this +######################################################################################## + +def test_tags_deprecation(): + class ChildClass(RegressorMixin, BaseEstimator): + """Child implementing the old tags API together with our new API.""" + + def _more_tags(self): + return {"allow_nan": True} + + main_warn_msg = "only use `_get_tags` and `_more_tags`" + with pytest.warns(FutureWarning, match=main_warn_msg): + tags = ChildClass().__sklearn_tags__() + assert tags.input_tags.allow_nan + + with pytest.warns(FutureWarning) as warning_list: + tags = _safe_tags(ChildClass()) + assert len(warning_list) == 2, len(warning_list) + assert str(warning_list[0].message).startswith( + "The `_safe_tags` utility function is deprecated" + ) + assert main_warn_msg in str(warning_list[1].message) + + assert isinstance(tags, dict) + assert _to_new_tags(tags).input_tags.allow_nan + + with pytest.warns(FutureWarning, match=main_warn_msg): + tags = get_tags(ChildClass()) + assert tags.input_tags.allow_nan + + class ChildClass(RegressorMixin, BaseEstimator): + """Child implementing the old and new tags API during the transition period.""" + + def _more_tags(self): + return {"allow_nan": True} + + def __sklearn_tags__(self): + tags = super().__sklearn_tags__() + tags.input_tags.allow_nan = True + return tags + + tags = get_tags(ChildClass()) + assert tags.input_tags.allow_nan + + warn_msg = "`_get_tags` tag provider is deprecated" + with pytest.warns(FutureWarning, match=warn_msg): + tags = ChildClass()._get_tags() + assert isinstance(tags, dict) + assert _to_new_tags(tags).input_tags.allow_nan + + warn_msg = "`_safe_tags` utility function is deprecated" + with pytest.warns(FutureWarning, match=warn_msg): + tags = _safe_tags(ChildClass()) + assert isinstance(tags, dict) + assert _to_new_tags(tags).input_tags.allow_nan + + class ChildClass(RegressorMixin, BaseEstimator): + """Child not setting any tags.""" + + tags = get_tags(ChildClass()) + assert tags.target_tags.required + + warn_msg = "`_get_tags` tag provider is deprecated" + with pytest.warns(FutureWarning, match=warn_msg): + tags = ChildClass()._get_tags() + assert isinstance(tags, dict) + assert _to_new_tags(tags).target_tags.required + + warn_msg = "`_safe_tags` utility function is deprecated" + with pytest.warns(FutureWarning, match=warn_msg): + tags = _safe_tags(ChildClass()) + assert isinstance(tags, dict) + assert _to_new_tags(tags).target_tags.required + + class Mixin: + def _more_tags(self): + return {"allow_nan": True} + + class ChildClass(Mixin, BaseEstimator): + """Child following the new API with mixin following the old API.""" + + def __sklearn_tags__(self): + tags = super().__sklearn_tags__() + tags.target_tags.required = True + return tags + + err_msg = ( + "Some classes from which ChildClass inherits only use `_get_tags` and " + "`_more_tags`" + ) + with pytest.raises(ValueError, match=err_msg): + tags = get_tags(ChildClass()) + with pytest.raises(ValueError, match=err_msg): + with pytest.warns(FutureWarning): + tags = ChildClass()._get_tags() + with pytest.raises(ValueError, match=err_msg): + with pytest.warns(FutureWarning): + tags = _safe_tags(ChildClass()) + + class Mixin: + def _more_tags(self): + return {"allow_nan": True} + + class ChildClass(Mixin, BaseEstimator): + """Child following the old API with mixin following the old API.""" + + def _more_tags(self): + return {"requires_y": True} + + with pytest.warns(FutureWarning, match=main_warn_msg): + tags = ChildClass().__sklearn_tags__() + assert tags.input_tags.allow_nan + + with pytest.warns(FutureWarning) as warning_list: + tags = _safe_tags(ChildClass()) + assert len(warning_list) == 2, len(warning_list) + assert str(warning_list[0].message).startswith( + "The `_safe_tags` utility function is deprecated" + ) + assert main_warn_msg in str(warning_list[1].message) + + assert isinstance(tags, dict) + assert _to_new_tags(tags).input_tags.allow_nan \ No newline at end of file From c68f368fc9256ec40409219a98b236b46c1c5fc7 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 20 Nov 2024 01:09:04 +0100 Subject: [PATCH 14/24] lint --- sklearn/utils/tests/test_tags.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sklearn/utils/tests/test_tags.py b/sklearn/utils/tests/test_tags.py index e6980b008bb11..1ddefcaf17489 100644 --- a/sklearn/utils/tests/test_tags.py +++ b/sklearn/utils/tests/test_tags.py @@ -8,7 +8,7 @@ TransformerMixin, ) from sklearn.utils import Tags, get_tags -from sklearn.utils._tags import _to_new_tags, _to_old_tags, _safe_tags +from sklearn.utils._tags import _safe_tags, _to_new_tags from sklearn.utils.estimator_checks import ( check_estimator_tags_renamed, check_valid_tag_types, @@ -86,6 +86,7 @@ def __sklearn_tags__(self): # TODO(1.7): Remove this ######################################################################################## + def test_tags_deprecation(): class ChildClass(RegressorMixin, BaseEstimator): """Child implementing the old tags API together with our new API.""" @@ -205,4 +206,4 @@ def _more_tags(self): assert main_warn_msg in str(warning_list[1].message) assert isinstance(tags, dict) - assert _to_new_tags(tags).input_tags.allow_nan \ No newline at end of file + assert _to_new_tags(tags).input_tags.allow_nan From a46e5934cdbfd36a78d8655b37aec83b03fd04d6 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 20 Nov 2024 11:08:10 +0100 Subject: [PATCH 15/24] iter --- sklearn/utils/_tags.py | 2 + sklearn/utils/tests/test_estimator_checks.py | 45 +++++++++++++++++--- 2 files changed, 41 insertions(+), 6 deletions(-) diff --git a/sklearn/utils/_tags.py b/sklearn/utils/_tags.py index 1cf8cd2d55afb..2bf43a1f0401c 100644 --- a/sklearn/utils/_tags.py +++ b/sklearn/utils/_tags.py @@ -438,6 +438,8 @@ def get_tags(estimator) -> Tags: # is not a direct user call but an internal one. tags = {**tags, **estimator._more_tags()} tags = _to_new_tags(tags) + else: + tags = default_tags(estimator) else: tags = default_tags(estimator) diff --git a/sklearn/utils/tests/test_estimator_checks.py b/sklearn/utils/tests/test_estimator_checks.py index d09b3e7f366ec..05104c042d3be 100644 --- a/sklearn/utils/tests/test_estimator_checks.py +++ b/sklearn/utils/tests/test_estimator_checks.py @@ -39,6 +39,7 @@ from sklearn.svm import SVC, NuSVC from sklearn.utils import _array_api, all_estimators, deprecated from sklearn.utils._param_validation import Interval, StrOptions +from sklearn.utils._tags import TransformerTags from sklearn.utils._test_common.instance_generator import ( _construct_instances, _get_expected_failed_checks, @@ -314,7 +315,7 @@ def fit(self, X, y): return self -class BadTransformerWithoutMixin(BaseEstimator): +class BadTransformerWithoutMixinWithTags(BaseEstimator): def fit(self, X, y=None): X = validate_data(self, X) return self @@ -324,6 +325,30 @@ def transform(self, X): X = validate_data(self, X, reset=False) return X + def __sklearn_tags__(self): + tags = super().__sklearn_tags__() + tags.transformer_tags = TransformerTags() + return tags + + +class BadTransformerWithoutMixinWithoutTags(BaseEstimator): + """Transformer that does not implement `fit_transform` and the tags. + + TODO(1.7): + In 1.6, it will raise an AttributeError for `fit_transform` and a warning to + mention that the `transformer_tags` tag is not set. + As for 1.7, it will raise a RuntimeError because the tag is not set. + """ + def fit(self, X, y=None): + X = validate_data(self, X) + return self + + def transform(self, X): + check_is_fitted(self) + X = validate_data(self, X, reset=False) + return X + + class NotInvariantPredict(BaseEstimator): def fit(self, X, y): @@ -844,11 +869,19 @@ def test_check_outlier_corruption(): check_outlier_corruption(1, 2, decision) -def test_check_estimator_transformer_no_mixin(): - # check that TransformerMixin is not required for transformer tests to run - # but it fails since the tag is not set - with raises(RuntimeError, "the `transformer_tags` tag is not set"): - check_estimator(BadTransformerWithoutMixin()) +def test_check_estimator_transformer_no_mixin_with_tags(): + with raises(AttributeError, ".*fit_transform.*"): + check_estimator(BadTransformerWithoutMixinWithTags()) + + +def test_check_estimator_transformer_no_mixin_without_tags(): + # TODO(1.7): replace the type of exception raised and remove the warning + with raises(AttributeError, ".*fit_transform.*"): + with warnings.catch_warnings(record=True) as record: + check_estimator(BadTransformerWithoutMixinWithoutTags()) + for rec in record: + assert issubclass(rec.category, FutureWarning) + assert "The transformer tags are not set properly" in str(rec.message) def test_check_estimator_clones(): From 55d7603b616a2b3e84ec9628c087b634d2d86f07 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 20 Nov 2024 11:33:42 +0100 Subject: [PATCH 16/24] lint --- sklearn/utils/tests/test_estimator_checks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/utils/tests/test_estimator_checks.py b/sklearn/utils/tests/test_estimator_checks.py index 05104c042d3be..f12493d098152 100644 --- a/sklearn/utils/tests/test_estimator_checks.py +++ b/sklearn/utils/tests/test_estimator_checks.py @@ -339,6 +339,7 @@ class BadTransformerWithoutMixinWithoutTags(BaseEstimator): mention that the `transformer_tags` tag is not set. As for 1.7, it will raise a RuntimeError because the tag is not set. """ + def fit(self, X, y=None): X = validate_data(self, X) return self @@ -349,7 +350,6 @@ def transform(self, X): return X - class NotInvariantPredict(BaseEstimator): def fit(self, X, y): # Convert data From 822e9c8d2eff1ea5a0bf272d5ff7216f96d7eda4 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 20 Nov 2024 17:56:37 +0100 Subject: [PATCH 17/24] fix pipeline --- sklearn/pipeline.py | 44 ++++++++++++++++++++++++---------- sklearn/tests/test_pipeline.py | 4 ++-- 2 files changed, 33 insertions(+), 15 deletions(-) diff --git a/sklearn/pipeline.py b/sklearn/pipeline.py index 9ff8a3549ef28..ab4fa9d9c5be0 100644 --- a/sklearn/pipeline.py +++ b/sklearn/pipeline.py @@ -7,6 +7,7 @@ from collections import Counter, defaultdict from contextlib import contextmanager from copy import deepcopy +from functools import reduce from itertools import chain, islice import numpy as np @@ -23,7 +24,7 @@ _get_container_adapter, _safe_set_output, ) -from .utils._tags import get_tags +from .utils._tags import TransformerTags, get_tags from .utils._user_interface import _print_elapsed_time from .utils.deprecation import _deprecate_Xt_in_inverse_transform from .utils.metadata_routing import ( @@ -1229,18 +1230,35 @@ def __sklearn_tags__(self): # tuples and `fit` is not called yet to validate the steps. pass - try: - if self.steps[-1][1] is not None and self.steps[-1][1] != "passthrough": - last_step_tags = get_tags(self.steps[-1][1]) - tags.estimator_type = last_step_tags.estimator_type - tags.target_tags.multi_output = last_step_tags.target_tags.multi_output - tags.classifier_tags = deepcopy(last_step_tags.classifier_tags) - tags.regressor_tags = deepcopy(last_step_tags.regressor_tags) - tags.transformer_tags = deepcopy(last_step_tags.transformer_tags) - except (ValueError, AttributeError, TypeError): - # This happens when the `steps` is not a list of (name, estimator) - # tuples and `fit` is not called yet to validate the steps. - pass + # try: + # dtype preservation will depend on the intersection of all steps + preserves_dtype = [] + for step in self.steps: + if step[1] is not None and step[1] != "passthrough": + step_tags = get_tags(step[1]) + if step_tags.transformer_tags is not None: + preserves_dtype.append( + set(step_tags.transformer_tags.preserves_dtype) + ) + if preserves_dtype: + preserves_dtype = list(reduce(set.intersection, preserves_dtype)) + + if self.steps[-1][1] is not None and self.steps[-1][1] != "passthrough": + last_step_tags = get_tags(self.steps[-1][1]) + tags.estimator_type = last_step_tags.estimator_type + tags.target_tags.multi_output = last_step_tags.target_tags.multi_output + tags.classifier_tags = deepcopy(last_step_tags.classifier_tags) + tags.regressor_tags = deepcopy(last_step_tags.regressor_tags) + tags.transformer_tags = deepcopy(last_step_tags.transformer_tags) + if tags.transformer_tags is not None: + tags.transformer_tags.preserves_dtype = preserves_dtype + elif self.steps[-1][1] is None or self.steps[-1][1] == "passthrough": + # "passthrough" behaves like a transformer + tags.transformer_tags = TransformerTags(preserves_dtype=[]) + # except (ValueError, AttributeError, TypeError): + # # This happens when the `steps` is not a list of (name, estimator) + # # tuples and `fit` is not called yet to validate the steps. + # pass return tags diff --git a/sklearn/tests/test_pipeline.py b/sklearn/tests/test_pipeline.py index d7a201f3abf6f..130793b7d9a46 100644 --- a/sklearn/tests/test_pipeline.py +++ b/sklearn/tests/test_pipeline.py @@ -2058,7 +2058,7 @@ def transform(self, X): ], ) def test_pipeline_warns_not_fitted(method): - class StatelessEstimator(BaseEstimator): + class StatelessEstimator(TransformerMixin, ClassifierMixin, BaseEstimator): """Stateless estimator that doesn't check if it's fitted. Stateless estimators that don't require fit, should properly set the @@ -2102,7 +2102,7 @@ def inverse_transform(self, X): # ===================================================================== -class SimpleEstimator(BaseEstimator): +class SimpleEstimator(TransformerMixin, ClassifierMixin, BaseEstimator): # This class is used in this section for testing routing in the pipeline. # This class should have every set_{method}_request def __sklearn_is_fitted__(self): From c914e0925ac57c25b974b5dc778b2f5863c11f8c Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 20 Nov 2024 18:26:30 +0100 Subject: [PATCH 18/24] fix vectorizer --- sklearn/feature_extraction/text.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/sklearn/feature_extraction/text.py b/sklearn/feature_extraction/text.py index e1bdfd5a7dee5..a94920024ff98 100644 --- a/sklearn/feature_extraction/text.py +++ b/sklearn/feature_extraction/text.py @@ -16,7 +16,7 @@ import numpy as np import scipy.sparse as sp -from sklearn.utils import metadata_routing +from sklearn.utils import TransformerTags, metadata_routing from ..base import BaseEstimator, OneToOneFeatureMixin, TransformerMixin, _fit_context from ..exceptions import NotFittedError @@ -554,6 +554,11 @@ def _warn_for_unused_params(self): " since 'analyzer' != 'word'" ) + def __sklearn_tags__(self): + tags = super().__sklearn_tags__() + tags.transformer_tags = TransformerTags(preserves_dtype=[]) + return tags + class HashingVectorizer( TransformerMixin, _VectorizerMixin, BaseEstimator, auto_wrap_output_keys=None From 521718cf25fdc263c3c8cf6d8c924291c8afabc2 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 20 Nov 2024 19:10:25 +0100 Subject: [PATCH 19/24] more fixes --- sklearn/feature_selection/tests/test_rfe.py | 4 ++-- sklearn/manifold/_mds.py | 2 ++ sklearn/manifold/_spectral_embedding.py | 2 ++ sklearn/model_selection/_search.py | 1 + sklearn/model_selection/tests/test_search.py | 4 ++-- sklearn/tests/test_metaestimators.py | 7 +++++++ sklearn/utils/tests/test_estimator_checks.py | 1 + sklearn/utils/tests/test_validation.py | 4 ++-- 8 files changed, 19 insertions(+), 6 deletions(-) diff --git a/sklearn/feature_selection/tests/test_rfe.py b/sklearn/feature_selection/tests/test_rfe.py index ae11de2fadf59..dc54bdd6560a7 100644 --- a/sklearn/feature_selection/tests/test_rfe.py +++ b/sklearn/feature_selection/tests/test_rfe.py @@ -9,7 +9,7 @@ from joblib import parallel_backend from numpy.testing import assert_allclose, assert_array_almost_equal, assert_array_equal -from sklearn.base import BaseEstimator, ClassifierMixin, is_classifier +from sklearn.base import BaseEstimator, ClassifierMixin, TransformerMixin, is_classifier from sklearn.compose import TransformedTargetRegressor from sklearn.cross_decomposition import CCA, PLSCanonical, PLSRegression from sklearn.datasets import load_iris, make_classification, make_friedman1 @@ -27,7 +27,7 @@ from sklearn.utils.fixes import CSR_CONTAINERS -class MockClassifier(ClassifierMixin, BaseEstimator): +class MockClassifier(TransformerMixin, ClassifierMixin, BaseEstimator): """ Dummy classifier to test recursive feature elimination """ diff --git a/sklearn/manifold/_mds.py b/sklearn/manifold/_mds.py index dc9f88b502da5..07e7f541fa41a 100644 --- a/sklearn/manifold/_mds.py +++ b/sklearn/manifold/_mds.py @@ -16,6 +16,7 @@ from ..metrics import euclidean_distances from ..utils import check_array, check_random_state, check_symmetric from ..utils._param_validation import Interval, StrOptions, validate_params +from ..utils._tags import TransformerTags from ..utils.parallel import Parallel, delayed from ..utils.validation import validate_data @@ -572,6 +573,7 @@ def __init__( def __sklearn_tags__(self): tags = super().__sklearn_tags__() + tags.transformer_tags = TransformerTags(preserves_dtype=["float64"]) tags.input_tags.pairwise = self.dissimilarity == "precomputed" return tags diff --git a/sklearn/manifold/_spectral_embedding.py b/sklearn/manifold/_spectral_embedding.py index ebd5d7c5b651b..3d7ede8fb9358 100644 --- a/sklearn/manifold/_spectral_embedding.py +++ b/sklearn/manifold/_spectral_embedding.py @@ -23,6 +23,7 @@ ) from ..utils._arpack import _init_arpack_v0 from ..utils._param_validation import Interval, StrOptions, validate_params +from ..utils._tags import TransformerTags from ..utils.extmath import _deterministic_vector_sign_flip from ..utils.fixes import laplacian as csgraph_laplacian from ..utils.fixes import parse_version, sp_version @@ -654,6 +655,7 @@ def __sklearn_tags__(self): "precomputed", "precomputed_nearest_neighbors", ] + tags.transformer_tags = TransformerTags(preserves_dtype=["float64"]) return tags def _get_affinity_matrix(self, X, Y=None): diff --git a/sklearn/model_selection/_search.py b/sklearn/model_selection/_search.py index 7515436af33da..3ee877e1f7597 100644 --- a/sklearn/model_selection/_search.py +++ b/sklearn/model_selection/_search.py @@ -487,6 +487,7 @@ def __sklearn_tags__(self): tags.estimator_type = sub_estimator_tags.estimator_type tags.classifier_tags = deepcopy(sub_estimator_tags.classifier_tags) tags.regressor_tags = deepcopy(sub_estimator_tags.regressor_tags) + tags.transformer_tags = deepcopy(sub_estimator_tags.transformer_tags) # allows cross-validation to see 'precomputed' metrics tags.input_tags.pairwise = get_tags(self.estimator).input_tags.pairwise tags.array_api_support = get_tags(self.estimator).array_api_support diff --git a/sklearn/model_selection/tests/test_search.py b/sklearn/model_selection/tests/test_search.py index 5313e5d28a1a7..6e38536d1f509 100644 --- a/sklearn/model_selection/tests/test_search.py +++ b/sklearn/model_selection/tests/test_search.py @@ -15,7 +15,7 @@ from scipy.stats import bernoulli, expon, uniform from sklearn import config_context -from sklearn.base import BaseEstimator, ClassifierMixin, is_classifier +from sklearn.base import BaseEstimator, ClassifierMixin, TransformerMixin, is_classifier from sklearn.cluster import KMeans from sklearn.compose import ColumnTransformer from sklearn.datasets import ( @@ -100,7 +100,7 @@ # Neither of the following two estimators inherit from BaseEstimator, # to test hyperparameter search on user-defined classifiers. -class MockClassifier(ClassifierMixin, BaseEstimator): +class MockClassifier(TransformerMixin, ClassifierMixin, BaseEstimator): """Dummy classifier to test the parameter search algorithms""" def __init__(self, foo_param=0): diff --git a/sklearn/tests/test_metaestimators.py b/sklearn/tests/test_metaestimators.py index 214fc75a68364..4cf2a05a325dd 100644 --- a/sklearn/tests/test_metaestimators.py +++ b/sklearn/tests/test_metaestimators.py @@ -19,6 +19,7 @@ from sklearn.preprocessing import MaxAbsScaler, StandardScaler from sklearn.semi_supervised import SelfTrainingClassifier from sklearn.utils import all_estimators +from sklearn.utils._tags import TransformerTags from sklearn.utils._test_common.instance_generator import _construct_instances from sklearn.utils._testing import SkipTest, set_random_state from sklearn.utils.estimator_checks import ( @@ -143,6 +144,12 @@ def score(self, X, y, *args, **kwargs): self._check_fit() return 1.0 + def __sklearn_tags__(self): + tags = super().__sklearn_tags__() + if hasattr(self, "transform"): + tags.transformer_tags = TransformerTags(preserves_dtype=[]) + return tags + methods = [ k for k in SubEstimator.__dict__.keys() diff --git a/sklearn/utils/tests/test_estimator_checks.py b/sklearn/utils/tests/test_estimator_checks.py index f12493d098152..4cb03bbfc4aa9 100644 --- a/sklearn/utils/tests/test_estimator_checks.py +++ b/sklearn/utils/tests/test_estimator_checks.py @@ -878,6 +878,7 @@ def test_check_estimator_transformer_no_mixin_without_tags(): # TODO(1.7): replace the type of exception raised and remove the warning with raises(AttributeError, ".*fit_transform.*"): with warnings.catch_warnings(record=True) as record: + warnings.filterwarnings("ignore", category=FutureWarning) check_estimator(BadTransformerWithoutMixinWithoutTags()) for rec in record: assert issubclass(rec.category, FutureWarning) diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index 669e40e137e17..e64f0ac4efb99 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -15,7 +15,7 @@ import sklearn from sklearn._config import config_context from sklearn._min_dependencies import dependent_packages -from sklearn.base import BaseEstimator +from sklearn.base import BaseEstimator, TransformerMixin from sklearn.datasets import make_blobs from sklearn.ensemble import RandomForestRegressor from sklearn.exceptions import NotFittedError, PositiveSpectrumWarning @@ -1990,7 +1990,7 @@ def test_get_feature_names_invalid_dtypes(names, dtypes): names = _get_feature_names(X) -class PassthroughTransformer(BaseEstimator): +class PassthroughTransformer(TransformerMixin, BaseEstimator): def fit(self, X, y=None): validate_data(self, X, reset=True) return self From dfe7c1fe192eaf8b1ad0139ebab4a0edbfb5da00 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 20 Nov 2024 19:12:37 +0100 Subject: [PATCH 20/24] add comments --- sklearn/base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sklearn/base.py b/sklearn/base.py index 6f26dd84c5e71..76f72c684b0e8 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -413,6 +413,7 @@ def __sklearn_tags__(self): classifier_tags=None, ) + # TODO(1.7): Remove this method def _more_tags(self): warnings.warn( "The `_more_tags` method is deprecated in 1.6 and will be removed in " @@ -421,6 +422,7 @@ def _more_tags(self): ) return _to_old_tags(default_tags(self)) + # TODO(1.7): Remove this method def _get_tags(self): warnings.warn( "The `_get_tags` tag provider is deprecated in 1.6 and will be removed in " From 5c9f23f2cb2efd39ca125e7ab99e367f243b3eb5 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 20 Nov 2024 19:13:53 +0100 Subject: [PATCH 21/24] put back try/except --- sklearn/pipeline.py | 58 ++++++++++++++++++++++----------------------- 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/sklearn/pipeline.py b/sklearn/pipeline.py index ab4fa9d9c5be0..96869d61342d6 100644 --- a/sklearn/pipeline.py +++ b/sklearn/pipeline.py @@ -1230,35 +1230,35 @@ def __sklearn_tags__(self): # tuples and `fit` is not called yet to validate the steps. pass - # try: - # dtype preservation will depend on the intersection of all steps - preserves_dtype = [] - for step in self.steps: - if step[1] is not None and step[1] != "passthrough": - step_tags = get_tags(step[1]) - if step_tags.transformer_tags is not None: - preserves_dtype.append( - set(step_tags.transformer_tags.preserves_dtype) - ) - if preserves_dtype: - preserves_dtype = list(reduce(set.intersection, preserves_dtype)) - - if self.steps[-1][1] is not None and self.steps[-1][1] != "passthrough": - last_step_tags = get_tags(self.steps[-1][1]) - tags.estimator_type = last_step_tags.estimator_type - tags.target_tags.multi_output = last_step_tags.target_tags.multi_output - tags.classifier_tags = deepcopy(last_step_tags.classifier_tags) - tags.regressor_tags = deepcopy(last_step_tags.regressor_tags) - tags.transformer_tags = deepcopy(last_step_tags.transformer_tags) - if tags.transformer_tags is not None: - tags.transformer_tags.preserves_dtype = preserves_dtype - elif self.steps[-1][1] is None or self.steps[-1][1] == "passthrough": - # "passthrough" behaves like a transformer - tags.transformer_tags = TransformerTags(preserves_dtype=[]) - # except (ValueError, AttributeError, TypeError): - # # This happens when the `steps` is not a list of (name, estimator) - # # tuples and `fit` is not called yet to validate the steps. - # pass + try: + # dtype preservation will depend on the intersection of all steps + preserves_dtype = [] + for step in self.steps: + if step[1] is not None and step[1] != "passthrough": + step_tags = get_tags(step[1]) + if step_tags.transformer_tags is not None: + preserves_dtype.append( + set(step_tags.transformer_tags.preserves_dtype) + ) + if preserves_dtype: + preserves_dtype = list(reduce(set.intersection, preserves_dtype)) + + if self.steps[-1][1] is not None and self.steps[-1][1] != "passthrough": + last_step_tags = get_tags(self.steps[-1][1]) + tags.estimator_type = last_step_tags.estimator_type + tags.target_tags.multi_output = last_step_tags.target_tags.multi_output + tags.classifier_tags = deepcopy(last_step_tags.classifier_tags) + tags.regressor_tags = deepcopy(last_step_tags.regressor_tags) + tags.transformer_tags = deepcopy(last_step_tags.transformer_tags) + if tags.transformer_tags is not None: + tags.transformer_tags.preserves_dtype = preserves_dtype + elif self.steps[-1][1] is None or self.steps[-1][1] == "passthrough": + # "passthrough" behaves like a transformer + tags.transformer_tags = TransformerTags(preserves_dtype=[]) + except (ValueError, AttributeError, TypeError): + # This happens when the `steps` is not a list of (name, estimator) + # tuples and `fit` is not called yet to validate the steps. + pass return tags From 8dbb35805ce72c223f316606f57fbefd694e523b Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 20 Nov 2024 19:14:45 +0100 Subject: [PATCH 22/24] comment --- sklearn/pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/pipeline.py b/sklearn/pipeline.py index 96869d61342d6..0b52559a92334 100644 --- a/sklearn/pipeline.py +++ b/sklearn/pipeline.py @@ -1253,7 +1253,7 @@ def __sklearn_tags__(self): if tags.transformer_tags is not None: tags.transformer_tags.preserves_dtype = preserves_dtype elif self.steps[-1][1] is None or self.steps[-1][1] == "passthrough": - # "passthrough" behaves like a transformer + # None and "passthrough" behave like a transformer tags.transformer_tags = TransformerTags(preserves_dtype=[]) except (ValueError, AttributeError, TypeError): # This happens when the `steps` is not a list of (name, estimator) From 2bf876e1324b97a11ddcd3837ebe4bf2182ec76a Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 20 Nov 2024 19:47:39 +0100 Subject: [PATCH 23/24] last fix --- sklearn/tests/test_common.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index 59b45b93a7e24..63474daeca37f 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -35,7 +35,7 @@ OneHotEncoder, StandardScaler, ) -from sklearn.utils import all_estimators +from sklearn.utils import TransformerMixin, all_estimators from sklearn.utils._test_common.instance_generator import ( _get_check_estimator_ids, _get_expected_failed_checks, @@ -412,7 +412,7 @@ def test_transition_public_api_deprecations(): to the new developer public API from 1.5 to 1.6. """ - class OldEstimator(BaseEstimator): + class OldEstimator(TransformerMixin, BaseEstimator): def fit(self, X, y=None): X = self._validate_data(X) self._check_n_features(X, reset=True) From b41d1a6bafa25db5a7dc882ac21a6bdc131d744d Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 20 Nov 2024 22:42:24 +0100 Subject: [PATCH 24/24] fix --- sklearn/tests/test_common.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index 63474daeca37f..c74c69fb27eab 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -17,7 +17,7 @@ from scipy.linalg import LinAlgWarning import sklearn -from sklearn.base import BaseEstimator +from sklearn.base import BaseEstimator, TransformerMixin from sklearn.compose import ColumnTransformer from sklearn.datasets import make_classification from sklearn.exceptions import ConvergenceWarning @@ -35,7 +35,7 @@ OneHotEncoder, StandardScaler, ) -from sklearn.utils import TransformerMixin, all_estimators +from sklearn.utils import all_estimators from sklearn.utils._test_common.instance_generator import ( _get_check_estimator_ids, _get_expected_failed_checks,