Skip to content

TST introduce _safe_tags for estimator not inheriting from BaseEstimator #18797

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 53 commits into from
Dec 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
a68194b
TST reintroduce _safe_tags for estimator not inheriting from BaseEsti…
glemaitre Nov 9, 2020
36f1c5c
typo
glemaitre Nov 9, 2020
9e54014
TST implement minimal classifier
glemaitre Nov 11, 2020
764968d
merge master
glemaitre Nov 25, 2020
8f571ac
Add future minimal tests
glemaitre Nov 25, 2020
88b01f1
refactor
glemaitre Nov 25, 2020
05d4225
fix change _get_tags in search and pipeline
glemaitre Nov 25, 2020
5630266
fix nested param name
glemaitre Nov 25, 2020
508d5e0
skip test
glemaitre Nov 25, 2020
610d645
upadte multiclass
glemaitre Nov 25, 2020
7d0a4f6
fix feature selection
glemaitre Nov 25, 2020
eeaf7b0
add default overwrite in safe_tag
glemaitre Nov 25, 2020
fdf1011
iter
glemaitre Nov 25, 2020
b9b2331
iter
glemaitre Nov 25, 2020
4e1f93b
TST safe_tags
glemaitre Nov 25, 2020
7b9b6be
iter
glemaitre Nov 25, 2020
45ada6b
whoops
glemaitre Nov 25, 2020
a51d75b
TST add test for passthtough in pipeline
glemaitre Nov 25, 2020
9874b7c
remove outdated test
glemaitre Nov 25, 2020
007fa09
revert _safe_tags when inheriting from BaseEstimator
glemaitre Nov 26, 2020
6839f58
add test _safe_tags
glemaitre Nov 26, 2020
5d73730
iter
glemaitre Nov 26, 2020
5021fbc
iter
glemaitre Nov 26, 2020
ddf6c79
iter
glemaitre Nov 26, 2020
369e7b7
fix
glemaitre Nov 26, 2020
61cdfca
mark as xfail
glemaitre Nov 26, 2020
5ba8c0c
cover transformer set_params
glemaitre Nov 26, 2020
41bc206
reduce check in set_params
glemaitre Nov 26, 2020
739d084
MNT move _sage_tags in _tags module
glemaitre Nov 26, 2020
04b0018
TST/DOC force estimator to have default tags when implementing _get_tgas
glemaitre Nov 26, 2020
18368ea
update documentation
glemaitre Nov 27, 2020
2e08d2a
slight rework of developers docs
NicolasHug Nov 27, 2020
4aeae95
didn't remove some stuff
NicolasHug Nov 27, 2020
15410b3
move test around
glemaitre Nov 30, 2020
7eaab71
first pass on Nicolas comments
glemaitre Nov 30, 2020
93dc099
add test for check
glemaitre Nov 30, 2020
d53adca
doc
glemaitre Nov 30, 2020
3b671d8
less diff
glemaitre Nov 30, 2020
7130ff6
remove useless parametre
glemaitre Nov 30, 2020
c9f6af4
add comment
glemaitre Nov 30, 2020
642c305
Update sklearn/utils/tests/test_estimator_checks.py
glemaitre Nov 30, 2020
1fd2e57
iter
glemaitre Nov 30, 2020
eb9c41b
PEP8
glemaitre Nov 30, 2020
8227075
Rephrase test comment [ci skip]
ogrisel Dec 1, 2020
07dace1
iter
glemaitre Dec 1, 2020
f3c2b02
Merge remote-tracking branch 'glemaitre/reintroduce_safe_tags' into r…
glemaitre Dec 1, 2020
d76b684
iter
glemaitre Dec 1, 2020
e8fa827
iter
glemaitre Dec 1, 2020
ed26968
update doc
glemaitre Dec 1, 2020
b8ecc41
fix test
glemaitre Dec 1, 2020
bb10791
doc
glemaitre Dec 1, 2020
b19137d
answer ogrisel comments
glemaitre Dec 2, 2020
754539f
more coverage
glemaitre Dec 2, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 18 additions & 9 deletions doc/developers/develop.rst
Original file line number Diff line number Diff line change
Expand Up @@ -511,12 +511,12 @@ Scikit-learn introduced estimator tags in version 0.21. These are annotations
of estimators that allow programmatic inspection of their capabilities, such as
sparse matrix support, supported output types and supported methods. The
estimator tags are a dictionary returned by the method ``_get_tags()``. These
tags are used by the common tests and the
:func:`sklearn.utils.estimator_checks.check_estimator` function to decide what
tests to run and what input data is appropriate. Tags can depend on estimator
parameters or even system architecture and can in general only be determined at
runtime. The default values for the estimator tags are defined in the
``BaseEstimator`` class.
tags are used in the common checks run by the
:func:`~sklearn.utils.estimator_checks.check_estimator` function and the
:func:`~sklearn.utils.estimator_checks.parametrize_with_checks` decorator.
Tags determine which checks to run and what input data is appropriate. Tags
can depend on estimator parameters or even system architecture and can in
general only be determined at runtime.

