diff --git a/doc/developers/develop.rst b/doc/developers/develop.rst index b7b5d2ac0316f..08ce24933dd8e 100644 --- a/doc/developers/develop.rst +++ b/doc/developers/develop.rst @@ -511,12 +511,12 @@ Scikit-learn introduced estimator tags in version 0.21. These are annotations of estimators that allow programmatic inspection of their capabilities, such as sparse matrix support, supported output types and supported methods. The estimator tags are a dictionary returned by the method ``_get_tags()``. These -tags are used by the common tests and the -:func:`sklearn.utils.estimator_checks.check_estimator` function to decide what -tests to run and what input data is appropriate. Tags can depend on estimator -parameters or even system architecture and can in general only be determined at -runtime. The default values for the estimator tags are defined in the -``BaseEstimator`` class. +tags are used in the common checks run by the +:func:`~sklearn.utils.estimator_checks.check_estimator` function and the +:func:`~sklearn.utils.estimator_checks.parametrize_with_checks` decorator. +Tags determine which checks to run and what input data is appropriate. Tags +can depend on estimator parameters or even system architecture and can in +general only be determined at runtime. The current set of estimator tags are: @@ -618,9 +618,10 @@ X_types (default=['2darray']) ``'categorical'`` data. For now, the test for sparse data do not make use of the ``'sparse'`` tag. - -To override the tags of a child class, one must define the `_more_tags()` -method and return a dict with the desired tags, e.g:: +It is unlikely that the default values for each tag will suit the needs of your +specific estimator. Additional tags can be created or default tags can be +overridden by defining a `_more_tags()` method which returns a dict with the +desired overridden tags or new tags. For example:: class MyMultiOutputEstimator(BaseEstimator): @@ -628,6 +629,14 @@ method and return a dict with the desired tags, e.g:: return {'multioutput_only': True, 'non_deterministic': True} +Any tag that is not in `_more_tags()` will just fall-back to the default values +documented above. + +Even if it is not recommended, it is possible to override the method +`_get_tags()`. Note however that **all tags must be present in the dict**. If +any of the keys documented above is not present in the output of `_get_tags()`, +an error will occur. + In addition to the tags, estimators also need to declare any non-optional parameters to ``__init__`` in the ``_required_parameters`` class attribute, which is a list or tuple. If ``_required_parameters`` is only diff --git a/sklearn/base.py b/sklearn/base.py index 96abc511b6125..3d49ec4fe96f6 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -15,32 +15,15 @@ from . import __version__ from ._config import get_config from .utils import _IS_32BIT +from .utils._tags import ( + _DEFAULT_TAGS, + _safe_tags, +) from .utils.validation import check_X_y from .utils.validation import check_array from .utils._estimator_html_repr import estimator_html_repr from .utils.validation import _deprecate_positional_args -_DEFAULT_TAGS = { - 'non_deterministic': False, - 'requires_positive_X': False, - 'requires_positive_y': False, - 'X_types': ['2darray'], - 'poor_score': False, - 'no_validation': False, - 'multioutput': False, - "allow_nan": False, - 'stateless': False, - 'multilabel': False, - '_skip_test': False, - '_xfail_checks': False, - 'multioutput_only': False, - 'binary_only': False, - 'requires_fit': True, - 'preserves_dtype': [np.float64], - 'requires_y': False, - 'pairwise': False, - } - @_deprecate_positional_args def clone(estimator, *, safe=True): @@ -858,11 +841,7 @@ def _is_pairwise(estimator): warnings.filterwarnings('ignore', category=FutureWarning) has_pairwise_attribute = hasattr(estimator, '_pairwise') pairwise_attribute = getattr(estimator, '_pairwise', False) - - if hasattr(estimator, '_get_tags') and callable(estimator._get_tags): - pairwise_tag = estimator._get_tags().get('pairwise', False) - else: - pairwise_tag = False + pairwise_tag = _safe_tags(estimator, key="pairwise") if has_pairwise_attribute: if pairwise_attribute != pairwise_tag: diff --git a/sklearn/feature_selection/_base.py b/sklearn/feature_selection/_base.py index a5d752cb3f4b6..60f891b69e2b7 100644 --- a/sklearn/feature_selection/_base.py +++ b/sklearn/feature_selection/_base.py @@ -12,9 +12,12 @@ from scipy.sparse import issparse, csc_matrix from ..base import TransformerMixin -from ..utils import check_array -from ..utils import safe_mask -from ..utils import safe_sqr +from ..utils import ( + check_array, + safe_mask, + safe_sqr, +) +from ..utils._tags import _safe_tags class SelectorMixin(TransformerMixin, metaclass=ABCMeta): @@ -74,9 +77,14 @@ def transform(self, X): X_r : array of shape [n_samples, n_selected_features] The input samples with only the selected features. """ - tags = self._get_tags() - X = check_array(X, dtype=None, accept_sparse='csr', - force_all_finite=not tags.get('allow_nan', True)) + # note: we use _safe_tags instead of _get_tags because this is a + # public Mixin. + X = check_array( + X, + dtype=None, + accept_sparse="csr", + force_all_finite=not _safe_tags(self, key="allow_nan"), + ) mask = self.get_support() if not mask.any(): warn("No features were selected: either the data is" diff --git a/sklearn/feature_selection/_from_model.py b/sklearn/feature_selection/_from_model.py index 5fb519a2bd798..4b96804fbcc45 100644 --- a/sklearn/feature_selection/_from_model.py +++ b/sklearn/feature_selection/_from_model.py @@ -7,6 +7,7 @@ from ._base import SelectorMixin from ._base import _get_feature_importances from ..base import BaseEstimator, clone, MetaEstimatorMixin +from ..utils._tags import _safe_tags from ..utils.validation import check_is_fitted from ..exceptions import NotFittedError @@ -283,5 +284,6 @@ def n_features_in_(self): return self.estimator_.n_features_in_ def _more_tags(self): - estimator_tags = self.estimator._get_tags() - return {'allow_nan': estimator_tags.get('allow_nan', True)} + return { + 'allow_nan': _safe_tags(self.estimator, key="allow_nan") + } diff --git a/sklearn/feature_selection/_rfe.py b/sklearn/feature_selection/_rfe.py index 9e6912792e837..16519dfba6761 100644 --- a/sklearn/feature_selection/_rfe.py +++ b/sklearn/feature_selection/_rfe.py @@ -10,8 +10,10 @@ import numbers from joblib import Parallel, effective_n_jobs + from ..utils.metaestimators import if_delegate_has_method from ..utils.metaestimators import _safe_split +from ..utils._tags import _safe_tags from ..utils.validation import check_is_fitted from ..utils.validation import _deprecate_positional_args from ..utils.fixes import delayed @@ -191,7 +193,7 @@ def _fit(self, X, y, step_score=None): X, y = self._validate_data( X, y, accept_sparse="csc", ensure_min_features=2, - force_all_finite=not tags.get('allow_nan', True), + force_all_finite=not tags.get("allow_nan", True), multi_output=True ) error_msg = ("n_features_to_select must be either None, a " @@ -371,11 +373,11 @@ def predict_log_proba(self, X): return self.estimator_.predict_log_proba(self.transform(X)) def _more_tags(self): - estimator_tags = self.estimator._get_tags() - return {'poor_score': True, - 'allow_nan': estimator_tags.get('allow_nan', True), - 'requires_y': True, - } + return { + 'poor_score': True, + 'allow_nan': _safe_tags(self.estimator, key='allow_nan'), + 'requires_y': True, + } class RFECV(RFE): diff --git a/sklearn/feature_selection/_sequential.py b/sklearn/feature_selection/_sequential.py index 545cde6a5cfef..271bc0062ef6b 100644 --- a/sklearn/feature_selection/_sequential.py +++ b/sklearn/feature_selection/_sequential.py @@ -7,6 +7,7 @@ from ._base import SelectorMixin from ..base import BaseEstimator, MetaEstimatorMixin, clone +from ..utils._tags import _safe_tags from ..utils.validation import check_is_fitted from ..model_selection import cross_val_score @@ -128,12 +129,11 @@ def fit(self, X, y): ------- self : object """ - tags = self._get_tags() X, y = self._validate_data( X, y, accept_sparse="csc", ensure_min_features=2, - force_all_finite=not tags.get('allow_nan', True), + force_all_finite=not tags.get("allow_nan", True), multi_output=True ) n_features = X.shape[1] @@ -207,8 +207,7 @@ def _get_support_mask(self): return self.support_ def _more_tags(self): - estimator_tags = self.estimator._get_tags() return { - 'allow_nan': estimator_tags.get('allow_nan', True), + 'allow_nan': _safe_tags(self.estimator, key="allow_nan"), 'requires_y': True, } diff --git a/sklearn/feature_selection/tests/test_rfe.py b/sklearn/feature_selection/tests/test_rfe.py index 6cbaeffb12997..9e6dfdbbd593a 100644 --- a/sklearn/feature_selection/tests/test_rfe.py +++ b/sklearn/feature_selection/tests/test_rfe.py @@ -56,8 +56,8 @@ def get_params(self, deep=True): def set_params(self, **params): return self - def _get_tags(self): - return {} + def _more_tags(self): + return {"allow_nan": True} def test_rfe_features_importance(): @@ -448,10 +448,7 @@ def test_rfe_importance_getter_validation(importance_getter, err_type, model.fit(X, y) -@pytest.mark.parametrize("cv", [ - None, - 5 -]) +@pytest.mark.parametrize("cv", [None, 5]) def test_rfe_allow_nan_inf_in_x(cv): iris = load_iris() X = iris.data diff --git a/sklearn/linear_model/_glm/tests/test_glm.py b/sklearn/linear_model/_glm/tests/test_glm.py index ece8f09c76acd..d6fc4e14b12fa 100644 --- a/sklearn/linear_model/_glm/tests/test_glm.py +++ b/sklearn/linear_model/_glm/tests/test_glm.py @@ -419,13 +419,13 @@ def test_tweedie_regression_family(regression_data): @pytest.mark.parametrize( - 'estimator, value', - [ - (PoissonRegressor(), True), - (GammaRegressor(), True), - (TweedieRegressor(power=1.5), True), - (TweedieRegressor(power=0), False) - ], + 'estimator, value', + [ + (PoissonRegressor(), True), + (GammaRegressor(), True), + (TweedieRegressor(power=1.5), True), + (TweedieRegressor(power=0), False), + ], ) def test_tags(estimator, value): assert estimator._get_tags()['requires_positive_y'] is value diff --git a/sklearn/linear_model/tests/test_coordinate_descent.py b/sklearn/linear_model/tests/test_coordinate_descent.py index 0c504dfbd8c6c..f9a7efc987699 100644 --- a/sklearn/linear_model/tests/test_coordinate_descent.py +++ b/sklearn/linear_model/tests/test_coordinate_descent.py @@ -300,7 +300,7 @@ def test_lasso_cv_positive_constraint(): (Lars, {}), (LinearRegression, {}), (LassoLarsIC, {})] - ) +) def test_model_pipeline_same_as_normalize_true(LinearModel, params): # Test that linear models (LinearModel) set with normalize set to True are # doing the same as the same linear model preceeded by StandardScaler @@ -315,7 +315,7 @@ def test_model_pipeline_same_as_normalize_true(LinearModel, params): LinearModel(normalize=False, fit_intercept=True, **params) ) - is_multitask = model_normalize._get_tags().get("multioutput_only", False) + is_multitask = model_normalize._get_tags()["multioutput_only"] # prepare the data n_samples, n_features = 100, 2 diff --git a/sklearn/model_selection/_search.py b/sklearn/model_selection/_search.py index eb282535dd4d5..51f43debf78ed 100644 --- a/sklearn/model_selection/_search.py +++ b/sklearn/model_selection/_search.py @@ -35,6 +35,7 @@ from joblib import Parallel from ..utils import check_random_state from ..utils.random import sample_without_replacement +from ..utils._tags import _safe_tags from ..utils.validation import indexable, check_is_fitted, _check_fit_params from ..utils.validation import _deprecate_positional_args from ..utils.metaestimators import if_delegate_has_method @@ -433,9 +434,8 @@ def _estimator_type(self): def _more_tags(self): # allows cross-validation to see 'precomputed' metrics - estimator_tags = self.estimator._get_tags() return { - 'pairwise': estimator_tags.get('pairwise', False), + 'pairwise': _safe_tags(self.estimator, "pairwise"), "_xfail_checks": {"check_supervised_y_2d": "DataConversionWarning not caught"}, } diff --git a/sklearn/model_selection/tests/test_search.py b/sklearn/model_selection/tests/test_search.py index 308d927911eaf..b1194600c530d 100644 --- a/sklearn/model_selection/tests/test_search.py +++ b/sklearn/model_selection/tests/test_search.py @@ -13,21 +13,26 @@ import scipy.sparse as sp import pytest -from sklearn.utils._testing import assert_raises -from sklearn.utils._testing import assert_warns -from sklearn.utils._testing import assert_warns_message -from sklearn.utils._testing import assert_raise_message -from sklearn.utils._testing import assert_array_equal -from sklearn.utils._testing import assert_array_almost_equal -from sklearn.utils._testing import assert_allclose -from sklearn.utils._testing import assert_almost_equal -from sklearn.utils._testing import ignore_warnings +from sklearn.utils._testing import ( + assert_raises, + assert_warns, + assert_warns_message, + assert_raise_message, + assert_array_equal, + assert_array_almost_equal, + assert_allclose, + assert_almost_equal, + ignore_warnings, + MinimalClassifier, + MinimalRegressor, + MinimalTransformer, +) from sklearn.utils._mocking import CheckingClassifier, MockDataFrame from scipy.stats import bernoulli, expon, uniform from sklearn.base import BaseEstimator, ClassifierMixin -from sklearn.base import clone +from sklearn.base import clone, is_classifier from sklearn.exceptions import NotFittedError from sklearn.datasets import make_classification from sklearn.datasets import make_blobs @@ -63,6 +68,7 @@ from sklearn.metrics import make_scorer from sklearn.metrics import roc_auc_score from sklearn.metrics import confusion_matrix +from sklearn.metrics import r2_score from sklearn.metrics.pairwise import euclidean_distances from sklearn.impute import SimpleImputer from sklearn.pipeline import Pipeline @@ -2079,3 +2085,35 @@ def _fit_param_callable(): 'scalar_param': 42, } model.fit(X_train, y_train, **fit_params) + + +# FIXME: Replace this test with a full `check_estimator` once we have API only +# checks. +@pytest.mark.filterwarnings("ignore:The total space of parameters 4 is") +@pytest.mark.parametrize("SearchCV", [GridSearchCV, RandomizedSearchCV]) +@pytest.mark.parametrize("Predictor", [MinimalRegressor, MinimalClassifier]) +def test_search_cv_using_minimal_compatible_estimator(SearchCV, Predictor): + # Check that third-party library can run tests without inheriting from + # BaseEstimator. + rng = np.random.RandomState(0) + X, y = rng.randn(25, 2), np.array([0] * 5 + [1] * 20) + + model = Pipeline([ + ("transformer", MinimalTransformer()), ("predictor", Predictor()) + ]) + + params = { + "transformer__param": [1, 10], "predictor__parama": [1, 10], + } + search = SearchCV(model, params, error_score="raise") + search.fit(X, y) + + assert search.best_params_.keys() == params.keys() + + y_pred = search.predict(X) + if is_classifier(search): + assert_array_equal(y_pred, 1) + assert search.score(X, y) == pytest.approx(accuracy_score(y, y_pred)) + else: + assert_allclose(y_pred, y.mean()) + assert search.score(X, y) == pytest.approx(r2_score(y, y_pred)) diff --git a/sklearn/model_selection/tests/test_validation.py b/sklearn/model_selection/tests/test_validation.py index e8fdbc1dd32df..4437a7a4cb35c 100644 --- a/sklearn/model_selection/tests/test_validation.py +++ b/sklearn/model_selection/tests/test_validation.py @@ -1985,15 +1985,3 @@ def _more_tags(self): "Set the estimator tags of your estimator instead") with pytest.warns(FutureWarning, match=msg): cross_validate(svm, linear_kernel, y, cv=2) - - # the _pairwise attribute is present and set to True while the pairwise - # tag is not present - class NoEstimatorTagSVM(SVC): - def _get_tags(self): - tags = super()._get_tags() - del tags['pairwise'] - return tags - - svm = NoEstimatorTagSVM(kernel='precomputed') - with pytest.warns(FutureWarning, match=msg): - cross_validate(svm, linear_kernel, y, cv=2) diff --git a/sklearn/multiclass.py b/sklearn/multiclass.py index 50bf83d4eaa43..182a412f8313f 100644 --- a/sklearn/multiclass.py +++ b/sklearn/multiclass.py @@ -47,6 +47,7 @@ from .metrics.pairwise import euclidean_distances from .utils import check_random_state from .utils.deprecation import deprecated +from .utils._tags import _safe_tags from .utils.validation import _num_samples from .utils.validation import check_is_fitted from .utils.validation import check_X_y, check_array @@ -499,8 +500,7 @@ def _pairwise(self): def _more_tags(self): """Indicate if wrapped estimator is using a precomputed Gram matrix""" - estimator_tags = self.estimator._get_tags() - return {'pairwise': estimator_tags.get('pairwise', False)} + return {'pairwise': _safe_tags(self.estimator, key="pairwise")} @property def _first_estimator(self): @@ -780,8 +780,9 @@ def _pairwise(self): def _more_tags(self): """Indicate if wrapped estimator is using a precomputed Gram matrix""" - estimator_tags = self.estimator._get_tags() - return {'pairwise': estimator_tags.get('pairwise', True)} + return { + 'pairwise': _safe_tags(self.estimator, key="pairwise") + } class OutputCodeClassifier(MetaEstimatorMixin, ClassifierMixin, BaseEstimator): diff --git a/sklearn/neighbors/_base.py b/sklearn/neighbors/_base.py index 187e41c242405..1e666043347cf 100644 --- a/sklearn/neighbors/_base.py +++ b/sklearn/neighbors/_base.py @@ -23,8 +23,11 @@ from ..base import is_classifier from ..metrics import pairwise_distances_chunked from ..metrics.pairwise import PAIRWISE_DISTANCE_FUNCTIONS -from ..utils import check_array, gen_even_slices -from ..utils import _to_object_array +from ..utils import ( + check_array, + gen_even_slices, + _to_object_array, +) from ..utils.deprecation import deprecated from ..utils.multiclass import check_classification_targets from ..utils.validation import check_is_fitted diff --git a/sklearn/pipeline.py b/sklearn/pipeline.py index 8d738d4b90fff..6df8cddc476c4 100644 --- a/sklearn/pipeline.py +++ b/sklearn/pipeline.py @@ -19,8 +19,12 @@ from .base import clone, TransformerMixin from .utils._estimator_html_repr import _VisualBlock from .utils.metaestimators import if_delegate_has_method -from .utils import Bunch, _print_elapsed_time +from .utils import ( + Bunch, + _print_elapsed_time, +) from .utils.deprecation import deprecated +from .utils._tags import _safe_tags from .utils.validation import check_memory from .utils.validation import _deprecate_positional_args from .utils.fixes import delayed @@ -623,8 +627,7 @@ def classes_(self): def _more_tags(self): # check if first estimator expects pairwise input - estimator_tags = self.steps[0][1]._get_tags() - return {'pairwise': estimator_tags.get('pairwise', False)} + return {'pairwise': _safe_tags(self.steps[0][1], "pairwise")} # TODO: Remove in 0.26 # mypy error: Decorated property not supported diff --git a/sklearn/tests/test_base.py b/sklearn/tests/test_base.py index b8d78a96d8e85..0c07db459d128 100644 --- a/sklearn/tests/test_base.py +++ b/sklearn/tests/test_base.py @@ -558,24 +558,9 @@ class IncorrectTagPCA(KernelPCA): with pytest.warns(FutureWarning, match=msg): assert not _is_pairwise(pca) - # the _pairwise attribute is present and set to False while the pairwise - # tag is not present - class FalsePairwise(BaseEstimator): - _pairwise = False - - def _get_tags(self): - tags = super()._get_tags() - del tags['pairwise'] - return tags - - false_pairwise = FalsePairwise() - with pytest.warns(None) as record: - assert not _is_pairwise(false_pairwise) - assert not record - # the _pairwise attribute is present and set to True while pairwise tag is # not present - class TruePairwise(FalsePairwise): + class TruePairwise(BaseEstimator): _pairwise = True true_pairwise = TruePairwise() diff --git a/sklearn/tests/test_multiclass.py b/sklearn/tests/test_multiclass.py index dbf574db6fe4d..b60a85c7bde00 100644 --- a/sklearn/tests/test_multiclass.py +++ b/sklearn/tests/test_multiclass.py @@ -17,8 +17,10 @@ from sklearn.multiclass import OutputCodeClassifier from sklearn.utils.multiclass import (check_classification_targets, type_of_target) -from sklearn.utils import check_array -from sklearn.utils import shuffle +from sklearn.utils import ( + check_array, + shuffle, +) from sklearn.metrics import precision_score from sklearn.metrics import recall_score @@ -794,10 +796,10 @@ def test_pairwise_tag(MultiClassClassifier): clf_notprecomputed = svm.SVC() ovr_false = MultiClassClassifier(clf_notprecomputed) - assert not ovr_false._get_tags()['pairwise'] + assert not ovr_false._get_tags()["pairwise"] ovr_true = MultiClassClassifier(clf_precomputed) - assert ovr_true._get_tags()['pairwise'] + assert ovr_true._get_tags()["pairwise"] # TODO: Remove in 0.26 diff --git a/sklearn/tests/test_pipeline.py b/sklearn/tests/test_pipeline.py index bd88f4acd03c3..7989394d0a65e 100644 --- a/sklearn/tests/test_pipeline.py +++ b/sklearn/tests/test_pipeline.py @@ -12,21 +12,27 @@ from scipy import sparse import joblib -from sklearn.utils._testing import assert_raises -from sklearn.utils._testing import assert_raises_regex -from sklearn.utils._testing import assert_raise_message -from sklearn.utils._testing import assert_allclose -from sklearn.utils._testing import assert_array_equal -from sklearn.utils._testing import assert_array_almost_equal -from sklearn.utils._testing import assert_no_warnings from sklearn.utils.fixes import parse_version +from sklearn.utils._testing import ( + assert_raises, + assert_raises_regex, + assert_raise_message, + assert_allclose, + assert_array_equal, + assert_array_almost_equal, + assert_no_warnings, + MinimalClassifier, + MinimalRegressor, + MinimalTransformer, +) -from sklearn.base import clone, BaseEstimator, TransformerMixin +from sklearn.base import clone, is_classifier, BaseEstimator, TransformerMixin from sklearn.pipeline import Pipeline, FeatureUnion, make_pipeline, make_union from sklearn.svm import SVC from sklearn.neighbors import LocalOutlierFactor from sklearn.linear_model import LogisticRegression, Lasso from sklearn.linear_model import LinearRegression +from sklearn.metrics import accuracy_score, r2_score from sklearn.cluster import KMeans from sklearn.feature_selection import SelectKBest, f_classif from sklearn.dummy import DummyRegressor @@ -1264,3 +1270,36 @@ def test_feature_union_warns_unknown_transformer_weight(): union = FeatureUnion(transformer_list, transformer_weights=weights) with pytest.raises(ValueError, match=expected_msg): union.fit(X, y) + + +@pytest.mark.parametrize('passthrough', [None, 'passthrough']) +def test_pipeline_get_tags_none(passthrough): + # Checks that tags are set correctly when the first transformer is None or + # 'passthrough' + # Non-regression test for: + # https://github.com/scikit-learn/scikit-learn/issues/18815 + pipe = make_pipeline(passthrough, SVC()) + assert not pipe._get_tags()['pairwise'] + + +# FIXME: Replace this test with a full `check_estimator` once we have API only +# checks. +@pytest.mark.parametrize("Predictor", [MinimalRegressor, MinimalClassifier]) +def test_search_cv_using_minimal_compatible_estimator(Predictor): + # Check that third-party library estimators can be part of a pipeline + # and tuned by grid-search without inheriting from BaseEstimator. + rng = np.random.RandomState(0) + X, y = rng.randn(25, 2), np.array([0] * 5 + [1] * 20) + + model = Pipeline([ + ("transformer", MinimalTransformer()), ("predictor", Predictor()) + ]) + model.fit(X, y) + + y_pred = model.predict(X) + if is_classifier(model): + assert_array_equal(y_pred, 1) + assert model.score(X, y) == pytest.approx(accuracy_score(y, y_pred)) + else: + assert_allclose(y_pred, y.mean()) + assert model.score(X, y) == pytest.approx(r2_score(y, y_pred)) diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index 9d542102b1dda..ca2be9d14fe29 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -43,7 +43,6 @@ parallel_backend = _joblib.parallel_backend register_parallel_backend = _joblib.register_parallel_backend - __all__ = ["murmurhash3_32", "as_float_array", "assert_all_finite", "check_array", "check_random_state", @@ -53,8 +52,7 @@ "check_symmetric", "indices_to_mask", "deprecated", "parallel_backend", "register_parallel_backend", "resample", "shuffle", "check_matplotlib_support", "all_estimators", - "DataConversionWarning", "estimator_html_repr" - ] + "DataConversionWarning", "estimator_html_repr"] IS_PYPY = platform.python_implementation() == 'PyPy' _IS_32BIT = 8 * struct.calcsize("P") == 32 diff --git a/sklearn/utils/_tags.py b/sklearn/utils/_tags.py new file mode 100644 index 0000000000000..ac908ec63ce82 --- /dev/null +++ b/sklearn/utils/_tags.py @@ -0,0 +1,67 @@ +import numpy as np + +_DEFAULT_TAGS = { + 'non_deterministic': False, + 'requires_positive_X': False, + 'requires_positive_y': False, + 'X_types': ['2darray'], + 'poor_score': False, + 'no_validation': False, + 'multioutput': False, + "allow_nan": False, + 'stateless': False, + 'multilabel': False, + '_skip_test': False, + '_xfail_checks': False, + 'multioutput_only': False, + 'binary_only': False, + 'requires_fit': True, + 'preserves_dtype': [np.float64], + 'requires_y': False, + 'pairwise': False, +} + + +def _safe_tags(estimator, key=None): + """Safely get estimator tags. + + :class:`~sklearn.BaseEstimator` provides the estimator tags machinery. + However, if an estimator does not inherit from this base class, we should + fall-back to the default tags. + + For scikit-learn built-in estimators, we should still rely on + `self._get_tags()`. `_safe_tags(est)` should be used when we are not sure + where `est` comes from: typically `_safe_tags(self.base_estimator)` where + `self` is a meta-estimator, or in the common checks. + + Parameters + ---------- + estimator : estimator object + The estimator from which to get the tag. + + key : str, default=None + Tag name to get. By default (`None`), all tags are returned. + + Returns + ------- + tags : dict or tag value + The estimator tags. A single value is returned if `key` is not None. + """ + if hasattr(estimator, "_get_tags"): + tags_provider = "_get_tags()" + tags = estimator._get_tags() + elif hasattr(estimator, "_more_tags"): + tags_provider = "_more_tags()" + tags = {**_DEFAULT_TAGS, **estimator._more_tags()} + else: + tags_provider = "_DEFAULT_TAGS" + tags = _DEFAULT_TAGS + + if key is not None: + if key not in tags: + raise ValueError( + f"The key {key} is not defined in {tags_provider} for the " + f"class {estimator.__class__.__name__}." + ) + return tags[key] + return tags diff --git a/sklearn/utils/_testing.py b/sklearn/utils/_testing.py index 0f340967b0cec..779e7b6574e3e 100644 --- a/sklearn/utils/_testing.py +++ b/sklearn/utils/_testing.py @@ -49,6 +49,12 @@ import sklearn from sklearn.utils import IS_PYPY, _IS_32BIT +from sklearn.utils.multiclass import check_classification_targets +from sklearn.utils.validation import ( + check_array, + check_is_fitted, + check_X_y, +) __all__ = ["assert_raises", @@ -866,3 +872,124 @@ def __exit__(self, exc_type, exc_value, _): self.raised_and_matched = True return True + + +class MinimalClassifier: + """Minimal classifier implementation with inheriting from BaseEstimator. + + This estimator should be tested with: + + * `check_estimator` in `test_estimator_checks.py`; + * within a `Pipeline` in `test_pipeline.py`; + * within a `SearchCV` in `test_search.py`. + """ + _estimator_type = "classifier" + + def __init__(self, param=None): + self.param = param + + def get_params(self, deep=True): + return {"param": self.param} + + def set_params(self, **params): + for key, value in params.items(): + setattr(self, key, value) + return self + + def fit(self, X, y): + X, y = check_X_y(X, y) + check_classification_targets(y) + self.classes_, counts = np.unique(y, return_counts=True) + self._most_frequent_class_idx = counts.argmax() + return self + + def predict_proba(self, X): + check_is_fitted(self) + X = check_array(X) + proba_shape = (X.shape[0], self.classes_.size) + y_proba = np.zeros(shape=proba_shape, dtype=np.float64) + y_proba[:, self._most_frequent_class_idx] = 1.0 + return y_proba + + def predict(self, X): + y_proba = self.predict_proba(X) + y_pred = y_proba.argmax(axis=1) + return self.classes_[y_pred] + + def score(self, X, y): + from sklearn.metrics import accuracy_score + return accuracy_score(y, self.predict(X)) + + +class MinimalRegressor: + """Minimal regressor implementation with inheriting from BaseEstimator. + + This estimator should be tested with: + + * `check_estimator` in `test_estimator_checks.py`; + * within a `Pipeline` in `test_pipeline.py`; + * within a `SearchCV` in `test_search.py`. + """ + _estimator_type = "regressor" + + def __init__(self, param=None): + self.param = param + + def get_params(self, deep=True): + return {"param": self.param} + + def set_params(self, **params): + for key, value in params.items(): + setattr(self, key, value) + return self + + def fit(self, X, y): + X, y = check_X_y(X, y) + self.is_fitted_ = True + self._mean = np.mean(y) + return self + + def predict(self, X): + check_is_fitted(self) + X = check_array(X) + return np.ones(shape=(X.shape[0],)) * self._mean + + def score(self, X, y): + from sklearn.metrics import r2_score + return r2_score(y, self.predict(X)) + + +class MinimalTransformer: + """Minimal transformer implementation with inheriting from + BaseEstimator. + + This estimator should be tested with: + + * `check_estimator` in `test_estimator_checks.py`; + * within a `Pipeline` in `test_pipeline.py`; + * within a `SearchCV` in `test_search.py`. + """ + + def __init__(self, param=None): + self.param = param + + def get_params(self, deep=True): + return {"param": self.param} + + def set_params(self, **params): + for key, value in params.items(): + setattr(self, key, value) + return self + + def fit(self, X, y=None): + X = check_array(X) + self.is_fitted_ = True + return self + + def transform(self, X, y=None): + check_is_fitted(self) + X = check_array(X) + return X + + def fit_transform(self, X, y=None): + return self.fit(X, y).transform(X, y) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 3148767e79676..7925487bad522 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -52,6 +52,10 @@ from ..metrics.pairwise import (rbf_kernel, linear_kernel, pairwise_distances) from .import shuffle +from ._tags import ( + _DEFAULT_TAGS, + _safe_tags, +) from .validation import has_fit_parameter, _num_samples from ..preprocessing import StandardScaler from ..preprocessing import scale @@ -68,7 +72,7 @@ def _yield_checks(estimator): name = estimator.__class__.__name__ - tags = estimator._get_tags() + tags = _safe_tags(estimator) pairwise = _is_pairwise(estimator) yield check_no_attributes_set_in_init @@ -114,9 +118,10 @@ def _yield_checks(estimator): # give the same answer as before. yield check_estimators_pickle + yield check_estimator_get_tags_default_keys def _yield_classifier_checks(classifier): - tags = classifier._get_tags() + tags = _safe_tags(classifier) # test classifiers can handle non-array data and pandas objects yield check_classifier_data_not_an_array @@ -170,7 +175,7 @@ def check_supervised_y_no_nan(name, estimator_orig): def _yield_regressor_checks(regressor): - tags = regressor._get_tags() + tags = _safe_tags(regressor) # TODO: test with intercept # TODO: test with multiple responses # basic testing @@ -196,7 +201,7 @@ def _yield_regressor_checks(regressor): def _yield_transformer_checks(transformer): - tags = transformer._get_tags() + tags = _safe_tags(transformer) # All transformers should either deal with sparse data or raise an # exception with type TypeError and an intelligible error message if not tags["no_validation"]: @@ -206,7 +211,7 @@ def _yield_transformer_checks(transformer): if tags["preserves_dtype"]: yield check_transformer_preserve_dtypes yield partial(check_transformer_general, readonly_memmap=True) - if not transformer._get_tags()["stateless"]: + if not _safe_tags(transformer, key="stateless"): yield check_transformers_unfitted # Dependent on external solvers and hence accessing the iter # param is non-trivial. @@ -243,13 +248,13 @@ def _yield_outliers_checks(estimator): # test outlier detectors can handle non-array data yield check_classifier_data_not_an_array # test if NotFittedError is raised - if estimator._get_tags()["requires_fit"]: + if _safe_tags(estimator, key="requires_fit"): yield check_estimators_unfitted def _yield_all_checks(estimator): name = estimator.__class__.__name__ - tags = estimator._get_tags() + tags = _safe_tags(estimator) if "2darray" not in tags["X_types"]: warnings.warn("Can't test estimator {} which requires input " " of type {}".format(name, tags["X_types"]), @@ -416,7 +421,7 @@ def _should_be_skipped_or_marked(estimator, check): check_name = (check.func.__name__ if isinstance(check, partial) else check.__name__) - xfail_checks = estimator._get_tags()['_xfail_checks'] or {} + xfail_checks = _safe_tags(estimator, key='_xfail_checks') or {} if check_name in xfail_checks: return True, xfail_checks[check_name] @@ -736,7 +741,7 @@ def check_estimator_sparse_data(name, estimator_orig): with ignore_warnings(category=FutureWarning): estimator = clone(estimator_orig) y = _enforce_estimator_tags_y(estimator, y) - tags = estimator_orig._get_tags() + tags = _safe_tags(estimator_orig) for matrix_format, X in _generate_sparse_matrix(X_csr): # catch deprecation warnings with ignore_warnings(category=FutureWarning): @@ -793,7 +798,7 @@ def check_sample_weights_pandas_series(name, estimator_orig): X = pd.DataFrame(_pairwise_estimator_convert_X(X, estimator_orig)) y = pd.Series([1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 2, 2]) weights = pd.Series([1] * 12) - if estimator._get_tags()["multioutput_only"]: + if _safe_tags(estimator, key="multioutput_only"): y = pd.DataFrame(y) try: estimator.fit(X, y, sample_weight=weights) @@ -818,7 +823,7 @@ def check_sample_weights_not_an_array(name, estimator_orig): X = _NotAnArray(_pairwise_estimator_convert_X(X, estimator_orig)) y = _NotAnArray([1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 2, 2]) weights = _NotAnArray([1] * 12) - if estimator._get_tags()["multioutput_only"]: + if _safe_tags(estimator, key="multioutput_only"): y = _NotAnArray(y.data.reshape(-1, 1)) estimator.fit(X, y, sample_weight=weights) @@ -922,7 +927,7 @@ def check_dtype_object(name, estimator_orig): rng = np.random.RandomState(0) X = _pairwise_estimator_convert_X(rng.rand(40, 10), estimator_orig) X = X.astype(object) - tags = estimator_orig._get_tags() + tags = _safe_tags(estimator_orig) y = (X[:, 0] * 4).astype(int) estimator = clone(estimator_orig) y = _enforce_estimator_tags_y(estimator, y) @@ -1140,7 +1145,7 @@ def check_methods_sample_order_invariance(name, estimator_orig): X = 3 * rnd.uniform(size=(20, 3)) X = _pairwise_estimator_convert_X(X, estimator_orig) y = X[:, 0].astype(np.int64) - if estimator_orig._get_tags()['binary_only']: + if _safe_tags(estimator_orig, key='binary_only'): y[y == 2] = 1 estimator = clone(estimator_orig) y = _enforce_estimator_tags_y(estimator, y) @@ -1328,7 +1333,7 @@ def _check_transformer(name, transformer_orig, X, y): X_pred2 = transformer.transform(X) X_pred3 = transformer.fit_transform(X, y=y_) - if transformer_orig._get_tags()['non_deterministic']: + if _safe_tags(transformer_orig, key='non_deterministic'): msg = name + ' is non deterministic' raise SkipTest(msg) if isinstance(X_pred, tuple) and isinstance(X_pred2, tuple): @@ -1359,7 +1364,7 @@ def _check_transformer(name, transformer_orig, X, y): # raises error on malformed input for transform if hasattr(X, 'shape') and \ - not transformer._get_tags()["stateless"] and \ + not _safe_tags(transformer, key="stateless") and \ X.ndim == 2 and X.shape[1] > 1: # If it's not an array, it does not have a 'T' property @@ -1374,7 +1379,7 @@ def _check_transformer(name, transformer_orig, X, y): @ignore_warnings def check_pipeline_consistency(name, estimator_orig): - if estimator_orig._get_tags()['non_deterministic']: + if _safe_tags(estimator_orig, key='non_deterministic'): msg = name + ' is non deterministic' raise SkipTest(msg) @@ -1466,7 +1471,7 @@ def check_transformer_preserve_dtypes(name, transformer_orig): X -= X.min() X = _pairwise_estimator_convert_X(X, transformer_orig) - for dtype in transformer_orig._get_tags()["preserves_dtype"]: + for dtype in _safe_tags(transformer_orig, key="preserves_dtype"): X_cast = X.astype(dtype) transformer = clone(transformer_orig) set_random_state(transformer) @@ -1591,7 +1596,7 @@ def check_estimators_pickle(name, estimator_orig): X -= X.min() X = _pairwise_estimator_convert_X(X, estimator_orig, kernel=rbf_kernel) - tags = estimator_orig._get_tags() + tags = _safe_tags(estimator_orig) # include NaN values when the estimator should deal with them if tags['allow_nan']: # set randomly 10 elements to np.nan @@ -1608,7 +1613,12 @@ def check_estimators_pickle(name, estimator_orig): # pickle and unpickle! pickled_estimator = pickle.dumps(estimator) - if estimator.__module__.startswith('sklearn.'): + module_name = estimator.__module__ + if module_name.startswith('sklearn.') and not ( + "test_" in module_name or module_name.endswith("_testing") + ): + # strict check for sklearn estimators that are not implemented in test + # modules. assert b"version" in pickled_estimator unpickled_estimator = pickle.loads(pickled_estimator) @@ -1652,7 +1662,7 @@ def check_estimators_partial_fit_n_features(name, estimator_orig): @ignore_warnings(category=FutureWarning) def check_classifier_multioutput(name, estimator): n_samples, n_labels, n_classes = 42, 5, 3 - tags = estimator._get_tags() + tags = _safe_tags(estimator) estimator = clone(estimator) X, y = make_multilabel_classification(random_state=42, n_samples=n_samples, @@ -1759,7 +1769,7 @@ def check_clustering(name, clusterer_orig, readonly_memmap=False): pred = clusterer.labels_ assert pred.shape == (n_samples,) assert adjusted_rand_score(pred, y) > 0.4 - if clusterer._get_tags()['non_deterministic']: + if _safe_tags(clusterer, key='non_deterministic'): return set_random_state(clusterer) with warnings.catch_warnings(record=True): @@ -1851,7 +1861,7 @@ def check_classifiers_train( X_m, y_m, X_b, y_b = create_memmap_backed_data([X_m, y_m, X_b, y_b]) problems = [(X_b, y_b)] - tags = classifier_orig._get_tags() + tags = _safe_tags(classifier_orig) if not tags['binary_only']: problems.append((X_m, y_m)) @@ -2142,7 +2152,7 @@ def check_estimators_unfitted(name, estimator_orig): @ignore_warnings(category=FutureWarning) def check_supervised_y_2d(name, estimator_orig): - tags = estimator_orig._get_tags() + tags = _safe_tags(estimator_orig) rnd = np.random.RandomState(0) n_samples = 30 X = _pairwise_estimator_convert_X( @@ -2249,7 +2259,7 @@ def check_classifiers_classes(name, classifier_orig): y_names_binary = np.take(labels_binary, y_binary) problems = [(X_binary, y_binary, y_names_binary)] - if not classifier_orig._get_tags()['binary_only']: + if not _safe_tags(classifier_orig, key='binary_only'): problems.append((X_multiclass, y_multiclass, y_names_multiclass)) for X, y, y_names in problems: @@ -2336,7 +2346,7 @@ def check_regressors_train( # TODO: find out why PLS and CCA fail. RANSAC is random # and furthermore assumes the presence of outliers, hence # skipped - if not regressor._get_tags()["poor_score"]: + if not _safe_tags(regressor, key="poor_score"): assert regressor.score(X, y_) > 0.5 @@ -2360,7 +2370,7 @@ def check_regressors_no_decision_function(name, regressor_orig): @ignore_warnings(category=FutureWarning) def check_class_weight_classifiers(name, classifier_orig): - if classifier_orig._get_tags()['binary_only']: + if _safe_tags(classifier_orig, key='binary_only'): problems = [2] else: problems = [2, 3] @@ -2399,7 +2409,7 @@ def check_class_weight_classifiers(name, classifier_orig): y_pred = classifier.predict(X_test) # XXX: Generally can use 0.89 here. On Windows, LinearSVC gets # 0.88 (Issue #9111) - if not classifier_orig._get_tags()['poor_score']: + if not _safe_tags(classifier_orig, key='poor_score'): assert np.mean(y_pred == 0) > 0.87 @@ -2717,16 +2727,16 @@ def param_filter(p): def _enforce_estimator_tags_y(estimator, y): # Estimators with a `requires_positive_y` tag only accept strictly positive # data - if estimator._get_tags()["requires_positive_y"]: + if _safe_tags(estimator, key="requires_positive_y"): # Create strictly positive y. The minimal increment above 0 is 1, as # y could be of integer dtype. y += 1 + abs(y.min()) # Estimators with a `binary_only` tag only accept up to two unique y values - if estimator._get_tags()["binary_only"] and y.size > 0: + if _safe_tags(estimator, key="binary_only") and y.size > 0: y = np.where(y == y.flat[0], y, y.flat[0] + 1) # Estimators in mono_output_task_error raise ValueError if y is of 1-D # Convert into a 2-D y for those estimators. - if estimator._get_tags()["multioutput_only"]: + if _safe_tags(estimator, key="multioutput_only"): return np.reshape(y, (-1, 1)) return y @@ -2738,11 +2748,11 @@ def _enforce_estimator_tags_x(estimator, X): X = X.dot(X.T) # Estimators with `1darray` in `X_types` tag only accept # X of shape (`n_samples`,) - if '1darray' in estimator._get_tags()['X_types']: + if '1darray' in _safe_tags(estimator, key='X_types'): X = X[:, 0] # Estimators with a `requires_positive_X` tag only accept # strictly positive data - if estimator._get_tags()['requires_positive_X']: + if _safe_tags(estimator, key='requires_positive_X'): X -= X.min() return X @@ -2884,7 +2894,7 @@ def check_classifiers_regression_target(name, estimator_orig): X = X + 1 + abs(X.min(axis=0)) # be sure that X is non-negative e = clone(estimator_orig) msg = "Unknown label type: " - if not e._get_tags()["no_validation"]: + if not _safe_tags(e, key="no_validation"): with raises(ValueError, match=msg): e.fit(X, y) @@ -3097,7 +3107,7 @@ def check_requires_y_none(name, estimator_orig): def check_n_features_in_after_fitting(name, estimator_orig): # Make sure that n_features_in are checked after fitting - tags = estimator_orig._get_tags() + tags = _safe_tags(estimator_orig) if "2darray" not in tags["X_types"] or tags["no_validation"]: return @@ -3147,3 +3157,18 @@ def check_n_features_in_after_fitting(name, estimator_orig): with raises(ValueError, match=msg): estimator.partial_fit(X_bad, y) + + +def check_estimator_get_tags_default_keys(name, estimator_orig): + # check that if _get_tags is implemented, it contains all keys from + # _DEFAULT_KEYS + estimator = clone(estimator_orig) + if not hasattr(estimator, "_get_tags"): + return + + tags_keys = set(estimator._get_tags().keys()) + default_tags_keys = set(_DEFAULT_TAGS.keys()) + assert tags_keys.intersection(default_tags_keys) == default_tags_keys, ( + f"{name}._get_tags() is missing entries for the following default tags" + f": {default_tags_keys - tags_keys.intersection(default_tags_keys)}" + ) diff --git a/sklearn/utils/tests/test_estimator_checks.py b/sklearn/utils/tests/test_estimator_checks.py index ecbf7cb7be7f4..8fabe5f91ea31 100644 --- a/sklearn/utils/tests/test_estimator_checks.py +++ b/sklearn/utils/tests/test_estimator_checks.py @@ -7,10 +7,16 @@ from sklearn.base import BaseEstimator, ClassifierMixin from sklearn.utils import deprecated -from sklearn.utils._testing import (assert_raises_regex, - ignore_warnings, - assert_warns, assert_raises, - SkipTest) +from sklearn.utils._testing import ( + assert_raises, + assert_raises_regex, + assert_warns, + ignore_warnings, + MinimalClassifier, + MinimalRegressor, + MinimalTransformer, + SkipTest, +) from sklearn.utils.estimator_checks import check_estimator, _NotAnArray from sklearn.utils.estimator_checks \ import check_class_weight_balanced_linear_classifier @@ -21,6 +27,8 @@ from sklearn.utils.estimator_checks import check_no_attributes_set_in_init from sklearn.utils.estimator_checks import check_classifier_data_not_an_array from sklearn.utils.estimator_checks import check_regressor_data_not_an_array +from sklearn.utils.estimator_checks import \ + check_estimator_get_tags_default_keys from sklearn.utils.validation import check_is_fitted from sklearn.utils.estimator_checks import check_outlier_corruption from sklearn.utils.fixes import np_version, parse_version @@ -368,6 +376,13 @@ def _more_tags(self): return {'binary_only': True} +class EstimatorMissingDefaultTags(BaseEstimator): + def _get_tags(self): + tags = super()._get_tags().copy() + del tags["allow_nan"] + return tags + + class RequiresPositiveYRegressor(LinearRegression): def fit(self, X, y): @@ -418,8 +433,6 @@ def test_check_estimator(): # check that we have a set_params and can clone msg = "Passing a class was deprecated" assert_raises_regex(TypeError, msg, check_estimator, object) - msg = "object has no attribute '_get_tags'" - assert_raises_regex(AttributeError, msg, check_estimator, object()) msg = ( "Parameter 'p' of estimator 'HasMutableParameters' is of type " "object which is not allowed" @@ -636,6 +649,25 @@ def test_check_regressor_data_not_an_array(): EstimatorInconsistentForPandas()) +def test_check_estimator_get_tags_default_keys(): + estimator = EstimatorMissingDefaultTags() + err_msg = (r"EstimatorMissingDefaultTags._get_tags\(\) is missing entries" + r" for the following default tags: {'allow_nan'}") + assert_raises_regex( + AssertionError, + err_msg, + check_estimator_get_tags_default_keys, + estimator.__class__.__name__, + estimator, + ) + + # noop check when _get_tags is not available + estimator = MinimalTransformer() + check_estimator_get_tags_default_keys( + estimator.__class__.__name__, estimator + ) + + def run_tests_without_pytest(): """Runs the tests in this file without using pytest. """ @@ -677,3 +709,17 @@ def test_xfail_ignored_in_check_estimator(): # Make sure checks marked as xfail are just ignored and not run by # check_estimator(), but still raise a warning. assert_warns(SkipTestWarning, check_estimator, NuSVC()) + + +# FIXME: this test should be uncommented when the checks will be granular +# enough. In 0.24, these tests fail due to low estimator performance. +def test_minimal_class_implementation_checks(): + # Check that third-party library can run tests without inheriting from + # BaseEstimator. + # FIXME + raise SkipTest + minimal_estimators = [ + MinimalTransformer(), MinimalRegressor(), MinimalClassifier() + ] + for estimator in minimal_estimators: + check_estimator(estimator) diff --git a/sklearn/utils/tests/test_tags.py b/sklearn/utils/tests/test_tags.py new file mode 100644 index 0000000000000..f96a4947164c3 --- /dev/null +++ b/sklearn/utils/tests/test_tags.py @@ -0,0 +1,47 @@ +import pytest + +from sklearn.base import BaseEstimator +from sklearn.utils._tags import ( + _DEFAULT_TAGS, + _safe_tags, +) + + +class NoTagsEstimator: + pass + + +class MoreTagsEstimator: + def _more_tags(self): + return {"allow_nan": True} + + +@pytest.mark.parametrize( + "estimator, err_msg", + [ + (BaseEstimator(), "The key xxx is not defined in _get_tags"), + (NoTagsEstimator(), "The key xxx is not defined in _DEFAULT_TAGS"), + ], +) +def test_safe_tags_error(estimator, err_msg): + # Check that safe_tags raises error in ambiguous case. + with pytest.raises(ValueError, match=err_msg): + _safe_tags(estimator, key="xxx") + + +@pytest.mark.parametrize( + "estimator, key, expected_results", + [ + (NoTagsEstimator(), None, _DEFAULT_TAGS), + (NoTagsEstimator(), "allow_nan", _DEFAULT_TAGS["allow_nan"]), + (MoreTagsEstimator(), None, {**_DEFAULT_TAGS, **{"allow_nan": True}}), + (MoreTagsEstimator(), "allow_nan", True), + (BaseEstimator(), None, _DEFAULT_TAGS), + (BaseEstimator(), "allow_nan", _DEFAULT_TAGS["allow_nan"]), + (BaseEstimator(), "allow_nan", _DEFAULT_TAGS["allow_nan"]), + ], +) +def test_safe_tags_no_get_tags(estimator, key, expected_results): + # check the behaviour of _safe_tags when an estimator does not implement + # _get_tags + assert _safe_tags(estimator, key=key) == expected_results