Skip to content

Commit be0c70a

Browse files
authored
Merge pull request #12 from trinhcon/ENH/sklearn_tags_interface
Enh/sklearn tags interface
2 parents 85dbb9f + a93e79d commit be0c70a

File tree

87 files changed

+621
-370
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

87 files changed

+621
-370
lines changed

doc/developers/develop.rst

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -510,7 +510,7 @@ Estimator Tags
510510
Scikit-learn introduced estimator tags in version 0.21. These are annotations
511511
of estimators that allow programmatic inspection of their capabilities, such as
512512
sparse matrix support, supported output types and supported methods. The
513-
estimator tags are a dictionary returned by the method ``_get_tags()``. These
513+
estimator tags are a dictionary returned by the method ``__sklearn_tags__()``. These
514514
tags are used in the common checks run by the
515515
:func:`~sklearn.utils.estimator_checks.check_estimator` function and the
516516
:func:`~sklearn.utils.estimator_checks.parametrize_with_checks` decorator.
@@ -620,23 +620,18 @@ X_types (default=['2darray'])
620620

621621
It is unlikely that the default values for each tag will suit the needs of your
622622
specific estimator. Additional tags can be created or default tags can be
623-
overridden by defining a `_more_tags()` method which returns a dict with the
623+
overridden by defining a `__sklearn_tags__()` method which returns a dict with the
624624
desired overridden tags or new tags. For example::
625625

626626
class MyMultiOutputEstimator(BaseEstimator):
627+
def __sklearn_tags__(self):
628+
more_tags = {'multioutput_only': True,
629+
'non_deterministic': True}
630+
return {**super().__sklearn_tags__(), **more_tags}
627631

628-
def _more_tags(self):
629-
return {'multioutput_only': True,
630-
'non_deterministic': True}
631-
632-
Any tag that is not in `_more_tags()` will just fall-back to the default values
632+
If inherited from BaseEstimator, any tags not explicitly defined will just fall-back to the default values
633633
documented above.
634634

635-
Even if it is not recommended, it is possible to override the method
636-
`_get_tags()`. Note however that **all tags must be present in the dict**. If
637-
any of the keys documented above is not present in the output of `_get_tags()`,
638-
an error will occur.
639-
640635
In addition to the tags, estimators also need to declare any non-optional
641636
parameters to ``__init__`` in the ``_required_parameters`` class attribute,
642637
which is a list or tuple. If ``_required_parameters`` is only

doc/tutorial/text_analytics/skeletons/exercise_02_sentiment.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@
4444
# more useful.
4545
# Fit the pipeline on the training set using grid search for the parameters
4646

47-
# TASK: print the cross-validated scores for the each parameters set
48-
# explored by the grid search
47+
# TASK: print the mean and std for each candidate along with the parameter
48+
# settings for all the candidates explored by grid search.
4949

5050
# TASK: Predict the outcome on the testing set and store it in a variable
5151
# named y_predicted

doc/whats_new/v1.1.rst

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -613,6 +613,17 @@ Changelog
613613

614614
....................
615615

616+
- |Enhancement| _more_tags() and _get_tags() are gathered into
617+
a single method called __sklearn_tags() for simplicity and convenience.
618+
:pr:`12` by :user:`Pasa Ali Aslan <pasaaliaslan>`, :user:`Albert Li <lialber2>`,
619+
:user:`Jack Woodger <jwoodger>`, :user:`Ciaran Hogan <ciaran-h>`, :user:`Victor Ko <VKo232>`,
620+
and :user:`Conroy Trinh <trinhcon>`
621+
622+
:mod:`sklearn.base` and all other modules with classes that implement BaseEstimator
623+
or mixins in sklearn.base.
624+
625+
....................
626+
616627
Code and Documentation Contributors
617628
-----------------------------------
618629

sklearn/base.py

Lines changed: 18 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -343,19 +343,8 @@ def __setstate__(self, state):
343343
except AttributeError:
344344
self.__dict__.update(state)
345345