The current set of estimator tags are:

Expand Down Expand Up @@ -618,16 +618,25 @@ X_types (default=['2darray'])
``'categorical'`` data. For now, the test for sparse data do not make use
of the ``'sparse'`` tag.


To override the tags of a child class, one must define the `_more_tags()`
method and return a dict with the desired tags, e.g::
It is unlikely that the default values for each tag will suit the needs of your
specific estimator. Additional tags can be created or default tags can be
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Additionnal tags can be created"

I thought we agreed not to support that #18797 (comment)? (or that's how I interpret @ogrisel's +1)

Copy link
Member Author

@glemaitre glemaitre Dec 2, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a difference between supporting in _safe_tags and people creating their own tags within their libraries using _more_tags. This is a real need here:

https://github.com/rapidsai/cuml/pull/3113/files#diff-e4bd6eee2eca2b0619b03a5f6ba7b471b4ca03080a6619d0079105d5f13c2165R34-R35

We have something similar in imbalanced-learn since the introduction of tags.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My +1 was to remove the default param to the _safe_tags. I think third-party implementers are free to add other tags in their own estimators if they which. cuML is already doing in in their master branch apparently:

https://github.com/rapidsai/cuml/pull/3113/files

overridden by defining a `_more_tags()` method which returns a dict with the
desired overridden tags or new tags. For example::

class MyMultiOutputEstimator(BaseEstimator):

def _more_tags(self):
return {'multioutput_only': True,
'non_deterministic': True}

Any tag that is not in `_more_tags()` will just fall-back to the default values
documented above.

Even if it is not recommended, it is possible to override the method
`_get_tags()`. Note however that **all tags must be present in the dict**. If
any of the keys documented above is not present in the output of `_get_tags()`,
an error will occur.

In addition to the tags, estimators also need to declare any non-optional
parameters to ``__init__`` in the ``_required_parameters`` class attribute,
which is a list or tuple. If ``_required_parameters`` is only
Expand Down
31 changes: 5 additions & 26 deletions sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,32 +15,15 @@
from . import __version__
from ._config import get_config
from .utils import _IS_32BIT
from .utils._tags import (
_DEFAULT_TAGS,
_safe_tags,
)
from .utils.validation import check_X_y
from .utils.validation import check_array
from .utils._estimator_html_repr import estimator_html_repr
from .utils.validation import _deprecate_positional_args

_DEFAULT_TAGS = {
'non_deterministic': False,
'requires_positive_X': False,
'requires_positive_y': False,
'X_types': ['2darray'],
'poor_score': False,
'no_validation': False,
'multioutput': False,
"allow_nan": False,
'stateless': False,
'multilabel': False,
'_skip_test': False,
'_xfail_checks': False,
'multioutput_only': False,
'binary_only': False,
'requires_fit': True,
'preserves_dtype': [np.float64],
'requires_y': False,
'pairwise': False,
}


