Skip to content

Commit 255718b

Browse files
glemaitreNicolasHugogrisel
authored
introduce _safe_tags for estimator not inheriting from BaseEstimator (scikit-learn#18797)
Co-authored-by: Nicolas Hug <contact@nicolas-hug.com> Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
1 parent 41f3fd5 commit 255718b

24 files changed

+537
-172
lines changed

doc/developers/develop.rst

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -511,12 +511,12 @@ Scikit-learn introduced estimator tags in version 0.21. These are annotations
511511
of estimators that allow programmatic inspection of their capabilities, such as
512512
sparse matrix support, supported output types and supported methods. The
513513
estimator tags are a dictionary returned by the method ``_get_tags()``. These
514-
tags are used by the common tests and the
515-
:func:`sklearn.utils.estimator_checks.check_estimator` function to decide what
516-
tests to run and what input data is appropriate. Tags can depend on estimator
517-
parameters or even system architecture and can in general only be determined at
518-
runtime. The default values for the estimator tags are defined in the
519-
``BaseEstimator`` class.
514+
tags are used in the common checks run by the
515+
:func:`~sklearn.utils.estimator_checks.check_estimator` function and the
516+
:func:`~sklearn.utils.estimator_checks.parametrize_with_checks` decorator.
517+
Tags determine which checks to run and what input data is appropriate. Tags
518+
can depend on estimator parameters or even system architecture and can in
519+
general only be determined at runtime.
520520

521521
The current set of estimator tags are:
522522

@@ -618,16 +618,25 @@ X_types (default=['2darray'])
618618
``'categorical'`` data. For now, the test for sparse data do not make use
619619
of the ``'sparse'`` tag.
620620

621-
622-
To override the tags of a child class, one must define the `_more_tags()`
623-
method and return a dict with the desired tags, e.g::
621+
It is unlikely that the default values for each tag will suit the needs of your
622+
specific estimator. Additional tags can be created or default tags can be
623+
overridden by defining a `_more_tags()` method which returns a dict with the
624+
desired overridden tags or new tags. For example::
624625

625626
class MyMultiOutputEstimator(BaseEstimator):
626627

627628
def _more_tags(self):
628629
return {'multioutput_only': True,
629630
'non_deterministic': True}
630631

632+
Any tag that is not in `_more_tags()` will just fall-back to the default values
633+
documented above.
634+
635+
Even if it is not recommended, it is possible to override the method
636+
`_get_tags()`. Note however that **all tags must be present in the dict**. If
637+
any of the keys documented above is not present in the output of `_get_tags()`,
638+
an error will occur.
639+
631640
In addition to the tags, estimators also need to declare any non-optional
632641
parameters to ``__init__`` in the ``_required_parameters`` class attribute,
633642
which is a list or tuple. If ``_required_parameters`` is only

sklearn/base.py

Lines changed: 5 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,32 +15,15 @@
1515
from . import __version__
1616
from ._config import get_config
1717
from .utils import _IS_32BIT
18+
from .utils._tags import (
19+
_DEFAULT_TAGS,
20+
_safe_tags,
21+
)
1822
from .utils.validation import check_X_y
1923
from .utils.validation import check_array
2024
from .utils._estimator_html_repr import estimator_html_repr
2125
from .utils.validation import _deprecate_positional_args
2226

23-
_DEFAULT_TAGS = {
24-
'non_deterministic': False,
25-
'requires_positive_X': False,
26-
'requires_positive_y': False,
27-
'X_types': ['2darray'],
28-
'poor_score': False,
29-
'no_validation': False,
30-
'multioutput': False,
31-
"allow_nan": False,
32-
'stateless': False,
33-
'multilabel': False,
34-
'_skip_test': False,
35-
'_xfail_checks': False,
36-
'multioutput_only': False,
37-
'binary_only': False,
38-
'requires_fit': True,
39-
'preserves_dtype': [np.float64],
40-
'requires_y': False,
41-
'pairwise': False,
42-
}
43-
4427

4528
@_deprecate_positional_args
4629
def clone(estimator, *, safe=True):
@@ -858,11 +841,7 @@ def _is_pairwise(estimator):
858841
warnings.filterwarnings('ignore', category=FutureWarning)
859842
has_pairwise_attribute = hasattr(estimator, '_pairwise')
860843
pairwise_attribute = getattr(estimator, '_pairwise', False)
861-
862-
if hasattr(estimator, '_get_tags') and callable(estimator._get_tags):
863-
pairwise_tag = estimator._get_tags().get('pairwise', False)
864-
else:
865-
pairwise_tag = False
844+
pairwise_tag = _safe_tags(estimator, key="pairwise")
866845

867846
if has_pairwise_attribute:
868847
if pairwise_attribute != pairwise_tag:

sklearn/feature_selection/_base.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,12 @@
1212
from scipy.sparse import issparse, csc_matrix
1313

1414
from ..base import TransformerMixin
15-
from ..utils import check_array
16-
from ..utils import safe_mask
17-
from ..utils import safe_sqr
15+
from ..utils import (
16+
check_array,
17+
safe_mask,
18+
safe_sqr,
19+
)
20+
from ..utils._tags import _safe_tags
1821

1922

2023
class SelectorMixin(TransformerMixin, metaclass=ABCMeta):
@@ -74,9 +77,14 @@ def transform(self, X):
7477
X_r : array of shape [n_samples, n_selected_features]
7578
The input samples with only the selected features.
7679
"""
77-
tags = self._get_tags()
78-
X = check_array(X, dtype=None, accept_sparse='csr',
79-
force_all_finite=not tags.get('allow_nan', True))
80+
# note: we use _safe_tags instead of _get_tags because this is a
81+
# public Mixin.
82+
X = check_array(
83+
X,
84+
dtype=None,
85+
accept_sparse="csr",
86+
force_all_finite=not _safe_tags(self, key="allow_nan"),
87+
)
8088
mask = self.get_support()
8189
if not mask.any():
8290
warn("No features were selected: either the data is"

sklearn/feature_selection/_from_model.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from ._base import SelectorMixin
88
from ._base import _get_feature_importances
99
from ..base import BaseEstimator, clone, MetaEstimatorMixin
10+
from ..utils._tags import _safe_tags
1011
from ..utils.validation import check_is_fitted
1112

1213
from ..exceptions import NotFittedError
@@ -283,5 +284,6 @@ def n_features_in_(self):
283284
return self.estimator_.n_features_in_
284285

285286
def _more_tags(self):
286-
estimator_tags = self.estimator._get_tags()
287-
return {'allow_nan': estimator_tags.get('allow_nan', True)}
287+
return {
288+
'allow_nan': _safe_tags(self.estimator, key="allow_nan")
289+
}

sklearn/feature_selection/_rfe.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@
1010
import numbers
1111
from joblib import Parallel, effective_n_jobs
1212

13+
1314
from ..utils.metaestimators import if_delegate_has_method
1415
from ..utils.metaestimators import _safe_split
16+
from ..utils._tags import _safe_tags
1517
from ..utils.validation import check_is_fitted
1618
from ..utils.validation import _deprecate_positional_args
1719
from ..utils.fixes import delayed
@@ -191,7 +193,7 @@ def _fit(self, X, y, step_score=None):
191193
X, y = self._validate_data(
192194
X, y, accept_sparse="csc",
193195
ensure_min_features=2,
194-
force_all_finite=not tags.get('allow_nan', True),
196+
force_all_finite=not tags.get("allow_nan", True),
195197
multi_output=True
196198
)
197199
error_msg = ("n_features_to_select must be either None, a "
@@ -371,11 +373,11 @@ def predict_log_proba(self, X):
371373
return self.estimator_.predict_log_proba(self.transform(X))
372374

373375
def _more_tags(self):
374-
estimator_tags = self.estimator._get_tags()
375-
return {'poor_score': True,
376-
'allow_nan': estimator_tags.get('allow_nan', True),
377-
'requires_y': True,
378-
}
376+
return {
377+
'poor_score': True,
378+
'allow_nan': _safe_tags(self.estimator, key='allow_nan'),
379+
'requires_y': True,
380+
}
379381

380382

381383
class RFECV(RFE):

sklearn/feature_selection/_sequential.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from ._base import SelectorMixin
99
from ..base import BaseEstimator, MetaEstimatorMixin, clone
10+
from ..utils._tags import _safe_tags
1011
from ..utils.validation import check_is_fitted
1112
from ..model_selection import cross_val_score
1213

@@ -128,12 +129,11 @@ def fit(self, X, y):
128129
-------
129130
self : object
130131
"""
131-
132132
tags = self._get_tags()
133133
X, y = self._validate_data(
134134
X, y, accept_sparse="csc",
135135
ensure_min_features=2,
136-
force_all_finite=not tags.get('allow_nan', True),
136+
force_all_finite=not tags.get("allow_nan", True),
137137
multi_output=True
138138
)
139139
n_features = X.shape[1]
@@ -207,8 +207,7 @@ def _get_support_mask(self):
207207
return self.support_
208208

209209
def _more_tags(self):
210-
estimator_tags = self.estimator._get_tags()
211210
return {
212-
'allow_nan': estimator_tags.get('allow_nan', True),
211+
'allow_nan': _safe_tags(self.estimator, key="allow_nan"),
213212
'requires_y': True,
214213
}

sklearn/feature_selection/tests/test_rfe.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ def get_params(self, deep=True):
5656
def set_params(self, **params):
5757
return self
5858

59-
def _get_tags(self):
60-
return {}
59+
def _more_tags(self):
60+
return {"allow_nan": True}
6161

6262

6363
def test_rfe_features_importance():
@@ -448,10 +448,7 @@ def test_rfe_importance_getter_validation(importance_getter, err_type,
448448
model.fit(X, y)
449449

450450

451-
@pytest.mark.parametrize("cv", [
452-
None,
453-
5
454-
])
451+
@pytest.mark.parametrize("cv", [None, 5])
455452
def test_rfe_allow_nan_inf_in_x(cv):
456453
iris = load_iris()
457454
X = iris.data

sklearn/linear_model/_glm/tests/test_glm.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -419,13 +419,13 @@ def test_tweedie_regression_family(regression_data):
419419

420420

421421
@pytest.mark.parametrize(
422-
'estimator, value',
423-
[
424-
(PoissonRegressor(), True),
425-
(GammaRegressor(), True),
426-
(TweedieRegressor(power=1.5), True),
427-
(TweedieRegressor(power=0), False)
428-
],
422+
'estimator, value',
423+
[
424+
(PoissonRegressor(), True),
425+
(GammaRegressor(), True),
426+
(TweedieRegressor(power=1.5), True),
427+
(TweedieRegressor(power=0), False),
428+
],
429429
)
430430
def test_tags(estimator, value):
431431
assert estimator._get_tags()['requires_positive_y'] is value

sklearn/linear_model/tests/test_coordinate_descent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ def test_lasso_cv_positive_constraint():
300300
(Lars, {}),
301301
(LinearRegression, {}),
302302
(LassoLarsIC, {})]
303-
)
303+
)
304304
def test_model_pipeline_same_as_normalize_true(LinearModel, params):
305305
# Test that linear models (LinearModel) set with normalize set to True are
306306
# doing the same as the same linear model preceeded by StandardScaler
@@ -315,7 +315,7 @@ def test_model_pipeline_same_as_normalize_true(LinearModel, params):
315315
LinearModel(normalize=False, fit_intercept=True, **params)
316316
)
317317

318-
is_multitask = model_normalize._get_tags().get("multioutput_only", False)
318+
is_multitask = model_normalize._get_tags()["multioutput_only"]
319319

320320
# prepare the data
321321
n_samples, n_features = 100, 2

sklearn/model_selection/_search.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from joblib import Parallel
3636
from ..utils import check_random_state
3737
from ..utils.random import sample_without_replacement
38+
from ..utils._tags import _safe_tags
3839
from ..utils.validation import indexable, check_is_fitted, _check_fit_params
3940
from ..utils.validation import _deprecate_positional_args
4041
from ..utils.metaestimators import if_delegate_has_method
@@ -433,9 +434,8 @@ def _estimator_type(self):
433434

434435
def _more_tags(self):
435436
# allows cross-validation to see 'precomputed' metrics
436-
estimator_tags = self.estimator._get_tags()
437437
return {
438-
'pairwise': estimator_tags.get('pairwise', False),
438+
'pairwise': _safe_tags(self.estimator, "pairwise"),
439439
"_xfail_checks": {"check_supervised_y_2d":
440440
"DataConversionWarning not caught"},
441441
}

0 commit comments

Comments
 (0)