346-
def _more_tags(self):
347-
return _DEFAULT_TAGS
348-
349-
def _get_tags(self):
350-
collected_tags = {}
351-
for base_class in reversed(inspect.getmro(self.__class__)):
352-
if hasattr(base_class, "_more_tags"):
353-
# need the if because mixins might not have _more_tags
354-
# but might do redundant work in estimators
355-
# (i.e. calling more tags on BaseEstimator multiple times)
356-
more_tags = base_class._more_tags(self)
357-
collected_tags.update(more_tags)
358-
return collected_tags
346+
def __sklearn_tags__(self):
347+
return copy.deepcopy(_DEFAULT_TAGS)
359348

360349
def _check_n_features(self, X, reset):
361350
"""Set the `n_features_in_` attribute, or check against it.
@@ -560,7 +549,7 @@ def _validate_data(
560549
"""
561550
self._check_feature_names(X, reset=reset)
562551

563-
if y is None and self._get_tags()["requires_y"]:
552+
if y is None and self.__sklearn_tags__()["requires_y"]:
564553
raise ValueError(
565554
f"This {self.__class__.__name__} estimator "
566555
"requires y to be passed, but the target y is None."
@@ -666,8 +655,9 @@ def score(self, X, y, sample_weight=None):
666655

667656
return accuracy_score(y, self.predict(X), sample_weight=sample_weight)
668657

669-
def _more_tags(self):
670-
return {"requires_y": True}
658+
def __sklearn_tags__(self):
659+
more_tags = {"requires_y": True}
660+
return {**super().__sklearn_tags__(), **more_tags}
671661

672662

673663
class RegressorMixin:
@@ -721,8 +711,9 @@ def score(self, X, y, sample_weight=None):
721711
y_pred = self.predict(X)
722712
return r2_score(y, y_pred, sample_weight=sample_weight)
723713

724-
def _more_tags(self):
725-
return {"requires_y": True}
714+
def __sklearn_tags__(self):
715+
more_tags = {"requires_y": True}
716+
return {**super().__sklearn_tags__(), **more_tags}
726717

727718

728719
class ClusterMixin:
@@ -752,8 +743,9 @@ def fit_predict(self, X, y=None):
752743
self.fit(X)
753744
return self.labels_
754745

755-
def _more_tags(self):
756-
return {"preserves_dtype": []}
746+
def __sklearn_tags__(self):
747+
more_tags = {"preserves_dtype": []}
748+
return {**super().__sklearn_tags__(), **more_tags}
757749

758750

759751
class BiclusterMixin:
@@ -983,19 +975,21 @@ class MetaEstimatorMixin:
983975
class MultiOutputMixin:
984976
"""Mixin to mark estimators that support multioutput."""
985977

986-
def _more_tags(self):
987-
return {"multioutput": True}
978+
def __sklearn_tags__(self):
979+
more_tags = {"multioutput": True}
980+
return {**super().__sklearn_tags__(), **more_tags}
988981

989982

990983
class _UnstableArchMixin:
991984
"""Mark estimators that are non-determinstic on 32bit or PowerPC"""
992985

993-
def _more_tags(self):
994-
return {
986+
def __sklearn_tags__(self):
987+
more_tags = {
995988
"non_deterministic": (
996989
_IS_32BIT or platform.machine().startswith(("ppc", "powerpc"))
997990
)
998991
}
992+
return {**super().__sklearn_tags__(), **more_tags}
999993

1000994

1001995
def is_classifier(estimator):

sklearn/calibration.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -456,8 +456,8 @@ def predict(self, X):
456456
check_is_fitted(self)
457457
return self.classes_[np.argmax(self.predict_proba(X), axis=1)]
458458

459-
def _more_tags(self):
460-
return {
459+
def __sklearn_tags__(self):
460+
more_tags = {
461461
"_xfail_checks": {
462462
"check_sample_weights_invariance": (
463463
"Due to the cross-validation and sample ordering, removing a sample"
@@ -466,6 +466,7 @@ def _more_tags(self):
466466
),
467467
}
468468
}
469+
return {**super().__sklearn_tags__(), **more_tags}
469470

470471

471472
def _fit_classifier_calibrator_pair(

sklearn/cluster/_affinity_propagation.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -426,8 +426,9 @@ def __init__(
426426
def _pairwise(self):
427427
return self.affinity == "precomputed"
428428

429-
def _more_tags(self):
430-
return {"pairwise": self.affinity == "precomputed"}
429+
def __sklearn_tags__(self):
430+
more_tags = {"pairwise": self.affinity == "precomputed"}
431+
return {**super().__sklearn_tags__(), **more_tags}
431432

432433
def fit(self, X, y=None):
433434
"""Fit the clustering from features, or affinity matrix.

sklearn/cluster/_bicluster.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,8 +189,8 @@ def _k_means(self, data, n_clusters):
189189
labels = model.labels_
190190
return centroid, labels
191191

192-
def _more_tags(self):
193-
return {
192+
def __sklearn_tags__(self):
193+
more_tags = {
194194
"_xfail_checks": {
195195
"check_estimators_dtypes": "raises nan error",
196196
"check_fit2d_1sample": "_scale_normalize fails",
@@ -201,6 +201,7 @@ def _more_tags(self):
201201
"check_fit2d_predict1d": "empty array passed inside",
202202
}
203203
}
204+
return {**super().__sklearn_tags__(), **more_tags}
204205

205206

206207
class SpectralCoclustering(BaseSpectral):

sklearn/cluster/_kmeans.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1389,14 +1389,11 @@ def score(self, X, y=None, sample_weight=None):
13891389
X, sample_weight, x_squared_norms, self.cluster_centers_, self._n_threads
13901390
)[1]
13911391

1392-
def _more_tags(self):
1393-
return {
1394-
"_xfail_checks": {
1395-
"check_sample_weights_invariance": (
1396-
"zero sample_weight is not equivalent to removing samples"
1397-
),
1398-
},
1399-
}
1392+
def __sklearn_tags__(self):
1393+
more_tags = {"_xfail_checks": {"check_sample_weights_invariance": (
1394+
"zero sample_weight is not equivalent to removing samples"
1395+
)}}
1396+
return {**super().__sklearn_tags__(), **more_tags}
14001397

14011398

14021399
def _mini_batch_step(
@@ -2184,11 +2181,15 @@ def predict(self, X, sample_weight=None):
21842181

21852182
return labels
21862183

2187-
def _more_tags(self):
2188-
return {
2189-
"_xfail_checks": {
2190-
"check_sample_weights_invariance": (
2191-
"zero sample_weight is not equivalent to removing samples"
2192-
),
2193-
}
2194-
}
2184+
def __sklearn_tags__(self):
2185+
super_tags = super().__sklearn_tags__()
2186+
xfail_checks = super_tags["_xfail_checks"]
2187+
if xfail_checks:
2188+
super_tags["_xfail_checks"].update({"check_sample_weights_invariance": (
2189+
"zero sample_weight is not equivalent to removing samples"
2190+
)})
2191+
else:
2192+
super_tags.update({"_xfail_checks": {"check_sample_weights_invariance": (
2193+
"zero sample_weight is not equivalent to removing samples"
2194+
)}})
2195+
return super_tags

sklearn/cluster/_spectral.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -776,11 +776,12 @@ def fit_predict(self, X, y=None):
776776
"""
777777
return super().fit_predict(X, y)
778778

779-
def _more_tags(self):
780-
return {
779+
def __sklearn_tags__(self):
780+
more_tags = {
781781
"pairwise": self.affinity
782782
in ["precomputed", "precomputed_nearest_neighbors"]
783783
}
784+
return {**super().__sklearn_tags__(), **more_tags}
784785

785786
# TODO: Remove in 1.1
786787
# mypy error: Decorated property not supported

sklearn/compose/_target.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -286,17 +286,18 @@ def predict(self, X, **predict_params):
286286

287287
return pred_trans
288288

289-
def _more_tags(self):
289+
def __sklearn_tags__(self):
290290
regressor = self.regressor
291291
if regressor is None:
292292
from ..linear_model import LinearRegression
293293

294294
regressor = LinearRegression()
295295

296-
return {
296+
more_tags = {
297297
"poor_score": True,
298298
"multioutput": _safe_tags(regressor, key="multioutput"),
299299
}
300+
return {**super().__sklearn_tags__(), **more_tags}
300301

301302
@property
302303
def n_features_in_(self):

0 commit comments

Comments
 (0)