@_deprecate_positional_args
def clone(estimator, *, safe=True):
Expand Down Expand Up @@ -858,11 +841,7 @@ def _is_pairwise(estimator):
warnings.filterwarnings('ignore', category=FutureWarning)
has_pairwise_attribute = hasattr(estimator, '_pairwise')
pairwise_attribute = getattr(estimator, '_pairwise', False)

if hasattr(estimator, '_get_tags') and callable(estimator._get_tags):
pairwise_tag = estimator._get_tags().get('pairwise', False)
else:
pairwise_tag = False
pairwise_tag = _safe_tags(estimator, key="pairwise")

if has_pairwise_attribute:
if pairwise_attribute != pairwise_tag:
Expand Down
20 changes: 14 additions & 6 deletions sklearn/feature_selection/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@
from scipy.sparse import issparse, csc_matrix

from ..base import TransformerMixin
from ..utils import check_array
from ..utils import safe_mask
from ..utils import safe_sqr
from ..utils import (
check_array,
safe_mask,
safe_sqr,
)
from ..utils._tags import _safe_tags


class SelectorMixin(TransformerMixin, metaclass=ABCMeta):
Expand Down Expand Up @@ -74,9 +77,14 @@ def transform(self, X):
X_r : array of shape [n_samples, n_selected_features]
The input samples with only the selected features.
"""
tags = self._get_tags()
X = check_array(X, dtype=None, accept_sparse='csr',
force_all_finite=not tags.get('allow_nan', True))
# note: we use _safe_tags instead of _get_tags because this is a
# public Mixin.
X = check_array(
X,
dtype=None,
accept_sparse="csr",
force_all_finite=not _safe_tags(self, key="allow_nan"),
)
mask = self.get_support()
if not mask.any():
warn("No features were selected: either the data is"
Expand Down
6 changes: 4 additions & 2 deletions sklearn/feature_selection/_from_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ._base import SelectorMixin
from ._base import _get_feature_importances
from ..base import BaseEstimator, clone, MetaEstimatorMixin
from ..utils._tags import _safe_tags
from ..utils.validation import check_is_fitted

from ..exceptions import NotFittedError
Expand Down Expand Up @@ -283,5 +284,6 @@ def n_features_in_(self):
return self.estimator_.n_features_in_

def _more_tags(self):
estimator_tags = self.estimator._get_tags()
return {'allow_nan': estimator_tags.get('allow_nan', True)}
return {
'allow_nan': _safe_tags(self.estimator, key="allow_nan")
}
14 changes: 8 additions & 6 deletions sklearn/feature_selection/_rfe.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
import numbers
from joblib import Parallel, effective_n_jobs


from ..utils.metaestimators import if_delegate_has_method
from ..utils.metaestimators import _safe_split
from ..utils._tags import _safe_tags
from ..utils.validation import check_is_fitted
from ..utils.validation import _deprecate_positional_args
from ..utils.fixes import delayed
Expand Down Expand Up @@ -191,7 +193,7 @@ def _fit(self, X, y, step_score=None):
X, y = self._validate_data(
X, y, accept_sparse="csc",
ensure_min_features=2,
force_all_finite=not tags.get('allow_nan', True),
force_all_finite=not tags.get("allow_nan", True),
multi_output=True
)
error_msg = ("n_features_to_select must be either None, a "
Expand Down Expand Up @@ -371,11 +373,11 @@ def predict_log_proba(self, X):
return self.estimator_.predict_log_proba(self.transform(X))

def _more_tags(self):
estimator_tags = self.estimator._get_tags()
return {'poor_score': True,
'allow_nan': estimator_tags.get('allow_nan', True),
'requires_y': True,
}
return {
'poor_score': True,
'allow_nan': _safe_tags(self.estimator, key='allow_nan'),
'requires_y': True,
}


class RFECV(RFE):
Expand Down
7 changes: 3 additions & 4 deletions sklearn/feature_selection/_sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from ._base import SelectorMixin
from ..base import BaseEstimator, MetaEstimatorMixin, clone
from ..utils._tags import _safe_tags
from ..utils.validation import check_is_fitted
from ..model_selection import cross_val_score

Expand Down Expand Up @@ -128,12 +129,11 @@ def fit(self, X, y):
-------
self : object
"""

