diff --git a/doc/conftest.py b/doc/conftest.py index d0f49ac087477..b06b1c6bb6d4c 100644 --- a/doc/conftest.py +++ b/doc/conftest.py @@ -138,6 +138,13 @@ def pytest_runtest_setup(item): setup_preprocessing() elif fname.endswith("statistical_inference/unsupervised_learning.rst"): setup_unsupervised_learning() + elif fname.endswith("metadata_routing.rst"): + # TODO: remove this once implemented + # Skip metarouting because is it is not fully implemented yet + raise SkipTest( + "Skipping doctest for metadata_routing.rst because it " + "is not fully implemented yet" + ) rst_files_requiring_matplotlib = [ "modules/partial_dependence.rst", diff --git a/doc/metadata_routing.rst b/doc/metadata_routing.rst new file mode 100644 index 0000000000000..4ae743374b6ab --- /dev/null +++ b/doc/metadata_routing.rst @@ -0,0 +1,204 @@ + +.. _metadata_routing: + +.. TODO: update doc/conftest.py once document is updated and examples run. + +Metadata Routing +================ + +This guide demonstrates how metadata such as ``sample_weight`` can be routed +and passed along to estimators, scorers, and CV splitters through +meta-estimators such as ``Pipeline`` and ``GridSearchCV``. In order to pass +metadata to a method such as ``fit`` or ``score``, the object accepting the +metadata, must *request* it. For estimators and splitters this is done via +``*_requests`` methods, e.g. ``fit_requests(...)``, and for scorers this is +done via ``score_requests`` method of a scorer. For grouped splitters such as +``GroupKFold`` a ``groups`` parameter is requested by default. This is best +demonstrated by the following examples. + +Usage Examples +************** +Here we present a few examples to show different common use-cases. The examples +in this section require the following imports and data:: + + >>> import numpy as np + >>> from sklearn.metrics import make_scorer, accuracy_score + >>> from sklearn.linear_model import LogisticRegressionCV + >>> from sklearn.linear_model import LogisticRegression + >>> from sklearn.model_selection import cross_validate + >>> from sklearn.model_selection import GridSearchCV + >>> from sklearn.model_selection import GroupKFold + >>> from sklearn.feature_selection import SelectKBest + >>> from sklearn.pipeline import make_pipeline + >>> n_samples, n_features = 100, 4 + >>> X = np.random.rand(n_samples, n_features) + >>> y = np.random.randint(0, 2, size=n_samples) + >>> my_groups = np.random.randint(0, 10, size=n_samples) + >>> my_weights = np.random.rand(n_samples) + >>> my_other_weights = np.random.rand(n_samples) + +Weighted scoring and fitting +---------------------------- + +Here ``GroupKFold`` requests ``groups`` by default. However, we need to +explicitly request weights in ``make_scorer`` and for ``LogisticRegressionCV``. +Both of these *consumers* understand the meaning of the key +``"sample_weight"``:: + + >>> weighted_acc = make_scorer(accuracy_score).score_requests( + ... sample_weight=True + ... ) + >>> lr = LogisticRegressionCV( + ... cv=GroupKFold(), scoring=weighted_acc, + ... ).fit_requests(sample_weight=True) + >>> cv_results = cross_validate( + ... lr, + ... X, + ... y, + ... cv=GroupKFold(), + ... props={"sample_weight": my_weights, "groups": my_groups}, + ... scoring=weighted_acc, + ... ) + +Note that in this example, ``my_weights`` is passed to both the scorer and +``~linear_model.LogisticRegressionCV``. + +Error handling: if ``props={'sample_weigh': my_weights, ...}`` were passed +(note the typo), cross_validate would raise an error, since 'sample_weigh' was +not requested by any of its children. + +Weighted scoring and unweighted fitting +--------------------------------------- + +All scikit-learn estimators requires weights to be explicitly requested or not +requested. To perform a unweighted fit, we need to configure +:class:`~linear_model.LogisticRegressionCV` to not request sample weights, so +that :func:`~model_selection.cross_validate` does not pass the weights along:: + + >>> weighted_acc = make_scorer(accuracy_score).score_requests( + ... sample_weight=True + ... ) + >>> lr = LogisticRegressionCV( + ... cv=GroupKFold(), scoring=weighted_acc, + ... ).fit_requests(sample_weight=False) + >>> cv_results = cross_validate( + ... lr, + ... X, + ... y, + ... cv=GroupKFold(), + ... props={"sample_weight": my_weights, "groups": my_groups}, + ... scoring=weighted_acc, + ... ) + +If :class:`~linear_model.LogisticRegressionCV` did not call ``fit_requests``, +:func:`~model_selection.cross_validate` will raise an error because weights is +passed in but :class:`~linear_model.LogisticRegressionCV` was not configured to +recognize the weights. + +Unweighted feature selection +---------------------------- + +Unlike ``LogisticRegressionCV``, ``SelectKBest`` doesn't accept weights and +therefore `"sample_weight"` is not routed to it:: + + >>> weighted_acc = make_scorer(accuracy_score).score_requests( + ... sample_weight=True + ... ) + >>> lr = LogisticRegressionCV( + ... cv=GroupKFold(), scoring=weighted_acc, + ... ).fit_requests(sample_weight=True) + >>> sel = SelectKBest(k=2) + >>> pipe = make_pipeline(sel, lr) + >>> cv_results = cross_validate( + ... pipe, + ... X, + ... y, + ... cv=GroupKFold(), + ... props={"sample_weight": my_weights, "groups": my_groups}, + ... scoring=weighted_acc, + ... ) + +Different scoring and fitting weights +------------------------------------- + +Despite ``make_scorer`` and ``LogisticRegressionCV`` both expecting a key +``sample_weight``, we can use aliases to pass different weights to different +consumers. In this example, we pass ``scoring_weight`` to the scorer, and +``fitting_weight`` to ``LogisticRegressionCV``:: + + >>> weighted_acc = make_scorer(accuracy_score).score_requests( + ... sample_weight="scoring_weight" + ... ) + >>> lr = LogisticRegressionCV( + ... cv=GroupKFold(), scoring=weighted_acc, + ... ).fit_requests(sample_weight="fitting_weight") + >>> cv_results = cross_validate( + ... lr, + ... X, + ... y, + ... cv=GroupKFold(), + ... props={ + ... "scoring_weight": my_weights, + ... "fitting_weight": my_other_weights, + ... "groups": my_groups, + ... }, + ... scoring=weighted_acc, + ... ) + +API Interface +************* + +A *consumer* is an object (estimator, meta-estimator, scorer, splitter) which +accepts and uses some metadata in at least one of their methods (``fit``, +``predict``, ``inverse_transform``, ``transform``, ``score``, ``split``). +Meta-estimators which only forward the metadata to other objects (the child +estimator, scorers, or splitters) and don't use the metadata themselves are not +consumers. (Meta)Estimators which route metadata to other objects are routers. +An (meta)estimator can be a consumer and a router at the same time. +(Meta)Estimators and splitters expose a ``*_requests`` method for each method +which accepts at least one metadata. For instance, if an estimator supports +``sample_weight`` in ``fit`` and ``score``, it exposes +``estimator.fit_requests(sample_weight=value)`` and +``estimator.score_requests(sample_weight=value)``. Here ``value`` can be: + +- ``RequestType.REQUESTED`` or ``True``: method requests a ``sample_weight``. + This means if the metadata is provided, it will be used, otherwise no error + is raised. +- ``RequestType.UNREQUESTED`` or ``False``: method does not request a + ``sample_weight``. +- ``RequestType.ERROR_IF_PASSED`` or ``None``: router will raise an error if + ``sample_weight`` is passed. This is in almost all cases the default value + when an object is instantiated and ensures the user sets the metadata + requests explicitly when a metadata is passed. The only exception are + ``Group*Fold`` splitters. +- ``"param_name"``: if this estimator is used in a meta-estimator, the + meta-estimator should forward ``"param_name"`` as ``sample_weight`` to this + estimator. This means the mapping between the metadata required by the + object, e.g. ``sample_weight`` and what is provided by the user, e.g. + ``my_weights`` is done at the router level, and not by the object, e.g. + estimator, itself. + +For the scorers, this is done the same way, using ``.score_requests`` method. + +If a metadata, e.g. ``sample_weight``, is passed by the user, the metadata +request for all objects which potentially can accept ``sample_weight`` should +be set by the user, otherwise an error is raised by the router object. For +example, the following code would raise, since it hasn't been explicitly set +whether ``sample_weight`` should be passed to the estimator's scorer or not:: + + >>> param_grid = {"C": [0.1, 1]} + >>> lr = LogisticRegression().fit_requests(sample_weight=True) + >>> try: + ... GridSearchCV( + ... estimator=lr, param_grid=param_grid + ... ).fit(X, y, sample_weight=my_weights) + ... except ValueError as e: + ... print(e) + sample_weight is passed but is not explicitly set as requested or not. In + method: score + +The issue can be fixed by explicitly setting the request value:: + + >>> lr = LogisticRegression().fit_requests( + ... sample_weight=True + ... ).score_requests(sample_weight=False) diff --git a/doc/user_guide.rst b/doc/user_guide.rst index 7d48934d32727..7e656567f3249 100644 --- a/doc/user_guide.rst +++ b/doc/user_guide.rst @@ -30,3 +30,4 @@ User Guide computing.rst modules/model_persistence.rst common_pitfalls.rst + metadata_routing.rst diff --git a/examples/plot_metadata_routing.py b/examples/plot_metadata_routing.py new file mode 100644 index 0000000000000..eab6ab40b26a8 --- /dev/null +++ b/examples/plot_metadata_routing.py @@ -0,0 +1,642 @@ +""" +================ +Metadata Routing +================ + +.. currentmodule:: sklearn + +This document shows how you can use the metadata routing mechanism in +scikit-learn to route metadata through meta-estimators to the estimators using +them. To better understand the rest of the document, we need to introduce two +concepts: routers and consumers. A router is an object, in most cases a +meta-estimator, which routes given data and metadata to other objects and +estimators. A consumer, on the other hand, is an object which accepts and uses +a certain given metadata. For instance, an estimator taking into account +``sample_weight`` is a consumer of ``sample_weight``. It is possible for an +object to be both a router and a consumer. For instance, a meta-estimator may +take into account ``sample_weight`` in certain calculations, but it may also +route it to the underlying estimator. + +First a few imports and some random data for the rest of the script. +""" +# %% + +import numpy as np +import warnings +from sklearn.base import BaseEstimator +from sklearn.base import ClassifierMixin +from sklearn.base import RegressorMixin +from sklearn.base import MetaEstimatorMixin +from sklearn.base import TransformerMixin +from sklearn.base import clone +from sklearn.utils.metadata_requests import RequestType +from sklearn.utils.metadata_requests import metadata_request_factory +from sklearn.utils.metadata_requests import MetadataRouter +from sklearn.utils.validation import check_is_fitted +from sklearn.linear_model import LinearRegression + +N, M = 100, 4 +X = np.random.rand(N, M) +y = np.random.randint(0, 2, size=N) +my_groups = np.random.randint(0, 10, size=N) +my_weights = np.random.rand(N) +my_other_weights = np.random.rand(N) + +# %% +# Estimators +# ---------- +# Here we demonstrate how an estimator can expose the required API to support +# metadata routing as a consumer. Imagine a simple classifier accepting +# ``sample_weight`` as a metadata on its ``fit`` and ``groups`` in its +# ``predict`` method. We add two constructor arguments to helps us check +# whether an expected metadata is given or not. This is a minimal scikit-learn +# compatible classifier: + + +class ExampleClassifier(ClassifierMixin, BaseEstimator): + def __init__(self, sample_weight_is_none=True, groups_is_none=True): + self.sample_weight_is_none = sample_weight_is_none + self.groups_is_none = groups_is_none + + def fit(self, X, y, sample_weight=None): + if (sample_weight is None) != self.sample_weight_is_none: + raise ValueError( + "sample_weight's value and sample_weight_is_none disagree!" + ) + # all classifiers need to expose a classes_ attribute once they're fit. + self.classes_ = np.array([0, 1]) + return self + + def predict(self, X, groups=None): + if (groups is None) != self.groups_is_none: + raise ValueError("groups's value and groups_is_none disagree!") + # return a constant value of 1, not a very smart classifier! + return np.ones(len(X)) + + +# %% +# The above estimator now has all it needs to consume metadata. This is done +# by some magic done in :class:`~base.BaseEstimator`. There are now three +# methods exposed by the above class: ``fit_requests``, ``predict_requests``, +# and ``get_metadata_request``. +# +# By default, no metadata is requested, which we can see as: + +ExampleClassifier().get_metadata_request() + +# %% +# The above output means that ``sample_weight`` and ``groups`` are not +# requested, but if a router is given those metadata, it should raise an error, +# since the user has not explicitly set whether they are required or not. The +# same is true for ``sample_weight`` in ``score`` method, which is inherited +# from :class:`~base.ClassifierMixin`. In order to explicitly set request +# values for those metadata, we can use these methods: + +est = ( + ExampleClassifier().fit_requests(sample_weight=False).predict_requests(groups=True) +) +est.get_metadata_request() + +# %% +# As you can see, now the two metadata have explicit request values, one is +# requested and the other one is not. Instead of ``True`` and ``False``, we +# could also use the :class:`~sklearn.utils.metadata_requests.RequestType` +# values. + +est = ( + ExampleClassifier() + .fit_requests(sample_weight=RequestType.UNREQUESTED) + .predict_requests(groups=RequestType.REQUESTED) +) +est.get_metadata_request() + +# %% +# Please note that as long as the above estimator is not used in another +# meta-estimator, the user does not need to set any requests for the metadata. +# A simple usage of the above estimator would work as expected. Remember that +# ``{sample_weight, groups}_is_none`` are for testing/demonstration purposes +# and don't have anything to do with the routing mechanisms. + +est = ExampleClassifier(sample_weight_is_none=False, groups_is_none=False) +est.fit(X, y, sample_weight=my_weights) +est.predict(X[:3, :], groups=my_groups) + +# %% +# Now let's have a meta-estimator, which doesn't do much other than routing the +# metadata correctly. + + +class MetaClassifier(MetaEstimatorMixin, ClassifierMixin, BaseEstimator): + def __init__(self, estimator): + self.estimator = estimator + + def fit(self, X, y, **fit_params): + if self.estimator is None: + raise ValueError("estimator cannot be None!") + + # meta-estimators are responsible for validating the given metadata + metadata_request_factory(self).fit.validate_metadata( + ignore_extras=False, kwargs=fit_params + ) + # we can use provided utility methods to map the given metadata to what + # is required by the underlying estimator + fit_params_ = metadata_request_factory(self.estimator).fit.get_input( + ignore_extras=False, kwargs=fit_params + ) + self.estimator_ = clone(self.estimator).fit(X, y, **fit_params_) + self.classes_ = self.estimator_.classes_ + return self + + def predict(self, X, **predict_params): + check_is_fitted(self) + # same as in ``fit``, we validate the given metadata + metadata_request_factory(self).predict.validate_metadata( + ignore_extras=False, kwargs=predict_params + ) + # and then prepare the input to the underlying ``predict`` method. + predict_params_ = metadata_request_factory(self.estimator_).predict.get_input( + ignore_extras=False, kwargs=predict_params + ) + return self.estimator_.predict(X, **predict_params_) + + def get_metadata_request(self): + router = MetadataRouter().add( + self.estimator, mapping="one-to-one", overwrite=False, mask=True + ) + return router.get_metadata_request() + + +# %% +# Let's break down different parts of the above code. +# +# First, the :method:`~utils.metadata_requests.metadata_request_factory` takes +# an object from which a :class:`~utils.metadata_requests.MetadataRequest` can +# be constructed. This may be an estimator, or a dictionary representing a +# ``MetadataRequest`` object. If an estimator is given, it tries to call the +# estimator and construct the object from that, and if the estimator doesn't +# have such a method, then a default empty ``MetadataRequest`` is returned. +# +# Then in each method, we use the corresponding +# :method:`~utils.metadata_requests.MethodMetadataRequest.get_input` to +# construct a dictionary of the form ``{"metadata": value}`` to pass to the +# underlying estimator's method. Please note that since in this example the +# meta-estimator does not consume any of the given metadata itself, and there +# is only one object to which the metadata is passed, we have +# ``ignore_extras=False`` which means passed metadata are also validated in the +# sense that it will be checked if anything extra is given. This is to avoid +# silent bugs, and this is how it will work: + +est = MetaClassifier( + estimator=ExampleClassifier(sample_weight_is_none=False).fit_requests( + sample_weight=True + ) +) +est.fit(X, y, sample_weight=my_weights) + +# %% +# Note that the above example checks that ``sample_weight`` is correctly passed to +# ``ExampleClassifier``, or else it would have raised: + +try: + est.fit(X, y) +except ValueError as e: + print(e) + +# %% +# If we pass an unknown metadata, it will be caught: +try: + est.fit(X, y, test=my_weights) +except ValueError as e: + print(e) + +# %% +# And if we pass something which is not explicitly requested: +try: + est.fit(X, y, sample_weight=my_weights).predict(X, groups=my_groups) +except ValueError as e: + print(e) + +# %% +# Also, if we explicitly say it's not requested, but pass it: +est = MetaClassifier( + estimator=ExampleClassifier(sample_weight_is_none=False) + .fit_requests(sample_weight=True) + .predict_requests(groups=False) +) +try: + est.fit(X, y, sample_weight=my_weights).predict(X[:3, :], groups=my_groups) +except ValueError as e: + print(e) + +# %% +# In order to understand the above implementation of ``get_metadata_request``, +# we need to also introduce an aliased metadata. This is when an estimator +# requests a metadata with a different name than the default value. For +# instance, in a setting where there are two estimators in a pipeline, one +# could request ``sample_weight1`` and the other ``sample_weight2``. Note that +# this doesn't change what the estimator expects, it only tells the +# meta-estimator how to map provided metadata to what's required. Here's an +# example, where we pass ``aliased_sample_weight`` to the meta-estimator, but +# the meta-estimator understands that ``aliased_sample_weight`` is an alias for +# ``sample_weight``, and passes it as ``sample_weight`` to the underlying +# estimator: +est = MetaClassifier( + estimator=ExampleClassifier(sample_weight_is_none=False).fit_requests( + sample_weight="aliased_sample_weight" + ) +) +est.fit(X, y, aliased_sample_weight=my_weights) + +# %% +# And passing ``sample_weight`` here will fail since it is requested with an alias: +try: + est.fit(X, y, sample_weight=my_weights) +except ValueError as e: + print(e) + +# %% +# This leads us to the ``get_metadata_request``. The way routing works in +# scikit-learn is that consumers request what they need, and routers pass that +# along. But another thing a router does, is that it also exposes what it +# requires so that it can be used as a consumer inside another router, e.g. a +# pipeline inside a grid search object. However, routers (e.g. our +# meta-estimator) don't expose the mapping, and only expose what's required for +# them to do their job. In the above example, it looks like the following: +est.get_metadata_request()["fit"] + +# %% +# As you can see, the only metadata requested for method ``fit`` is +# ``"aliased_sample_weight"``. This information is enough for another +# meta-estimator/router to know what needs to be passed to ``est``. In other +# words, ``sample_weight`` is *masked* . The ``MetadataRouter`` class enables us to +# easily create the routing object which would create the output we need for +# our ``get_metadata_request``. In the above implementation, +# ``mapping="one-to-one"`` means all requests are mapped one to one from the +# sub-estimator to the meta-estimator's methods, and ``mask=True`` indicates +# that the requests should be masked, as explained. Masking is necessary since +# it's the meta-estimator which does the mapping between the alias and the +# original metadata name. Without it, having ``est`` in another meta-estimator +# would break the routing. Imagine this example: + +meta_est = MetaClassifier(estimator=est).fit(X, y, aliased_sample_weight=my_weights) + +# %% +# In the above example, this is how each ``fit`` method will call the +# sub-estimator's ``fit``:: +# +# meta_est.fit(X, y, aliased_sample_weight=my_weights): +# ... # this estimator (est), expects aliased_sample_weight as seen above +# self.estimator_.fit(X, y, aliased_sample_weight=aliased_sample_weight): +# ... # est passes aliased_sample_weight's value as sample_weight, +# # which is expected by the sub-estimator +# self.estimator_.fit(X, y, sample_weight=aliased_sample_weight) +# ... + +# %% +# Router and Consumer +# ------------------- +# To show how a slightly more complicated case would work, consider a case +# where a meta-estimator uses some metadata, but it also routes them to an +# underlying estimator. In this case, this meta-estimator is a consumer and a +# router at the same time. This is how we can implement one, and it is very +# similar to what we had before, with a few tweaks. + + +class RouterConsumerClassifier(MetaEstimatorMixin, ClassifierMixin, BaseEstimator): + def __init__(self, estimator, sample_weight_is_none=True): + self.estimator = estimator + self.sample_weight_is_none = sample_weight_is_none + + def fit(self, X, y, sample_weight, **fit_params): + if self.estimator is None: + raise ValueError("estimator cannot be None!") + + if (sample_weight is None) != self.sample_weight_is_none: + raise ValueError( + "sample_weight's value and sample_weight_is_none disagree!" + ) + + if sample_weight is not None: + fit_params["sample_weight"] = sample_weight + + # meta-estimators are responsible for validating the given metadata + metadata_request_factory(self).fit.validate_metadata( + ignore_extras=False, self_metadata=super(), kwargs=fit_params + ) + # we can use provided utility methods to map the given metadata to what + # is required by the underlying estimator + fit_params_ = metadata_request_factory(self.estimator).fit.get_input( + ignore_extras=False, kwargs=fit_params + ) + self.estimator_ = clone(self.estimator).fit(X, y, **fit_params_) + self.classes_ = self.estimator_.classes_ + return self + + def predict(self, X, **predict_params): + check_is_fitted(self) + # same as in ``fit``, we validate the given metadata + metadata_request_factory(self).predict.validate_metadata( + ignore_extras=False, kwargs=predict_params + ) + # and then prepare the input to the underlying ``predict`` method. + predict_params_ = metadata_request_factory(self.estimator_).predict.get_input( + ignore_extras=False, kwargs=predict_params + ) + return self.estimator_.predict(X, **predict_params_) + + def get_metadata_request(self): + router = ( + MetadataRouter() + .add(super(), mapping="one-to-one", overwrite=False, mask=False) + .add(self.estimator, mapping="one-to-one", overwrite="smart", mask=True) + ) + return router.get_metadata_request() + + +# %% +# The two key parts where the above estimator differs from our previous +# meta-estimator is validation in ``fit``, and generating routing data in +# ``get_metadata_request``. In ``fit``, we pass ``self_metadata=super()`` to +# ``validate_metadata``. This is important since consumers don't validate how +# metadata is passed to them, it's only done by routers. In this case, this +# means validation should be done for the metadata consumed by the +# sub-estimator, but not for the metadata consumed by the meta-estimator +# itself. +# +# In ``get_metadata_request``, we add what's consumed by this meta-estimator +# without masking them, before adding what's requested by the sub-estimator. +# Passing ``super()`` here means only what's explicitly mentioned in the +# methods' signature is considered as metadata consumed by this estimator; in +# this case fit's ``sample_weight``. Let's see what the routing metadata looks like with +# different settings: + +# %% +# no metadata requested +est = RouterConsumerClassifier(estimator=ExampleClassifier()) +est.get_metadata_request()["fit"] + + +# %% +# ``sample_weight`` requested by child estimator +est = RouterConsumerClassifier( + estimator=ExampleClassifier().fit_requests(sample_weight=True) +) +est.get_metadata_request()["fit"] +# %% +# ``sample_weight`` requested by meta-estimator +est = RouterConsumerClassifier(estimator=ExampleClassifier()).fit_requests( + sample_weight=True +) +est.get_metadata_request()["fit"] + +# %% +# As you can see, the last two are identical, which is fine since that's what a +# meta-estimator having ``RouterConsumerClassifier`` as a sub-estimator needs. +# The situation is different if we use named aliases: +# +# Aliased on both +est = RouterConsumerClassifier( + sample_weight_is_none=False, + estimator=ExampleClassifier(sample_weight_is_none=False).fit_requests( + sample_weight="clf_sample_weight" + ), +).fit_requests(sample_weight="meta_clf_sample_weight") +est.get_metadata_request()["fit"] + +# %% +# However, ``fit`` of the meta-estimator only needs the alias for the +# sub-estimator: +est.fit(X, y, sample_weight=my_weights, clf_sample_weight=my_other_weights) + +# %% +# Alias only on the sub-estimator. This is useful if we don't want the +# meta-estimator to use the metadata, and we only want the metadata to be used +# by the sub-estimator. +est = RouterConsumerClassifier( + estimator=ExampleClassifier().fit_requests(sample_weight="aliased_sample_weight") +).fit_requests(sample_weight=True) +est.get_metadata_request()["fit"] + +# %% +# Alias only on the meta-estimator. This example raises an error since there +# will be two conflicting values for routing ``sample_weight``. +est = RouterConsumerClassifier( + estimator=ExampleClassifier().fit_requests(sample_weight=True) +).fit_requests(sample_weight="aliased_sample_weight") +try: + est.get_metadata_request()["fit"] +except ValueError as e: + print(e) + + +# %% +# Simple Pipeline +# --------------- +# A slightly more complicated use-case is a meta-estimator which does something +# similar to the ``Pipeline``. Here is a meta-estimator, which accepts a +# transformer and a classifier, and applies the transformer before running the +# classifier. + + +class SimplePipeline(ClassifierMixin, BaseEstimator): + _required_parameters = ["estimator"] + + def __init__(self, transformer, classifier): + self.transformer = transformer + self.classifier = classifier + + def fit(self, X, y, **fit_params): + metadata_request_factory(self).fit.validate_metadata(kwargs=fit_params) + + transformer_fit_params = metadata_request_factory( + self.transformer + ).fit.get_input(ignore_extras=True, kwargs=fit_params) + transformer_transform_params = metadata_request_factory( + self.transformer + ).transform.get_input(ignore_extras=True, kwargs=fit_params) + self.transformer_ = clone(self.transformer).fit(X, y, **transformer_fit_params) + X_transformed = self.transformer_.transform(X, **transformer_transform_params) + + classifier_fit_params = metadata_request_factory(self.classifier).fit.get_input( + ignore_extras=True, kwargs=fit_params + ) + self.classifier_ = clone(self.classifier).fit( + X_transformed, y, **classifier_fit_params + ) + return self + + def predict(self, X, **predict_params): + metadata_request_factory(self).predict.validate_metadata(kwargs=predict_params) + + transformer_transform_params = metadata_request_factory( + self.transformer + ).transform.get_input(ignore_extras=True, kwargs=predict_params) + X_transformed = self.transformer_.transform(X, **transformer_transform_params) + + classifier_predict_params = metadata_request_factory( + self.classifier + ).predict.get_input(ignore_extras=True, kwargs=predict_params) + return self.classifier_.predict(X_transformed, **classifier_predict_params) + + def get_metadata_request(self): + router = ( + MetadataRouter() + .add( + self.transformer, + mapping={ + "fit": ["fit", "transform"], + "predict": "transform", + }, + overwrite="smart", + mask=True, + ) + .add(self.classifier, mapping="one-to-one", overwrite="smart", mask=True) + ) + return router.get_metadata_request() + + +# %% +# As you can see, we use the transformer's ``transform`` and ``fit`` methods in +# ``fit``, and its ``transform`` method in ``predict``, and that's what you see +# implemented in the routing structure of the pipeline class. In order to test +# the above pipeline, let's add an example transformer. + + +class ExampleTransformer(TransformerMixin, BaseEstimator): + def fit(self, X, y, sample_weight=None): + if sample_weight is None: + raise ValueError("sample_weight is None!") + return self + + def transform(self, X, groups=None): + if groups is None: + raise ValueError("groups is None!") + return X + + +# %% +# Now we can test our pipeline, and see if metadata is correctly passed around. +# This example uses our simple pipeline, and our transformer, and our +# consumer+router estimator which uses our simple classifier. + +est = SimplePipeline( + transformer=ExampleTransformer() + # we transformer's fit to receive sample_weight + .fit_requests(sample_weight=True) + # we want transformer's transform to receive groups + .transform_requests(groups=True), + classifier=RouterConsumerClassifier( + sample_weight_is_none=False, + estimator=ExampleClassifier(sample_weight_is_none=False) + # we want this sub-estimator to receive sample_weight in fit + .fit_requests(sample_weight=True) + # but not groups in predict + .predict_requests(groups=False), + ).fit_requests( + # and we want the meta-estimator to receive sample_weight as well + sample_weight=True + ), +) +est.fit(X, y, sample_weight=my_weights, groups=my_groups).predict( + X[:3], groups=my_groups +) + +# %% +# Deprecation / Default Value Change +# ----------------------------------- +# In this section we show how one should handle the case where a router becomes +# also a consumer, especially when it consumes the same metadata as its +# sub-estimator. In this case, a warning should be raised for a while, to let +# users know the behavior is changed from previous versions. + + +class MetaRegressor(MetaEstimatorMixin, RegressorMixin, BaseEstimator): + def __init__(self, estimator): + self.estimator = estimator + + def fit(self, X, y, **fit_params): + metadata_request_factory(self).fit.validate_metadata( + ignore_extras=False, self_metadata=super(), kwargs=fit_params + ) + fit_params_ = metadata_request_factory(self.estimator).fit.get_input( + ignore_extras=False, kwargs=fit_params + ) + self.estimator_ = clone(self.estimator).fit(X, y, **fit_params_) + + def get_metadata_request(self): + router = MetadataRouter().add( + self.estimator, mapping="one-to-one", overwrite=False, mask=True + ) + return router.get_metadata_request() + + +# %% +# As explained above, this is now a valid usage: + +reg = MetaRegressor(estimator=LinearRegression().fit_requests(sample_weight=True)) +reg.fit(X, y, sample_weight=my_weights) + + +# %% +# Now imagine we further develop ``MetaRegressor`` and it now also *consumes* +# ``sample_weight``: + + +class SampledMetaRegressor(MetaEstimatorMixin, RegressorMixin, BaseEstimator): + __metadata_request__fit = {"sample_weight": RequestType.WARN} + + def __init__(self, estimator): + self.estimator = estimator + + def fit(self, X, y, sample_weight=None, **fit_params): + if sample_weight is not None: + fit_params["sample_weight"] = sample_weight + metadata_request_factory(self).fit.validate_metadata( + ignore_extras=False, self_metadata=super(), kwargs=fit_params + ) + estimator_fit_params = metadata_request_factory(self.estimator).fit.get_input( + ignore_extras=True, kwargs=fit_params + ) + self.estimator_ = clone(self.estimator).fit(X, y, **estimator_fit_params) + + def get_metadata_request(self): + router = ( + MetadataRouter() + .add(super(), mapping="one-to-one", overwrite=False, mask=False) + .add(self.estimator, mapping="one-to-one", overwrite="smart", mask=True) + ) + return router.get_metadata_request() + + +# %% +# The above implementation is almost no different than ``MetaRegressor``, and +# because of the default request value defined in `__metadata_request__fit`` +# there is a warning raised. + +with warnings.catch_warnings(record=True) as record: + SampledMetaRegressor( + estimator=LinearRegression().fit_requests(sample_weight=False) + ).fit(X, y, sample_weight=my_weights) +for w in record: + print(w.message) + + +# %% +# When an estimator suports a metadata which wasn't supported before, the +# following pattern can be used to warn the users about it. + + +class ExampleRegressor(RegressorMixin, BaseEstimator): + __metadata_request__fit = {"sample_weight": RequestType.WARN} + + def fit(self, X, y, sample_weight=None): + return self + + def predict(self, X): + return np.zeros(shape=(len(X))) + + +with warnings.catch_warnings(record=True) as record: + MetaRegressor(estimator=ExampleRegressor()).fit(X, y, sample_weight=my_weights) +for w in record: + print(w.message) diff --git a/sklearn/base.py b/sklearn/base.py index 390a458a0962c..ced54396f7bde 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -27,6 +27,7 @@ from .utils.validation import _generate_get_feature_names_out from .utils.validation import check_is_fitted from .utils._estimator_html_repr import estimator_html_repr +from .utils.metadata_requests import _MetadataRequester from .utils.validation import _get_feature_names @@ -86,7 +87,13 @@ def clone(estimator, *, safe=True): new_object_params = estimator.get_params(deep=False) for name, param in new_object_params.items(): new_object_params[name] = clone(param, safe=False) + new_object = klass(**new_object_params) + try: + new_object._metadata_request = copy.deepcopy(estimator._metadata_request) + except AttributeError: + pass + params_set = new_object.get_params(deep=False) # quick sanity check of the parameters of the clone @@ -151,7 +158,7 @@ def _pprint(params, offset=0, printer=repr): return lines -class BaseEstimator: +class BaseEstimator(_MetadataRequester): """Base class for all estimators in scikit-learn. Notes diff --git a/sklearn/externals/_sentinels.py b/sklearn/externals/_sentinels.py new file mode 100644 index 0000000000000..662a82864ca8d --- /dev/null +++ b/sklearn/externals/_sentinels.py @@ -0,0 +1,82 @@ +# type: ignore +""" +Copied from https://github.com/taleinat/python-stdlib-sentinels +PEP-0661: Status: Draft +""" +import sys as _sys +from typing import Optional + + +__all__ = ["sentinel"] + + +def sentinel( + name: str, + repr: Optional[str] = None, + module: Optional[str] = None, +): + """Create a unique sentinel object. + + *name* should be the fully-qualified name of the variable to which the + return value shall be assigned. + + *repr*, if supplied, will be used for the repr of the sentinel object. + If not provided, "" will be used (with any leading class names + removed). + + *module*, if supplied, will be used as the module name for the purpose + of setting a unique name for the sentinels unique class. The class is + set as an attribute of this name on the "sentinels" module, so that it + may be found by the pickling mechanism. In most cases, the module name + does not need to be provided, and it will be found by inspecting the + stack frame. + """ + name = _sys.intern(str(name)) + repr = repr or f'<{name.rsplit(".", 1)[-1]}>' + + if module is None: + try: + module = _get_parent_frame().f_globals.get("__name__", "__main__") + except (AttributeError, ValueError): + pass + class_name = _sys.intern(_get_class_name(name, module)) + + class_namespace = { + "__repr__": lambda self: repr, + } + cls = type(class_name, (), class_namespace) + cls.__module__ = __name__ + globals()[class_name] = cls + + sentinel = cls() + + def __new__(cls): + return sentinel + + __new__.__qualname__ = f"{class_name}.__new__" + cls.__new__ = __new__ + + return sentinel + + +if hasattr(_sys, "_getframe"): + _get_parent_frame = lambda: _sys._getframe(2) +else: # pragma: no cover + + def _get_parent_frame(): + """Return the frame object for the caller's parent stack frame.""" + try: + raise Exception + except Exception: + return _sys.exc_info()[2].tb_frame.f_back.f_back + + +def _get_class_name( + sentinel_qualname: str, + module_name: Optional[str] = None, +) -> str: + return ( + "_sentinel_type__" + f'{module_name.replace(".", "_") + "__" if module_name else ""}' + f'{sentinel_qualname.replace(".", "_")}' + ) diff --git a/sklearn/tests/test_docstrings.py b/sklearn/tests/test_docstrings.py index 07dab43adb633..f010255eff3ac 100644 --- a/sklearn/tests/test_docstrings.py +++ b/sklearn/tests/test_docstrings.py @@ -208,6 +208,8 @@ "sklearn.utils.validation.check_is_fitted", "sklearn.utils.validation.check_memory", "sklearn.utils.validation.check_random_state", + # Never fix this one, it's vendord code + "sklearn.externals._sentinels.sentinel", ] FUNCTION_DOCSTRING_IGNORE_LIST = set(FUNCTION_DOCSTRING_IGNORE_LIST) diff --git a/sklearn/tests/test_props.py b/sklearn/tests/test_props.py new file mode 100644 index 0000000000000..c5e1ea52cb67b --- /dev/null +++ b/sklearn/tests/test_props.py @@ -0,0 +1,562 @@ +import re +import numpy as np +import pytest + +from sklearn.base import BaseEstimator +from sklearn.base import ClassifierMixin +from sklearn.base import TransformerMixin +from sklearn.base import MetaEstimatorMixin +from sklearn.base import clone +from sklearn.utils import MetadataRequest +from sklearn.utils.validation import check_is_fitted +from sklearn.utils.metadata_requests import RequestType +from sklearn.utils.metadata_requests import metadata_request_factory +from sklearn.utils.metadata_requests import MetadataRouter +from sklearn.utils.metadata_requests import MethodMetadataRequest + +from sklearn.base import _MetadataRequester + +N, M = 100, 4 +X = np.random.rand(N, M) +y = np.random.randint(0, 2, size=N) +my_groups = np.random.randint(0, 10, size=N) +my_weights = np.random.rand(N) +my_other_weights = np.random.rand(N) + + +def assert_request_is_empty(metadata_request, exclude=None): + """Check if a metadata request dict is empty. + + One can exclude a method or a list of methods from the check using the + ``exclude`` perameter. + """ + if isinstance(metadata_request, MetadataRequest): + metadata_request = metadata_request.to_dict() + if exclude is None: + exclude = [] + for method, request in metadata_request.items(): + if method in exclude: + continue + props = [ + prop + for prop, alias in request.items() + if isinstance(alias, str) + or RequestType(alias) != RequestType.ERROR_IF_PASSED + ] + assert not len(props) + + +class TestEstimatorNoMetadata(ClassifierMixin, BaseEstimator): + """An estimator which accepts no metadata on any method.""" + + def fit(self, X, y): + return self + + def predict(self, X): + return np.ones(len(X)) + + +class TestEstimatorFitMetadata(ClassifierMixin, BaseEstimator): + """An estimator accepting two metadata in its ``fit`` method.""" + + def __init__(self, sample_weight_none=True, brand_none=True): + self.sample_weight_none = sample_weight_none + self.brand_none = brand_none + + def fit(self, X, y, sample_weight=None, brand=None): + assert ( + sample_weight is None + ) == self.sample_weight_none, "sample_weight and sample_weight_none don't agree" + assert (brand is None) == self.brand_none, "brand and brand_none don't agree" + return self + + def predict(self, X): + return np.ones(len(X)) + + +class TestSimpleMetaEstimator(MetaEstimatorMixin, ClassifierMixin, BaseEstimator): + """A meta-estimator which also consumes sample_weight itself in ``fit``.""" + + def __init__(self, estimator, sample_weight_none): + self.sample_weight_none = sample_weight_none + self.estimator = estimator + + def fit(self, X, y, sample_weight=None, **kwargs): + assert ( + sample_weight is None + ) == self.sample_weight_none, "sample_weight and sample_weight_none don't agree" + + if sample_weight is not None: + kwargs["sample_weight"] = sample_weight + metadata_request_factory(self).fit.validate_metadata( + ignore_extras=False, self_metadata=super(), kwargs=kwargs + ) + fit_params = metadata_request_factory(self.estimator).fit.get_input( + ignore_extras=True, kwargs=kwargs + ) + self.estimator_ = clone(self.estimator).fit(X, y, **fit_params) + return self + + def get_metadata_request(self): + router = MetadataRouter().add(super(), mask=False) + router.add(self.estimator, mapping={"fit": "fit"}, mask=True, overwrite="smart") + return router.get_metadata_request() + + +class TestTransformer(TransformerMixin, BaseEstimator): + """A transformer which accepts metadata on fit and transform.""" + + def __init__( + self, + brand_none=True, + new_param_none=True, + fit_sample_weight_none=True, + transform_sample_weight_none=True, + ): + self.brand_none = brand_none + self.new_param_none = new_param_none + self.fit_sample_weight_none = fit_sample_weight_none + self.transform_sample_weight_none = transform_sample_weight_none + + def fit(self, X, y=None, brand=None, new_param=None, sample_weight=None): + assert ( + sample_weight is None + ) == self.fit_sample_weight_none, ( + "sample_weight and fit_sample_weight_none don't agree" + ) + assert ( + new_param is None + ) == self.new_param_none, "new_param and new_param_none don't agree" + assert (brand is None) == self.brand_none, "brand and brand_none don't agree" + return self + + def transform(self, X, y=None, sample_weight=None): + assert ( + sample_weight is None + ) == self.transform_sample_weight_none, ( + "sample_weight and transform_sample_weight_none don't agree" + ) + return X + + +class TestMetaTransformer(MetaEstimatorMixin, TransformerMixin, BaseEstimator): + """A simple meta-transformer.""" + + def __init__(self, transformer): + self.transformer = transformer + + def fit(self, X, y=None, **fit_params): + metadata_request_factory(self).fit.validate_metadata( + ignore_extras=False, + kwargs=fit_params, + ) + fit_params_ = metadata_request_factory(self.transformer).fit.get_input( + ignore_extras=False, kwargs=fit_params + ) + self.transformer_ = clone(self.transformer).fit(X, y, **fit_params_) + return self + + def transform(self, X, y=None, **transform_params): + metadata_request_factory(self).transform.validate_metadata( + ignore_extras=False, kwargs=transform_params + ) + transform_params_ = metadata_request_factory( + self.transformer + ).transform.get_input(ignore_extras=False, kwargs=transform_params) + return self.transformer_.transform(X, **transform_params_) + + +class SimplePipeline(BaseEstimator): + """A very simple pipeline, assuming the last step is always a predictor.""" + + def __init__(self, steps): + self.steps = steps + + def fit(self, X, y, **fit_params): + self.steps_ = [] + metadata_request_factory(self).fit.validate( + ignore_extras=False, kwargs=fit_params + ) + X_transformed = X + for step in self.steps[:-1]: + requests = metadata_request_factory(step) + step_fit_params = requests.fit.get_input( + ignore_extras=True, kwargs=fit_params + ) + transformer = clone(step).fit(X_transformed, y, **step_fit_params) + self.steps_.append(transformer) + step_transform_params = requests.transform.get_input( + ignore_extras=True, kwargs=fit_params + ) + X_transformed = transformer.transform( + X_transformed, **step_transform_params + ) + + requests = metadata_request_factory(step) + step_fit_params = requests.fit.get_input(ignore_extras=True, kwargs=fit_params) + self.steps_.append( + clone(self.steps[-1]).fit(X_transformed, y, **step_fit_params) + ) + return self + + def predict(self, X, **predict_params): + check_is_fitted(self) + X_transformed = X + metadata_request_factory(self).predict.validate_metadata( + ignore_extras=False, kwargs=predict_params + ) + for step in self.steps_[:-1]: + step_transform_params = metadata_request_factory(step).transform.get_input( + ignore_extras=True, kwargs=predict_params + ) + X_transformed = step.transform(X, **step_transform_params) + + step_predict_params = metadata_request_factory( + self.steps_[-1] + ).predict.get_input(ignore_extras=True, kwargs=predict_params) + return self.steps_[-1].predict(X_transformed, **step_predict_params) + + def get_metadata_request(self): + router = MetadataRouter() + if len(self.steps) > 1: + router.add( + self.steps[:-1], + mask=True, + mapping={"predict": "transform", "fit": ["transform", "fit"]}, + overwrite="smart", + ) + router.add(self.steps[-1], overwrite="smart", mapping="one-to-one") + return router.get_metadata_request() + + +def test_assert_request_is_empty(): + requests = MetadataRequest() + assert_request_is_empty(requests) + assert_request_is_empty(requests.to_dict()) + + requests.fit.add_request(prop="foo", alias=RequestType.ERROR_IF_PASSED) + # this should still work, since ERROR_IF_PASSED is the default value + assert_request_is_empty(requests) + + requests.fit.add_request(prop="bar", alias="value") + with pytest.raises(AssertionError): + # now requests is no more empty + assert_request_is_empty(requests) + + # but one can exclude a method + assert_request_is_empty(requests, exclude="fit") + + requests.score.add_request(prop="carrot", alias=RequestType.REQUESTED) + with pytest.raises(AssertionError): + # excluding `fit` is not enough + assert_request_is_empty(requests, exclude="fit") + + # and excluding both fit and score would avoid an exception + assert_request_is_empty(requests, exclude=["fit", "score"]) + + +def test_default_requests(): + class OddEstimator(BaseEstimator): + __metadata_request__fit = { + # set a different default request + "sample_weight": RequestType.REQUESTED + } # type: ignore + + odd_request = metadata_request_factory(OddEstimator()) + assert odd_request.fit.requests == {"sample_weight": RequestType.REQUESTED} + + # check other test estimators + assert not len(metadata_request_factory(TestEstimatorNoMetadata()).fit.requests) + assert_request_is_empty(TestEstimatorNoMetadata().get_metadata_request()) + + trs_request = metadata_request_factory(TestTransformer()) + assert trs_request.fit.requests == { + "sample_weight": RequestType(None), + "brand": RequestType(None), + "new_param": RequestType(None), + } + assert trs_request.transform.requests == { + "sample_weight": RequestType(None), + } + assert_request_is_empty(trs_request) + + est_request = metadata_request_factory(TestEstimatorFitMetadata()) + assert est_request.fit.requests == { + "sample_weight": RequestType(None), + "brand": RequestType(None), + } + assert_request_is_empty(est_request) + + +def test_simple_metadata_routing(): + # Tests that metadata is properly routed + # The underlying estimator doesn't accept or request metadata + cls = TestSimpleMetaEstimator( + estimator=TestEstimatorNoMetadata(), sample_weight_none=True + ) + cls.fit(X, y) + + # Meta-estimator consumes sample_weight, but doesn't forward it to the underlying + # estimator + cls = TestSimpleMetaEstimator( + estimator=TestEstimatorNoMetadata(), sample_weight_none=False + ) + cls.fit(X, y, sample_weight=my_weights) + + # If the estimator accepts the metadata but doesn't explicitly say it doesn't + # need it, there's an error + cls = TestSimpleMetaEstimator( + estimator=TestEstimatorFitMetadata(), + sample_weight_none=False, + ) + with pytest.raises( + ValueError, + match=( + "sample_weight is passed but is not explicitly set as requested or not. In" + " method: fit" + ), + ): + cls.fit(X, y, sample_weight=my_weights) + + # Explicitly saying the estimator doesn't need it, makes the error go away, + # but if a metadata is passed which is not requested by any object/estimator, + # there will be still an error + cls = TestSimpleMetaEstimator( + estimator=TestEstimatorFitMetadata().fit_requests(sample_weight=False), + sample_weight_none=False, + ) + # this doesn't raise since TestSimpleMetaEstimator itself is a consumer, + # and passing metadata to the consumer directly is fine regardless of its + # metadata_request values. + cls.fit(X, y, sample_weight=my_weights) + + # Requesting a metadata will make the meta-estimator forward it correctly + cls = TestSimpleMetaEstimator( + estimator=TestEstimatorFitMetadata(sample_weight_none=False).fit_requests( + sample_weight=True + ), + sample_weight_none=False, + ) + cls.fit(X, y, sample_weight=my_weights) + + +def test_invalid_metadata(): + # check that passing wrong metadata raises an error + trs = TestMetaTransformer( + transformer=TestTransformer().transform_requests(sample_weight=True) + ) + with pytest.raises( + ValueError, + match=( + re.escape( + "Metadata passed which is not understood: ['other_param']. In method:" + " transform" + ) + ), + ): + trs.fit(X, y).transform(X, other_param=my_weights) + + # passing a metadata which is not requested by any estimator should also raise + trs = TestMetaTransformer( + transformer=TestTransformer().transform_requests(sample_weight=False) + ) + with pytest.raises( + ValueError, + match=( + re.escape( + "Metadata passed which is not understood: ['sample_weight']. In method:" + " transform" + ) + ), + ): + trs.fit(X, y).transform(X, sample_weight=my_weights) + + +def test_get_metadata_request(): + class TestDefaultsBadMethodName(_MetadataRequester): + __metadata_request__fit = { + "sample_weight": RequestType.ERROR_IF_PASSED, + "my_param": RequestType.ERROR_IF_PASSED, + } + __metadata_request__score = { + "sample_weight": RequestType.ERROR_IF_PASSED, + "my_param": True, + "my_other_param": RequestType.ERROR_IF_PASSED, + } + # this will raise an error since we don't understand "other_method" as a method + __metadata_request__other_method = {"my_param": True} + + class TestDefaults(_MetadataRequester): + __metadata_request__fit = { + "sample_weight": RequestType.ERROR_IF_PASSED, + "my_other_param": RequestType.ERROR_IF_PASSED, + } + __metadata_request__score = { + "sample_weight": RequestType.ERROR_IF_PASSED, + "my_param": True, + "my_other_param": RequestType.ERROR_IF_PASSED, + } + __metadata_request__predict = {"my_param": True} + + with pytest.raises( + AttributeError, match="'MetadataRequest' object has no attribute 'other_method'" + ): + TestDefaultsBadMethodName().get_metadata_request() + + expected = { + "score": { + "my_param": RequestType(True), + "my_other_param": RequestType(None), + "sample_weight": RequestType(None), + }, + "fit": { + "my_other_param": RequestType(None), + "sample_weight": RequestType(None), + }, + "partial_fit": {}, + "predict": {"my_param": RequestType(True)}, + "transform": {}, + "inverse_transform": {}, + "split": {}, + } + assert TestDefaults().get_metadata_request() == expected + + est = TestDefaults().score_requests(my_param="other_param") + expected = { + "score": { + "my_param": "other_param", + "my_other_param": RequestType(None), + "sample_weight": RequestType(None), + }, + "fit": { + "my_other_param": RequestType(None), + "sample_weight": RequestType(None), + }, + "partial_fit": {}, + "predict": {"my_param": RequestType(True)}, + "transform": {}, + "inverse_transform": {}, + "split": {}, + } + assert est.get_metadata_request() == expected + + est = TestDefaults().fit_requests(sample_weight=True) + expected = { + "score": { + "my_param": RequestType(True), + "my_other_param": RequestType(None), + "sample_weight": RequestType(None), + }, + "fit": { + "my_other_param": RequestType(None), + "sample_weight": RequestType(True), + }, + "partial_fit": {}, + "predict": {"my_param": RequestType(True)}, + "transform": {}, + "inverse_transform": {}, + "split": {}, + } + assert est.get_metadata_request() == expected + + +def test__get_default_requests(): + # Test _get_default_requests method + class ExplicitRequest(BaseEstimator): + __metadata_request__fit = {"prop": RequestType.ERROR_IF_PASSED} + + def fit(self, X, y): + return self + + assert metadata_request_factory(ExplicitRequest()).fit.requests == { + "prop": RequestType.ERROR_IF_PASSED + } + assert_request_is_empty(ExplicitRequest().get_metadata_request(), exclude="fit") + + class ExplicitRequestOverwrite(BaseEstimator): + __metadata_request__fit = {"prop": RequestType.REQUESTED} + + def fit(self, X, y, prop=None, **kwargs): + return self + + assert metadata_request_factory(ExplicitRequestOverwrite()).fit.requests == { + "prop": RequestType.REQUESTED + } + assert_request_is_empty( + ExplicitRequestOverwrite().get_metadata_request(), exclude="fit" + ) + + class ImplicitRequest(BaseEstimator): + def fit(self, X, y, prop=None, **kwargs): + return self + + assert metadata_request_factory(ImplicitRequest()).fit.requests == { + "prop": RequestType.ERROR_IF_PASSED + } + assert_request_is_empty(ImplicitRequest().get_metadata_request(), exclude="fit") + + class ImplicitRequestRemoval(BaseEstimator): + __metadata_request__fit = {"prop": RequestType.UNUSED} + + def fit(self, X, y, prop=None, **kwargs): + return self + + assert metadata_request_factory(ImplicitRequestRemoval()).fit.requests == {} + assert_request_is_empty(ImplicitRequestRemoval().get_metadata_request()) + + +def test_method_metadata_request(): + mmr = MethodMetadataRequest(name="fit") + with pytest.raises( + ValueError, + match="overwrite can only be one of {True, False, 'smart', 'ignore'}.", + ): + mmr.add_request(prop="test", alias=None, overwrite="test") + + with pytest.raises(ValueError, match="Expected all metadata to be called test"): + mmr.add_request(prop="foo", alias="bar", expected_metadata="test") + + with pytest.raises(ValueError, match="Aliasing is not allowed"): + mmr.add_request(prop="foo", alias="bar", allow_aliasing=False) + + with pytest.raises(ValueError, match="alias should be either a string or"): + mmr.add_request(prop="foo", alias=1.4) + + mmr.add_request(prop="foo", alias=None) + assert mmr.requests == {"foo": RequestType.ERROR_IF_PASSED} + with pytest.raises(ValueError, match="foo is already requested"): + mmr.add_request(prop="foo", alias=True) + with pytest.raises(ValueError, match="foo is already requested"): + mmr.add_request(prop="foo", alias=True) + mmr.add_request(prop="foo", alias=True, overwrite="smart") + assert mmr.requests == {"foo": RequestType.REQUESTED} + + with pytest.raises(ValueError, match="Can only add another MethodMetadataRequest"): + mmr.merge_method_request({}) + + assert MethodMetadataRequest.from_dict(None, name="fit").requests == {} + assert MethodMetadataRequest.from_dict("foo", name="fit").requests == { + "foo": RequestType.ERROR_IF_PASSED + } + assert MethodMetadataRequest.from_dict(["foo", "bar"], name="fit").requests == { + "foo": RequestType.ERROR_IF_PASSED, + "bar": RequestType.ERROR_IF_PASSED, + } + + +def test_metadata_request_factory(): + class Consumer(BaseEstimator): + __metadata_request__fit = {"prop": RequestType.ERROR_IF_PASSED} + + assert_request_is_empty(metadata_request_factory(None)) + assert_request_is_empty(metadata_request_factory({})) + assert_request_is_empty(metadata_request_factory(object())) + + mr = MetadataRequest({"fit": "foo"}, default="bar") + mr_factory = metadata_request_factory(mr) + assert_request_is_empty(mr_factory, exclude="fit") + assert mr_factory.fit.requests == {"foo": "bar"} + + mr = metadata_request_factory(Consumer()) + assert_request_is_empty(mr, exclude="fit") + assert mr.fit.requests == {"prop": RequestType.ERROR_IF_PASSED} diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index 3d8a1ca87d210..44734a8a49a07 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -41,6 +41,10 @@ check_scalar, ) from .. import get_config +from .metadata_requests import MetadataRequest +from .metadata_requests import MethodMetadataRequest +from .metadata_requests import metadata_request_factory +from .metadata_requests import MetadataRouter # Do not deprecate parallel_backend and register_parallel_backend as they are @@ -75,6 +79,10 @@ "all_estimators", "DataConversionWarning", "estimator_html_repr", + "MetadataRequest", + "metadata_request_factory", + "MetadataRouter", + "MethodMetadataRequest", ] IS_PYPY = platform.python_implementation() == "PyPy" diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index e6cbc38adbcac..9687a9a67489b 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -2994,6 +2994,8 @@ def check_no_attributes_set_in_init(name, estimator_orig): # Test for no setting apart from parameters during init invalid_attr = set(vars(estimator)) - set(init_params) - set(parents_init_params) + # Ignore private attributes + invalid_attr = set([attr for attr in invalid_attr if not attr.startswith("_")]) assert not invalid_attr, ( "Estimator %s should not set any attribute apart" " from parameters during init. Found attributes %s." diff --git a/sklearn/utils/metadata_requests.py b/sklearn/utils/metadata_requests.py new file mode 100644 index 0000000000000..c3463e2d3b167 --- /dev/null +++ b/sklearn/utils/metadata_requests.py @@ -0,0 +1,885 @@ +import inspect +from copy import deepcopy +from enum import Enum +from collections import defaultdict +from typing import Union, Optional +from warnings import warn +from ..externals._sentinels import sentinel # type: ignore # mypy error!!! + + +class RequestType(Enum): + UNREQUESTED = False + REQUESTED = True + ERROR_IF_PASSED = None + # this sentinel is used in `__metadata_request__*` attributes to indicate + # that a metadata is not present even though it may be present in the + # corresponding method's signature. + UNUSED = sentinel("UNUSED") + # this sentinel is used whenever a default value is changed, and therefore + # the user should explicitly set the value, otherwise a warning is shown. + # An example is when a meta-estimator is only a router, but then becomes + # also a consumer. + WARN = sentinel("WARN") + + +# this sentinel is the default used in `{method}_requests` methods to indicate +# no change requested by the user. +UNCHANGED = sentinel("UNCHANGED") + +METHODS = [ + "fit", + "partial_fit", + "predict", + "score", + "split", + "transform", + "inverse_transform", +] + + +REQUESTER_DOC = """ Request metadata passed to the ``{method}`` method. + + Please check :ref:`User Guide ` on how the routing + mechanism works. + + Parameters + ---------- +""" +REQUESTER_DOC_PARAM = """ {metadata} : RequestType, str, True, False, or None, \ + default=UNCHANGED + Whether {metadata} should be passed to {method} by meta-estimators or + not, and if yes, should it have an alias. + + - True or RequestType.REQUESTED: {metadata} is requested, and passed to \ +{method} if provided. + + - False or RequestType.UNREQUESTED: {metadata} is not requested and the \ +meta-estimator will not pass it to {method}. + + - None or RequestType.ERROR_IF_PASSED: {metadata} is not requested, and \ +the meta-estimator will raise an error if the user provides {metadata} + + - str: {metadata} should be passed to the meta-estimator with this given \ +alias instead of the original name. + +""" +REQUESTER_DOC_RETURN = """ Returns + ------- + self + Returns the object itself. +""" + + +class MethodMetadataRequest: + """Contains the metadata request info for a single method. + + .. versionadded:: 1.1 + + Parameters + ---------- + name : str + The name of the method to which these requests belong. + """ + + def __init__(self, name): + self.requests = dict() + self.name = name + + def add_request( + self, + *, + prop, + alias, + allow_aliasing=True, + overwrite=False, + expected_metadata=None, + ): + """Add request info for a prop. + + Parameters + ---------- + prop : str + The property for which a request is set. + + alias : str, RequestType, or {True, False, None} + The alias which is routed to `prop` + + - str: the name which should be used as an alias when a meta-estimator + routes the metadata. + + - True or RequestType.REQUESTED: requested + + - False or RequestType.UNREQUESTED: not requested + + - None or RequestType.ERROR_IF_PASSED: error if passed + + allow_aliasing : bool, default=True + If False, alias should be the same as prop if it's a string. + + overwrite : bool or str, default=False + + - True: ``alias`` replaces the existing routing. + + - False: a ``ValueError`` is raised if the given value conflicts + with an existing one. + + - "smart": overwrite in this order: + ``RequestType.REQUESTED`` over ``RequestType.UNREQUESTED`` over + ``RequestType.ERROR_IF_PASSED``, and error if existing value is + a string. + + - "ignore": ignore the requested metadata if it already exists. + + expected_metadata : str, default=None + If provided, all props should be the same as this value. It used to + handle default values. + """ + if overwrite not in {True, False, "smart", "ignore"}: + raise ValueError( + "overwrite can only be one of {True, False, 'smart', 'ignore'}; " + f"but f{overwrite} is given." + ) + if expected_metadata is not None and expected_metadata != prop: + raise ValueError( + f"Expected all metadata to be called {expected_metadata} but " + f"{prop} was passed." + ) + if not allow_aliasing and isinstance(alias, str) and prop != alias: + raise ValueError( + "Aliasing is not allowed, prop and alias should " + "be the same strings if alias is a string." + ) + + if not isinstance(alias, str): + try: + alias = RequestType(alias) + except ValueError: + raise ValueError( + "alias should be either a string or one of " + "{None, True, False}, or a RequestType." + ) + + if alias == prop: + alias = RequestType.REQUESTED + + if alias == RequestType.UNUSED and prop in self.requests: + del self.requests[prop] + elif prop not in self.requests or overwrite is True: + self.requests[prop] = alias + elif prop in self.requests and overwrite == "ignore": + pass + elif overwrite == "smart": + current = self.requests[prop] + if isinstance(current, str): + raise ValueError( + f"Cannot overwrite {current} with {alias} when overwrite=smart." + ) + current = RequestType(current) + + # REQUESTED > UNREQUESTED > ERROR_IF_PASSED + if alias == RequestType.REQUESTED and current in { + RequestType.ERROR_IF_PASSED, + RequestType.UNREQUESTED, + RequestType.WARN, + }: + self.requests[prop] = alias + elif alias in {RequestType.UNREQUESTED, RequestType.WARN} and current in { + RequestType.ERROR_IF_PASSED, + RequestType.WARN, + }: + self.requests[prop] = RequestType.UNREQUESTED + elif self.requests[prop] != alias: + raise ValueError( + f"{prop} is already requested as {self.requests[prop]}, " + f"which is not the same as the one given: {alias}. Cannot " + "overwrite when overwrite=False." + ) + + def merge_method_request(self, other, overwrite=False, expected_metadata=None): + """Merge the metadata request info of two methods. + + The methods can be the same, or different. For example, merging + fit and score info of the same object, or merging fit request info + from two different sub estimators. + + Parameters + ---------- + other : MethodMetadataRequest + The other object to be merged with this instance. + + overwrite : bool or str, default=False + + - True: ``alias`` replaces the existing routing. + + - False: a ``ValueError`` is raised if the given value conflicts + with an existing one. + + - "smart": overwrite in this order: + ``RequestType.REQUESTED`` over ``RequestType.UNREQUESTED`` over + ``RequestType.ERROR_IF_PASSED``, and error if existing value is + a string. + + - "ignore": ignore the requested metadata if it already exists. + + expected_metadata : str, default=None + If provided, all props should be the same as this value. It used to + handle default values. + """ + if not isinstance(other, MethodMetadataRequest): + raise ValueError("Can only add another MethodMetadataRequest.") + for prop, alias in other.requests.items(): + self.add_request( + prop=prop, + alias=alias, + overwrite=overwrite, + expected_metadata=expected_metadata, + ) + + def validate_metadata(self, ignore_extras=False, self_metadata=None, kwargs=None): + """Validate the given arguments against the requested ones. + + Parameters + ---------- + ignore_extras : bool, default=False + If ``True``, no error is raised if extra unknown args are passed. + + self_metadata : MetadataRequest-like, default=None + This parameter can be anything which can be an input to + ``metadata_request_factory``. Only the part of the metadata which + is the same as ``name`` is used. + + Consumers don't validate their own metadata. Validation is always + done by routers (i.e. usually meta-estimators). But sometimes an + object is a consumer and a router, e.g. ``LogisticRegressionCV`` + which consumes ``sample_weight``, but also routes metadata to the + given scorer(s) and CV object, and therefore is also a router. In + such a case, ``sample_weight`` is the metadata being consumed. A + router can get its own required metadata, as opposed to the ones + required by its sub-objects, using + ``metadata_request_factory(super())``. ``validate_metadata`` then + uses the part which is relevant to this validation. Since this + object knows which method is relevant using its ``name``, passing + ``super()`` here would be sufficient. + + kwargs : dict + Provided metadata. + + Returns + ------- + None + """ + kwargs = {} if kwargs is None else kwargs + self_metadata = getattr( + metadata_request_factory(self_metadata), self.name + ).requests + warn_metadata = [k for k, v in self_metadata.items() if v == RequestType.WARN] + warn_kwargs = [k for k in kwargs.keys() if k in warn_metadata] + if warn_kwargs: + warn( + "The following metadata are provided, which are now supported by this " + f"class: {warn_kwargs}. These metadata were not processed in previous " + "versions. Set their requested value to RequestType.UNREQUESTED " + "to maintain previous behavior, or to RequestType.REQUESTED to " + "consume and use the metadata.", + UserWarning, + ) + # we then remove self metadata from kwargs, since they should not be + # validated. + kwargs = {v: k for v, k in kwargs.items() if v not in self_metadata} + args = {arg for arg, value in kwargs.items() if value is not None} + if not ignore_extras and args - set(self.requests.keys()): + raise ValueError( + "Metadata passed which is not understood: " + f"{sorted(args - set(self.requests.keys()))}. In method: " + f"{self.name}" + ) + + for prop, alias in self.requests.items(): + if not isinstance(alias, str): + alias = RequestType(alias) + if alias == RequestType.UNREQUESTED: + if prop in args: + raise ValueError( + f"{prop} is not requested by any estimator but is provided. In" + f" method: {self.name}" + ) + elif alias == RequestType.REQUESTED or isinstance(alias, str): + # we ignore what the given alias here is, since aliases are + # checked at the parent meta-estimator level, and the child + # still expects the original names for the metadata. + # If a metadata is requested but not passed, no error is raised + continue + elif alias == RequestType.ERROR_IF_PASSED: + if prop in args: + raise ValueError( + f"{prop} is passed but is not explicitly set as " + f"requested or not. In method: {self.name}" + ) + + def get_input(self, ignore_extras=False, kwargs=None): + """Return the input parameters requested by the method. + + The output of this method can be used directly as the input to the + corresponding method as extra props. + + Parameters + ---------- + ignore_extras : bool, default=False + If ``True``, no error is raised if extra unknown args are passed. + + kwargs : dict + A dictionary of provided metadata. + + Returns + ------- + kwargs : dict + A dictionary of {prop: value} which can be given to the + corresponding method. + """ + kwargs = {} if kwargs is None else kwargs + args = {arg: value for arg, value in kwargs.items() if value is not None} + res = dict() + for prop, alias in self.requests.items(): + if not isinstance(alias, str): + alias = RequestType(alias) + + if alias == RequestType.WARN: + warn( + f"Support for {prop} has recently been added to this class. " + "To maintain backward compatibility, it is ignored now. " + "You can set the request value to RequestType.UNREQUESTED " + "to silence this warning, or to RequestType.REQUESTED to " + "consume and use the metadata." + ) + elif alias == RequestType.UNREQUESTED: + continue + elif alias == RequestType.REQUESTED and prop in args: + res[prop] = args[prop] + elif alias == RequestType.ERROR_IF_PASSED and prop in args: + raise ValueError( + f"{prop} is passed but is not explicitly set as " + f"requested or not. In method: {self.name}" + ) + elif alias in args: + res[prop] = args[alias] + self.validate_metadata(ignore_extras=ignore_extras, kwargs=res) + return res + + def masked(self): + """Return a masked version of the requests. + + Returns + ------- + masked : MethodMetadataRequest + A masked version is one which converts a ``{'prop': 'alias'}`` to + ``{'alias': True}``. This is desired in meta-estimators passing + requests to their parent estimators. + """ + res = MethodMetadataRequest(name=self.name) + for prop, alias in self.requests.items(): + if isinstance(alias, str): + res.add_request( + prop=alias, + alias=alias, + allow_aliasing=False, + overwrite=False, + ) + else: + res.add_request( + prop=prop, + alias=alias, + allow_aliasing=False, + overwrite=False, + ) + return res + + @classmethod + def from_dict( + cls, requests, name, allow_aliasing=True, default=RequestType.ERROR_IF_PASSED + ): + """Construct a MethodMetadataRequest from a given dictionary. + + Parameters + ---------- + requests : dict + A dictionary representing the requests. + + name : str + The name of the method to which these requests belong. + + allow_aliasing : bool, default=True + If false, only aliases with the same name as the parameter are + allowed. This is useful when handling the default values. + + default : RequestType, True, False, None, or str, \ + default=RequestType.ERROR_IF_PASSED + The default value to be used if parameters are provided as a string + or list instead of the fully specifying dict. + + Returns + ------- + requests: MethodMetadataRequest + A :class:`MethodMetadataRequest` object. + """ + if requests is None: + requests = dict() + elif isinstance(requests, str): + requests = {requests: default} + elif isinstance(requests, (list, set)): + requests = {r: default for r in requests} + result = cls(name=name) + for prop, alias in requests.items(): + result.add_request(prop=prop, alias=alias, allow_aliasing=allow_aliasing) + return result + + def __repr__(self): + return str(self.requests) + + def __str__(self): + return str(self.requests) + + +class MetadataRequest: + """Contains the metadata request info of an object. + + .. versionadded:: 1.1 + + Parameters + ---------- + requests : dict of dict of {str: str}, default=None + A dictionary where the keys are the names of the methods, and the values are + a dictionary of the form ``{"required_metadata": "provided_metadata"}``. + ``"provided_metadata"`` can also be a ``RequestType`` or {True, False, None}. + + default : RequestType, True, False, None, or str, \ + default=RequestType.ERROR_IF_PASSED + The default value to be used if parameters are provided as a string instead of + the usual second layer dict. + """ + + def __init__(self, requests=None, default=RequestType.ERROR_IF_PASSED): + for method in METHODS: + setattr(self, method, MethodMetadataRequest(name=method)) + + if requests is None: + return + elif not isinstance(requests, dict): + raise ValueError( + "Can only construct an instance from a dict. Please call " + "metadata_request_factory for other types of input." + ) + + for method, method_requests in requests.items(): + if method not in METHODS: + raise ValueError(f"{method} is not supported as a method.") + setattr( + self, + method, + MethodMetadataRequest.from_dict( + method_requests, name=method, default=default + ), + ) + + def add_requests( + self, + obj, + mapping="one-to-one", + overwrite=False, + expected_metadata=None, + ): + """Add request info from the given object with the desired mapping. + + Parameters + ---------- + obj : object + An object from which a MetadataRequest can be constructed. + + mapping : dict or str, default="one-to-one" + The mapping between the ``obj``'s methods and this object's + methods. If ``"one-to-one"`` all methods' requests from ``obj`` are + merged into this instance's methods. If a dict, the mapping is of + the form ``{"destination_method": "source_method"}`` or + ``{"destination_method": ["source_method1", ...]}``. + + overwrite : bool or str, default=False + + - True: ``alias`` replaces the existing routing. + + - False: a ``ValueError`` is raised if the given value conflicts + with an existing one. + + - "smart": overwrite in this order: + ``RequestType.REQUESTED`` over ``RequestType.UNREQUESTED`` over + ``RequestType.ERROR_IF_PASSED``, and error if existing value is + a string. + + - "ignore": ignore the requested metadata if it already exists. + + expected_metadata : str, default=None + If provided, all props should be the same as this value. It used to + handle default values. + """ + if not isinstance(mapping, dict) and mapping != "one-to-one": + raise ValueError( + "mapping can only be a dict or the literal 'one-to-one'. " + f"Given value: {mapping}" + ) + if mapping == "one-to-one": + mapping = [(method, method) for method in METHODS] + else: + _mapping = [] + for destination, sources in mapping.items(): + if isinstance(sources, list): + _mapping.extend([(destination, source) for source in sources]) + else: + _mapping.append((destination, sources)) + mapping = _mapping + other = metadata_request_factory(obj) + for destination, source in mapping: + my_method = getattr(self, destination) + other_method = getattr(other, source) + my_method.merge_method_request( + other_method, + overwrite=overwrite, + expected_metadata=expected_metadata, + ) + + def masked(self): + """Return a masked version of the requests. + + A masked version is one which converts a ``{'prop': 'alias'}`` to + ``{'alias': True}``. This is desired in meta-estimators passing + requests to their parent estimators. + """ + res = MetadataRequest() + for method in METHODS: + setattr(res, method, getattr(self, method).masked()) + return res + + def to_dict(self): + """Return dictionary representation of this object.""" + output = dict() + for method in METHODS: + output[method] = getattr(self, method).requests + return output + + def __repr__(self): + return str(self.to_dict()) + + def __str__(self): + return str(self.to_dict()) + + +def metadata_request_factory(obj=None): + """Get a MetadataRequest instance from the given object. + + This function always returns a copy or an instance constructed from the + intput, such that changing the output of this function will not change the + original object. + + .. versionadded:: 1.1 + + Parameters + ---------- + obj : object + If the object is already a MetadataRequest, return that. + If the object is an estimator, try to call `get_metadata_request` and get + an instance from that method. + If the object is a dict, create a MetadataRequest from that. + + Returns + ------- + metadata_requests : MetadataRequest + A ``MetadataRequest`` taken or created from the given object. + """ + if obj is None: + return MetadataRequest() + + if isinstance(obj, MetadataRequest): + return deepcopy(obj) + + if isinstance(obj, dict): + return MetadataRequest(obj) + + # doing this instead of a try/except since an AttributeError could be raised + # for other reasons. + if hasattr(obj, "get_metadata_request"): + return MetadataRequest(obj.get_metadata_request()) + else: + return MetadataRequest() + + +class MetadataRouter: + """Route the metadata to child objects. + + .. versionadded:: 1.1 + """ + + def __init__(self): + self.requests = MetadataRequest() + + def add(self, *obj, mapping="one-to-one", overwrite=False, mask=False): + """Add a set of requests to the existing ones. + + Parameters + ---------- + *obj : objects + A set of objects from which the requests are extracted. Passed as + arguments to this method. + + mapping : dict or str, default="one-to-one" + The mapping between the ``obj``'s methods and this routing object's + methods. If ``"one-to-one"`` all methods' requests from ``obj`` are + merged into this instance's methods. If a dict, the mapping is of + the form ``{"destination_method": "source_method"}`` or + ``{"destination_method": ["source_method1", ...]}``. + + overwrite : bool or str, default=False + + - True: ``alias`` replaces the existing routing. + + - False: a ``ValueError`` is raised if the given value conflicts + with an existing one. + + - "smart": overwrite in this order: + ``RequestType.REQUESTED`` over ``RequestType.UNREQUESTED`` over + ``RequestType.ERROR_IF_PASSED``, and error if existing value is + a string. + + - "ignore": ignore the requested metadata if it already exists. + + mask : bool, default=False + If the requested metadata should be masked by the alias. If + ``True``, then a request of the form + ``{'sample_weight' : 'my_weight'}`` is converted to + ``{'my_weight': 'my_weight'}``. This is required for meta-estimators + which should expose the requested parameters and not the ones + expected by the objects' methods. + """ + for x in obj: + if mask: + x = metadata_request_factory(x).masked() + self.requests.add_requests(x, mapping=mapping, overwrite=overwrite) + return self + + def get_metadata_request(self): + """Get requested data properties. + + Please check :ref:`User Guide ` on how the routing + mechanism works. + + Returns + ------- + request : dict + A dict of dict of str->value. The key to the first dict is the name + of the method, and the key to the second dict is the name of the + argument requested by the method. + """ + return self.requests.to_dict() + + +class RequestMethod: + """ + A descriptor for request methods. + + .. versionadded:: 1.1 + + Parameters + ---------- + name : str + The name of the method for which the request function should be + created, e.g. ``"fit"`` would create a ``fit_requests`` function. + + keys : list of str + A list of strings which are accepted parameters by the created + function, e.g. ``["sample_weight"]`` if the corresponding method + accepts it as a metadata. + + Notes + ----- + This class is a descriptor [1]_ and uses PEP-362 to set the signature of + the returned function [2]_. + + References + ---------- + .. [1] https://docs.python.org/3/howto/descriptor.html + + .. [2] https://www.python.org/dev/peps/pep-0362/ + """ + + def __init__(self, name, keys): + self.name = name + self.keys = keys + + def __get__(self, instance, owner): + # we would want to have a method which accepts only the expected args + def func(**kw): + if set(kw) - set(self.keys): + raise TypeError(f"Unexpected args: {set(kw) - set(self.keys)}") + + requests = metadata_request_factory(instance) + + try: + method_metadata_request = getattr(requests, self.name) + except AttributeError: + raise ValueError(f"{self.name} is not a supported method.") + + for prop, alias in kw.items(): + if alias is not UNCHANGED: + method_metadata_request.add_request( + prop=prop, alias=alias, allow_aliasing=True, overwrite=True + ) + instance._metadata_request = requests.to_dict() + + return instance + + # Now we set the relevant attributes of the function so that it seems + # like a normal method to the end user, with known expected arguments. + func.__name__ = f"{self.name}_requests" + params = [ + inspect.Parameter( + name="self", + kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, + annotation=type(instance), + ) + ] + params.extend( + [ + inspect.Parameter( + k, + inspect.Parameter.KEYWORD_ONLY, + default=UNCHANGED, + annotation=Optional[Union[RequestType, str]], + ) + for k in self.keys + ] + ) + func.__signature__ = inspect.Signature( + params, + return_annotation=type(instance), + ) + doc = REQUESTER_DOC.format(method=self.name) + for metadata in self.keys: + doc += REQUESTER_DOC_PARAM.format(metadata=metadata, method=self.name) + doc += REQUESTER_DOC_RETURN + func.__doc__ = doc + return func + + +class _MetadataRequester: + """Mixin class for adding metadata request functionality. + + .. versionadded:: 1.1 + """ + + def __init_subclass__(cls, **kwargs): + """Set the ``{method}_requests`` methods. + + This uses PEP-487 [1]_ to set the ``{method}_requests`` methods. It + looks for the information available in the set default values which are + set using ``__metadata_request__*`` class attributes. + + References + ---------- + .. [1] https://www.python.org/dev/peps/pep-0487 + """ + try: + requests = cls._get_default_requests().to_dict() + except Exception: + # if there are any issues in the default values, it will be raised + # when ``get_metadata_request`` is called. Here we are going to + # ignore all the issues such as bad defaults etc.` + super().__init_subclass__(**kwargs) + return + + for request_method, request_keys in requests.items(): + # set ``{method}_requests``` methods + if not len(request_keys): + continue + setattr( + cls, + f"{request_method}_requests", + RequestMethod(request_method, sorted(request_keys)), + ) + super().__init_subclass__(**kwargs) + + @classmethod + def _get_default_requests(cls): + """Collect default request values. + + This method combines the information present in ``metadata_request__*`` + class attributes. + """ + + requests = MetadataRequest() + + # need to go through the MRO since this is a class attribute and + # ``vars`` doesn't report the parent class attributes. We go through + # the reverse of the MRO since cls is the first in the tuple and object + # is the last. + defaults = defaultdict() + for klass in reversed(inspect.getmro(cls)): + klass_defaults = { + attr: value + for attr, value in vars(klass).items() + if "__metadata_request__" in attr + } + defaults.update(klass_defaults) + defaults = dict(sorted(defaults.items())) + + # First take all arguments from the method signatures and have them as + # ERROR_IF_PASSED, except X, y, *args, and **kwargs. + for method in METHODS: + # Here we use `isfunction` instead of `ismethod` because calling `getattr` + # on a class instead of an instance returns an unbound function. + if not hasattr(cls, method) or not inspect.isfunction(getattr(cls, method)): + continue + # ignore the first parameter of the method, which is usually "self" + params = list(inspect.signature(getattr(cls, method)).parameters.items())[ + 1: + ] + for pname, param in params: + if pname in {"X", "y", "Y"}: + continue + if param.kind in {param.VAR_POSITIONAL, param.VAR_KEYWORD}: + continue + getattr(requests, method).add_request( + prop=pname, + alias=RequestType.ERROR_IF_PASSED, + allow_aliasing=False, + overwrite=False, + ) + + # Then overwrite those defaults with the ones provided in + # __metadata_request__* attributes, which are provided in `requests` here. + + for attr, value in defaults.items(): + # we don't check for attr.startswith() since python prefixes attrs + # starting with __ with the `_ClassName`. + substr = "__metadata_request__" + method = attr[attr.index(substr) + len(substr) :] + for prop, alias in value.items(): + getattr(requests, method).add_request( + prop=prop, alias=alias, overwrite=True + ) + return requests + + def get_metadata_request(self): + """Get requested data properties. + + Please check :ref:`User Guide ` on how the routing + mechanism works. + + Returns + ------- + request : dict + A dict of dict of str->value. The key to the first dict is the name + of the method, and the key to the second dict is the name of the + argument requested by the method. + """ + if hasattr(self, "_metadata_request"): + requests = metadata_request_factory(self._metadata_request) + else: + requests = self._get_default_requests() + + return requests.to_dict() diff --git a/sklearn/utils/tests/test_estimator_checks.py b/sklearn/utils/tests/test_estimator_checks.py index 691f531a07e6d..cb1fa7bc385fd 100644 --- a/sklearn/utils/tests/test_estimator_checks.py +++ b/sklearn/utils/tests/test_estimator_checks.py @@ -651,6 +651,10 @@ class NonConformantEstimatorNoParamSet(BaseEstimator): def __init__(self, you_should_set_this_=None): pass + class ConformantEstimatorClassAttribute(BaseEstimator): + # making sure our __metadata_request__* class attributes are okay! + __metadata_request__fit = {"foo": True} + msg = ( "Estimator estimator_name should not set any" " attribute apart from parameters during init." @@ -670,6 +674,14 @@ def __init__(self, you_should_set_this_=None): "estimator_name", NonConformantEstimatorNoParamSet() ) + # a private class attribute is okay! + check_no_attributes_set_in_init( + "estimator_name", ConformantEstimatorClassAttribute() + ) + check_no_attributes_set_in_init( + "estimator_name", ConformantEstimatorClassAttribute().fit_requests(foo=True) + ) + def test_check_estimator_pairwise(): # check that check_estimator() works on estimator with _pairwise