tags = self._get_tags()
X, y = self._validate_data(
X, y, accept_sparse="csc",
ensure_min_features=2,
force_all_finite=not tags.get('allow_nan', True),
force_all_finite=not tags.get("allow_nan", True),
multi_output=True
)
n_features = X.shape[1]
Expand Down Expand Up @@ -207,8 +207,7 @@ def _get_support_mask(self):
return self.support_

def _more_tags(self):
estimator_tags = self.estimator._get_tags()
return {
'allow_nan': estimator_tags.get('allow_nan', True),
'allow_nan': _safe_tags(self.estimator, key="allow_nan"),
'requires_y': True,
}
9 changes: 3 additions & 6 deletions sklearn/feature_selection/tests/test_rfe.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def get_params(self, deep=True):
def set_params(self, **params):
return self

def _get_tags(self):
return {}
def _more_tags(self):
return {"allow_nan": True}


def test_rfe_features_importance():
Expand Down Expand Up @@ -448,10 +448,7 @@ def test_rfe_importance_getter_validation(importance_getter, err_type,
model.fit(X, y)


@pytest.mark.parametrize("cv", [
None,
5
])
@pytest.mark.parametrize("cv", [None, 5])
def test_rfe_allow_nan_inf_in_x(cv):
iris = load_iris()
X = iris.data
Expand Down
14 changes: 7 additions & 7 deletions sklearn/linear_model/_glm/tests/test_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,13 +419,13 @@ def test_tweedie_regression_family(regression_data):


@pytest.mark.parametrize(
'estimator, value',
[
(PoissonRegressor(), True),
(GammaRegressor(), True),
(TweedieRegressor(power=1.5), True),
(TweedieRegressor(power=0), False)
],
'estimator, value',
[
(PoissonRegressor(), True),
(GammaRegressor(), True),
(TweedieRegressor(power=1.5), True),
(TweedieRegressor(power=0), False),
],
)
def test_tags(estimator, value):
assert estimator._get_tags()['requires_positive_y'] is value
4 changes: 2 additions & 2 deletions sklearn/linear_model/tests/test_coordinate_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def test_lasso_cv_positive_constraint():
(Lars, {}),
(LinearRegression, {}),
(LassoLarsIC, {})]
)
)
def test_model_pipeline_same_as_normalize_true(LinearModel, params):
# Test that linear models (LinearModel) set with normalize set to True are
# doing the same as the same linear model preceeded by StandardScaler
Expand All @@ -315,7 +315,7 @@ def test_model_pipeline_same_as_normalize_true(LinearModel, params):
LinearModel(normalize=False, fit_intercept=True, **params)
)

is_multitask = model_normalize._get_tags().get("multioutput_only", False)
is_multitask = model_normalize._get_tags()["multioutput_only"]

# prepare the data
n_samples, n_features = 100, 2
Expand Down
4 changes: 2 additions & 2 deletions sklearn/model_selection/_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from joblib import Parallel
from ..utils import check_random_state
from ..utils.random import sample_without_replacement
from ..utils._tags import _safe_tags
from ..utils.validation import indexable, check_is_fitted, _check_fit_params
from ..utils.validation import _deprecate_positional_args
from ..utils.metaestimators import if_delegate_has_method
Expand Down Expand Up @@ -433,9 +434,8 @@ def _estimator_type(self):

def _more_tags(self):
# allows cross-validation to see 'precomputed' metrics
estimator_tags = self.estimator._get_tags()
return {
'pairwise': estimator_tags.get('pairwise', False),
'pairwise': _safe_tags(self.estimator, "pairwise"),
"_xfail_checks": {"check_supervised_y_2d":
"DataConversionWarning not caught"},
}
Expand Down
58 changes: 48 additions & 10 deletions sklearn/model_selection/tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,26 @@
import scipy.sparse as sp
import pytest

from sklearn.utils._testing import assert_raises
from sklearn.utils._testing import assert_warns
from sklearn.utils._testing import assert_warns_message
from sklearn.utils._testing import assert_raise_message
from sklearn.utils._testing import assert_array_equal
from sklearn.utils._testing import assert_array_almost_equal
from sklearn.utils._testing import assert_allclose
from sklearn.utils._testing import assert_almost_equal
from sklearn.utils._testing import ignore_warnings
from sklearn.utils._testing import (
assert_raises,
assert_warns,
assert_warns_message,
assert_raise_message,
assert_array_equal,
assert_array_almost_equal,
assert_allclose,
assert_almost_equal,
ignore_warnings,
MinimalClassifier,
MinimalRegressor,
MinimalTransformer,
)
from sklearn.utils._mocking import CheckingClassifier, MockDataFrame

from scipy.stats import bernoulli, expon, uniform

from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.base import clone
from sklearn.base import clone, is_classifier
from sklearn.exceptions import NotFittedError
from sklearn.datasets import make_classification
from sklearn.datasets import make_blobs
Expand Down Expand Up @@ -63,6 +68,7 @@
from sklearn.metrics import make_scorer
from sklearn.metrics import roc_auc_score
from sklearn.metrics import confusion_matrix
from sklearn.metrics import r2_score
from sklearn.metrics.pairwise import euclidean_distances
from sklearn.impute import SimpleImputer
from sklearn.pipeline import Pipeline
Expand Down Expand Up @@ -2079,3 +2085,35 @@ def _fit_param_callable():
'scalar_param': 42,
}
model.fit(X_train, y_train, **fit_params)


# FIXME: Replace this test with a full `check_estimator` once we have API only
# checks.
@pytest.mark.filterwarnings("ignore:The total space of parameters 4 is")
@pytest.mark.parametrize("SearchCV", [GridSearchCV, RandomizedSearchCV])
@pytest.mark.parametrize("Predictor", [MinimalRegressor, MinimalClassifier])
def test_search_cv_using_minimal_compatible_estimator(SearchCV, Predictor):
# Check that third-party library can run tests without inheriting from
# BaseEstimator.
rng = np.random.RandomState(0)
X, y = rng.randn(25, 2), np.array([0] * 5 + [1] * 20)

model = Pipeline([
("transformer", MinimalTransformer()), ("predictor", Predictor())
])

params = {
"transformer__param": [1, 10], "predictor__parama": [1, 10],
}
search = SearchCV(model, params, error_score="raise")
search.fit(X, y)

assert search.best_params_.keys() == params.keys()

y_pred = search.predict(X)
if is_classifier(search):
assert_array_equal(y_pred, 1)
assert search.score(X, y) == pytest.approx(accuracy_score(y, y_pred))
else:
assert_allclose(y_pred, y.mean())
assert search.score(X, y) == pytest.approx(r2_score(y, y_pred))
12 changes: 0 additions & 12 deletions sklearn/model_selection/tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1985,15 +1985,3 @@ def _more_tags(self):
"Set the estimator tags of your estimator instead")
with pytest.warns(FutureWarning, match=msg):
cross_validate(svm, linear_kernel, y, cv=2)

# the _pairwise attribute is present and set to True while the pairwise
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NicolasHug by not being permissive (getting default with _get_tags), we need to remove this test. What do you think about this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are not sure that this case is actually possible in practice.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't followed the introduction of the pairwise tag but since the test assumes that the tag doesn't exist and since we're telling 3rd parties that all tags should exist, I'd say it makes sense to remove the test

# tag is not present
class NoEstimatorTagSVM(SVC):
def _get_tags(self):
tags = super()._get_tags()
del tags['pairwise']
return tags

svm = NoEstimatorTagSVM(kernel='precomputed')
with pytest.warns(FutureWarning, match=msg):
cross_validate(svm, linear_kernel, y, cv=2)
Loading