From dbead5c86ab7b266682a9bec448f68c12f8b3b80 Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Fri, 8 Oct 2021 17:27:15 +0200 Subject: [PATCH 01/18] initial base implementation commit --- doc/metadata_routing.rst | 193 +++++ doc/user_guide.rst | 1 + examples/metadata_routing.py | 518 ++++++++++++ sklearn/base.py | 9 +- sklearn/externals/_sentinels.py | 82 ++ sklearn/tests/test_props.py | 590 +++++++++++++ sklearn/utils/__init__.py | 8 + sklearn/utils/estimator_checks.py | 11 +- sklearn/utils/metadata_requests.py | 840 +++++++++++++++++++ sklearn/utils/tests/test_estimator_checks.py | 12 + 10 files changed, 2255 insertions(+), 9 deletions(-) create mode 100644 doc/metadata_routing.rst create mode 100644 examples/metadata_routing.py create mode 100644 sklearn/externals/_sentinels.py create mode 100644 sklearn/tests/test_props.py create mode 100644 sklearn/utils/metadata_requests.py diff --git a/doc/metadata_routing.rst b/doc/metadata_routing.rst new file mode 100644 index 0000000000000..034c685bfb2eb --- /dev/null +++ b/doc/metadata_routing.rst @@ -0,0 +1,193 @@ + +.. _metadata_routing: + +Metadata Routing +================ + +This guide demonstrates how metadata such as ``sample_weight`` can be routed +and passed along to estimators, scorers, and CV splitters through +meta-estimators such as ``Pipeline`` and ``GridSearchCV``. In order to pass +metadata to a method such as ``fit`` or ``score``, the object accepting the +metadata, must *request* it. For estimators and splitters this is done via +``*_requests`` methods, e.g. ``fit_requests(...)``, and for scorers this is +done via ``score_requests`` method of a scorer. For grouped splitters such as +``GroupKFold`` a ``groups`` parameter is requested by default. This is best +demonstrated by the following examples. + +Usage Examples +************** +Here we present a few examples to show different common use-cases. The examples +in this section require the following imports and data:: + +.. TODO: add once implemented + >>> import numpy as np + >>> from sklearn.metrics import make_scorer, accuracy_score + >>> from sklearn.linear_model import LogisticRegressionCV + >>> from sklearn.linear_model import LogisticRegression + >>> from sklearn.model_selection import cross_validate + >>> from sklearn.model_selection import GridSearchCV + >>> from sklearn.model_selection import GroupKFold + >>> from sklearn.feature_selection import SelectKBest + >>> from sklearn.pipeline import make_pipeline + >>> n_samples, n_features = 100, 4 + >>> X = np.random.rand(n_samples, n_features) + >>> y = np.random.randint(0, 2, size=n_samples) + >>> my_groups = np.random.randint(0, 10, size=n_samples) + >>> my_weights = np.random.rand(n_samples) + >>> my_other_weights = np.random.rand(n_samples) + +Weighted scoring and fitting +---------------------------- + +Here ``GroupKFold`` requests ``groups`` by default. However, we need to +explicitly request weights in ``make_scorer`` and for ``LogisticRegressionCV``. +Both of these *consumers* understand the meaning of the key +``"sample_weight"``:: + +.. TODO: add once implemented + >>> weighted_acc = make_scorer(accuracy_score).score_requests( + ... sample_weight=True + ... ) + >>> lr = LogisticRegressionCV( + ... cv=GroupKFold(), scoring=weighted_acc, + ... ).fit_requests(sample_weight=True) + >>> cv_results = cross_validate( + ... lr, + ... X, + ... y, + ... cv=GroupKFold(), + ... props={"sample_weight": my_weights, "groups": my_groups}, + ... scoring=weighted_acc, + ... ) + +Error handling: if ``props={'sample_weigh': my_weights, ...}`` were passed +(note the typo), cross_validate would raise an error, since 'sample_weigh' was +not requested by any of its children. + +Weighted scoring and unweighted fitting +--------------------------------------- + +Since ``LogisticRegressionCV``, like all scikit-learn estimators, requires that +weights explicitly be requested, we need to explicitly say that +``sample_weight`` is not used for it, so that ``cross_validate`` doesn't pass +it along. + +.. TODO: add once implemented + >>> weighted_acc = make_scorer(accuracy_score).score_requests( + ... sample_weight=True + ... ) + >>> lr = LogisticRegressionCV( + ... cv=GroupKFold(), scoring=weighted_acc, + ... ).fit_requests(sample_weight=False) + >>> cv_results = cross_validate( + ... lr, + ... X, + ... y, + ... cv=GroupKFold(), + ... props={"sample_weight": my_weights, "groups": my_groups}, + ... scoring=weighted_acc, + ... ) + +Unweighted feature selection +---------------------------- + +Unlike ``LogisticRegressionCV``, ``SelectKBest`` doesn't accept weights and +therefore `"sample_weight"` is not routed to it:: + +.. TODO: add once implemented + >>> weighted_acc = make_scorer(accuracy_score).score_requests( + ... sample_weight=True + ... ) + >>> lr = LogisticRegressionCV( + ... cv=GroupKFold(), scoring=weighted_acc, + ... ).fit_requests(sample_weight=True) + >>> sel = SelectKBest(k=2) + >>> pipe = make_pipeline(sel, lr) + >>> cv_results = cross_validate( + ... pipe, + ... X, + ... y, + ... cv=GroupKFold(), + ... props={"sample_weight": my_weights, "groups": my_groups}, + ... scoring=weighted_acc, + ... ) + +Different scoring and fitting weights +------------------------------------- + +Despite ``make_scorer`` and ``LogisticRegressionCV`` both expecting a key +``sample_weight``, we can use aliases to pass different weights to different +consumers. In this example, we pass ``scoring_weight`` to the scorer, and +``fitting_weight`` to ``LogisticRegressionCV``:: + +.. TODO: add once implemented + >>> weighted_acc = make_scorer(accuracy_score).score_requests( + ... sample_weight="scoring_weight" + ... ) + >>> lr = LogisticRegressionCV( + ... cv=GroupKFold(), scoring=weighted_acc, + ... ).fit_requests(sample_weight="fitting_weight") + >>> cv_results = cross_validate( + ... lr, + ... X, + ... y, + ... cv=GroupKFold(), + ... props={ + ... "scoring_weight": my_weights, + ... "fitting_weight": my_other_weights, + ... "groups": my_groups, + ... }, + ... scoring=weighted_acc, + ... ) + +API Interface +************* + +A *consumer* is an object (estimator, meta-estimator, scorer, splitter) which +accepts and uses some metadata in at least one of their methods (``fit``, +``predict``, ``inverse_transform``, ``transform``, ``score``, ``split``). +Meta-estimators which only forward the metadata other objects (the child +estimator, scorers, or splitters) and don't use the metadata themselves are not +consumers. (Meta)Estimators which route metadata to other objects are routers. +An (meta)estimator can be a consumer and a router at the same time. +(Meta)Estimators and splitters expose a ``*_requests`` method for each method +which accepts at least one metadata. For instance, if an estimator supports +``sample_weight`` in ``fit`` and ``score``, it exposes +``estimator.fit_requests(sample_weight=value)`` and +``estimator.score_requests(sample_weight=value)``. Here ``value`` can be: + +- ``RequestType.REQUESTED`` or ``True``: method requests a ``sample_weight``. + This means if the metadata is provided, it will be used, otherwise no error + is raised. +- ``RequestType.UNREQUESTED`` or ``False``: method does not request a + ``sample_weight``. +- ``RequestType.ERROR_IF_PASSED`` or ``None``: router will raise an error if + ``sample_weight`` is passed. This is in almost all cases the default value + when an object is instantiated and ensures the user sets the metadata + requests explicitly when a metadata is passed. +- ``"param_name"``: if this estimator is used in a meta-estimator, the + meta-estimator should forward ``"param_name"`` as ``sample_weight`` to this + estimator. This means the mapping between the metadata required by the + object, e.g. ``sample_weight`` and what is provided by the user, e.g. + ``my_weights`` is done at the router level, and not by the object, e.g. + estimator, itself. + +For the scorers, this is done the same way, using ``.score_requests`` method. + +If a metadata, e.g. ``sample_weight`` is passed by the user, the metadata +request for all objects which potentially can accept ``sample_weight`` should +be set by the user, otherwise an error is raised by the router object. For +example, the following code would raise, since it hasn't been explicitly set +whether ``sample_weight`` should be passed to the estimator's scorer or not:: + +.. TODO: add once implemented + >>> param_grid = {"C": [0.1, 1]} + >>> lr = LogisticRegression().fit_requests(sample_weight=True) + >>> try: + ... GridSearchCV( + ... estimator=lr, param_grid=param_grid + ... ).fit(X, y, sample_weight=my_weights) + ... except ValueError as e: + ... print(e) + sample_weight is passed but is not explicitly set as requested or not. In + method: score diff --git a/doc/user_guide.rst b/doc/user_guide.rst index 7d48934d32727..7e656567f3249 100644 --- a/doc/user_guide.rst +++ b/doc/user_guide.rst @@ -30,3 +30,4 @@ User Guide computing.rst modules/model_persistence.rst common_pitfalls.rst + metadata_routing.rst diff --git a/examples/metadata_routing.py b/examples/metadata_routing.py new file mode 100644 index 0000000000000..9ee0090b32b59 --- /dev/null +++ b/examples/metadata_routing.py @@ -0,0 +1,518 @@ +""" +================ +Metadata Routing +================ + +.. currentmodule:: sklearn + +This document shows how you can use the metadata routing mechanism in +scikit-learn to route metadata through meta-estimators to the estimators using +them. To better understand the rest of the document, we need to introduce two +concepts: routers and consumers. A router is an object, in most cases a +meta-estimator, which routes given data and metadata to other objects and +estimators. A consumer, on the other hand, is an object which accepts and uses +a certain given metadata. For instance, an estimator taking into account +``sample_weight`` is a consumer of ``sample_weight``. It is possible for an +object to be both a router and a consumer. For instance, a meta-estimator may +take into account ``sample_weight`` in certain calculations, but it may also +route it to the underlying estimator. + +First a few imports and some random data for the rest of the script. +""" +# %% + +import numpy as np +from sklearn.base import BaseEstimator +from sklearn.base import ClassifierMixin +from sklearn.base import MetaEstimatorMixin +from sklearn.base import TransformerMixin +from sklearn.base import clone +from sklearn.utils.metadata_requests import RequestType +from sklearn.utils.metadata_requests import metadata_request_factory +from sklearn.utils.metadata_requests import MetadataRouter +from sklearn.utils.validation import check_is_fitted + +N, M = 100, 4 +X = np.random.rand(N, M) +y = np.random.randint(0, 2, size=N) +my_groups = np.random.randint(0, 10, size=N) +my_weights = np.random.rand(N) +my_other_weights = np.random.rand(N) + +# %% +# Estimators +# ---------- +# Here we demonstrate how an estimator can expose the required API to support +# metadata routing as a consumer. Imagine a simple classifier accepting ``foo`` +# as a metadata on its ``fit`` and ``bar`` in its ``predict`` method. We add +# two constructor arguments to helps us check whether an expected metadata is +# given or not. This is a minimal scikit-learn compatible classifier: + + +class ExampleClassifier(ClassifierMixin, BaseEstimator): + def __init__(self, foo_is_none=True, bar_is_none=True): + self.foo_is_none = foo_is_none + self.bar_is_none = bar_is_none + + def fit(self, X, y, foo=None): + if (foo is None) != self.foo_is_none: + raise ValueError("foo's value and foo_is_none disagree!") + # all classifiers need to expose a classes_ attribute once they're fit. + self.classes_ = np.array([0, 1]) + return self + + def predict(self, X, bar=None): + if (bar is None) != self.bar_is_none: + raise ValueError("bar's value and bar_is_none disagree!") + # return a constant value of 1, not a very smart classifier! + return np.ones(len(X)) + + +# %% +# The above estimator now has all it needs to consume metadata. This is done +# by some magic done in :class:`~base.BaseEstimator`. There are now three +# methods exposed by the above class: ``fit_requests``, ``predict_requests``, +# and ``get_metadata_request``. +# +# By default, no metadata is requested, which we can see as: + +ExampleClassifier().get_metadata_request() + +# %% +# The above output means that ``foo`` and ``bar`` are not requested, but if a +# router is given those metadata, it should raise an error, since the user has +# not explicitly set whether they are required or not. The same is true for +# ``sample_weight`` in ``score`` method, which is inherited from +# :class:`~base.ClassifierMixin`. In order to explicitly set request values for +# those metadata, we can use these methods: + +est = ExampleClassifier().fit_requests(foo=False).predict_requests(bar=True) +est.get_metadata_request() + +# %% +# As you can see, now the two metadata have explicit request values, one is +# requested and the other one is not. Instead of ``True`` and ``False``, we +# could also use the ``RequestType`` values. + +est = ( + ExampleClassifier() + .fit_requests(foo=RequestType.UNREQUESTED) + .predict_requests(bar=RequestType.REQUESTED) +) +est.get_metadata_request() + +# %% +# Please note that as long as the above estimator is not used in another +# meta-estimator, the user does not need to set any requests for the metadata. +# A simple usage of the above estimator would work as expected. Remember that +# ``{foo, bar}_is_none`` are for testing/demonstration purposes and don't have +# anything to do with the routing mechanisms. + +est = ExampleClassifier(foo_is_none=False, bar_is_none=False) +est.fit(X, y, foo=my_weights) +est.predict(X[:3, :], bar=my_groups) + +# %% +# Now let's have a meta-estimator, which doesn't do much other than routing the +# metadata correctly. + + +class MetaClassifier(MetaEstimatorMixin, ClassifierMixin, BaseEstimator): + def __init__(self, estimator): + self.estimator = estimator + + def fit(self, X, y, **fit_params): + if self.estimator is None: + raise ValueError("estimator cannot be None!") + + # meta-estimators are responsible for validating the given metadata + metadata_request_factory(self).fit.validate_metadata( + ignore_extras=False, kwargs=fit_params + ) + # we can use provided utility methods to map the given metadata to what + # is required by the underlying estimator + fit_params_ = metadata_request_factory(self.estimator).fit.get_method_input( + ignore_extras=False, kwargs=fit_params + ) + self.estimator_ = clone(self.estimator).fit(X, y, **fit_params_) + self.classes_ = self.estimator_.classes_ + return self + + def predict(self, X, **predict_params): + check_is_fitted(self) + # same as in ``fit``, we validate the given metadata + metadata_request_factory(self).predict.validate_metadata( + ignore_extras=False, kwargs=predict_params + ) + # and then prepare the input to the underlying ``predict`` method. + predict_params_ = metadata_request_factory( + self.estimator_ + ).predict.get_method_input(ignore_extras=False, kwargs=predict_params) + return self.estimator_.predict(X, **predict_params_) + + def get_metadata_request(self): + router = MetadataRouter().add( + self.estimator, mapping="one-to-one", overwrite=False, mask=True + ) + return router.get_metadata_request() + + +# %% +# Let's break down different parts of the above code. +# +# First, the :method:`~utils.metadata_requests.metadata_request_factory` takes +# an object from which a :class:`~utils.metadata_requests.MetadataRequest` can +# be constructed. This may be an estimator, or a dictionary representing a +# ``MetadataRequest`` object. If an estimator is given, it tries to call the +# estimator and construct the object from that, and if the estimator doesn't +# have such a method, then a default empty ``MetadataRequest`` is returned. +# +# Then in each method, we use the corresponding +# :method:`~utils.metadata_requests.MethodMetadataRequest.get_method_input` to +# construct a dictionary of the form ``{"metadata": value}`` to pass to the +# underlying estimator's method. Please note that since in this example the +# meta-estimator does not consume any of the given metadata itself, and there +# is only one object to which the metadata is passed, we have +# ``ignore_extras=False`` which means passed metadata are also validated in the +# sense that it will be checked if anything extra is given. This is to avoid +# silent bugs, and this is how it will work: + +est = MetaClassifier( + estimator=ExampleClassifier(foo_is_none=False).fit_requests(foo=True) +) +est.fit(X, y, foo=my_weights) + +# %% +# Note that the above example checks that ``foo`` is correctly passed to +# ``ExampleClassifier``, or else it would have raised: + +try: + est.fit(X, y) +except ValueError as e: + print(e) + +# %% +# If we pass an unknown metadata, it will be caught: +try: + est.fit(X, y, test=my_weights) +except ValueError as e: + print(e) + +# %% +# And if we pass something which is not explicitly requested: +try: + est.fit(X, y, foo=my_weights).predict(X, bar=my_groups) +except ValueError as e: + print(e) + +# %% +# Also, if we explicitly say it's not requested, but pass it: +est = MetaClassifier( + estimator=ExampleClassifier(foo_is_none=False) + .fit_requests(foo=True) + .predict_requests(bar=False) +) +try: + est.fit(X, y, foo=my_weights).predict(X[:3, :], bar=my_groups) +except ValueError as e: + print(e) + +# %% +# In order to understand the above implementation of ``get_metadata_request``, +# we need to also introduce an aliaced metadata. This is when an estimator +# requests a metadata with a different name than the default value. For +# instance, in a setting where there are two estimators in a pipeline, one +# could request ``sample_weight1`` and the other ``sample_weight2``. Note that +# this doesn't change what the estimator expects, it only tells the +# meta-estimator how to map provided metadata to what's required. Here's an +# example, where we pass ``aliased_foo`` to the meta-estimator, but the +# meta-estimator understands that ``aliased_foo`` is an alias for ``foo``, and +# passes it as ``foo`` to the underlying estimator: +est = MetaClassifier( + estimator=ExampleClassifier(foo_is_none=False).fit_requests(foo="aliased_foo") +) +est.fit(X, y, aliased_foo=my_weights) + +# %% +# And passing ``foo`` here will fail since it is requested with an alias: +try: + est.fit(X, y, foo=my_weights) +except ValueError as e: + print(e) + +# %% +# This leads us to the ``get_metadata_request``. The way routing works in +# scikit-learn is that consumers request what they need, and routers pass that +# along. But another thing a router does, is that it also exposes what it +# requires so that it can be used as a consumer inside another router, e.g. a +# pipeline inside a grid search object. However, routers (e.g. our +# meta-estimator) don't expose the mapping, and only expose what's required for +# them to do their job. In the above example, it looks like the following: +est.get_metadata_request()["fit"] + +# %% +# As you can see, the only metadata requested for method ``fit`` is +# ``"aliased_foo"``. This information is enough for another +# meta-estimator/router to know what needs to be passed to ``est``. In other +# words, ``foo`` is *masked* . The ``MetadataRouter`` class enables us to +# easily create the routing object which would create the output we need for +# our ``get_metadata_request``. In the above implementation, +# ``mapping="one-to-one"`` means all requests are mapped one to one from the +# sub-estimator to the meta-estimator's methods, and ``mask=True`` indicates +# that the requests should be masked, as explained. Masking is necessary since +# it's the meta-estimator which does the mapping between the alias and the +# original metadata name. Without it, having ``est`` in another meta-estimator +# would break the routing. Imagine this example: + +meta_est = MetaClassifier(estimator=est).fit(X, y, aliased_foo=my_weights) + +# %% +# In the above example, this is how each ``fit`` method will call the +# sub-estimator's ``fit``: +# +# meta_est.fit(X, y, aliased_foo=my_weights): +# ... # this estimator (est), expects aliased_foo as seen above +# self.estimator_.fit(X, y, aliased_foo=aliased_foo): +# ... # est passes aliased_foo's value as foo, which is expected +# # by the sub-estimator +# self.estimator_.fit(X, y, foo=aliased_foo) +# ... + +# %% +# Router and Consumer +# ------------------- +# To show how a slightly more complicated case would work, consider a case +# where a meta-estimator uses some metadata, but it also routes them to an +# underlying estimator. In this case, this meta-estimator is a consumer and a +# router at the same time. This is how we can implement one, and it is very +# similar to what we had before, with a few tweaks. + + +class RouterConsumerClassifier(MetaEstimatorMixin, ClassifierMixin, BaseEstimator): + def __init__(self, estimator, foo_is_none=True): + self.estimator = estimator + self.foo_is_none = foo_is_none + + def fit(self, X, y, foo, **fit_params): + if self.estimator is None: + raise ValueError("estimator cannot be None!") + + if (foo is None) != self.foo_is_none: + raise ValueError("foo's value and foo_is_none disagree!") + + if foo is not None: + fit_params["foo"] = foo + + # meta-estimators are responsible for validating the given metadata + metadata_request_factory(self).fit.validate_metadata( + ignore_extras=False, self_metadata=super(), kwargs=fit_params + ) + # we can use provided utility methods to map the given metadata to what + # is required by the underlying estimator + fit_params_ = metadata_request_factory(self.estimator).fit.get_method_input( + ignore_extras=False, kwargs=fit_params + ) + self.estimator_ = clone(self.estimator).fit(X, y, **fit_params_) + self.classes_ = self.estimator_.classes_ + return self + + def predict(self, X, **predict_params): + check_is_fitted(self) + # same as in ``fit``, we validate the given metadata + metadata_request_factory(self).predict.validate_metadata( + ignore_extras=False, kwargs=predict_params + ) + # and then prepare the input to the underlying ``predict`` method. + predict_params_ = metadata_request_factory( + self.estimator_ + ).predict.get_method_input(ignore_extras=False, kwargs=predict_params) + return self.estimator_.predict(X, **predict_params_) + + def get_metadata_request(self): + router = ( + MetadataRouter() + .add(super(), mapping="one-to-one", overwrite=False, mask=False) + .add(self.estimator, mapping="one-to-one", overwrite="smart", mask=True) + ) + return router.get_metadata_request() + + +# %% +# The two key parts where the above estimator differs from our previous +# meta-estimator is validation in ``fit``, and generating routing data in +# ``get_metadata_request``. In ``fit``, we pass ``self_metadata=super()`` to +# ``validate_metadata``. This is important since consumers don't validate how +# metadata is passed to them, it's only done by routers. In this case, this +# means validation should be done for the metadata consumed by the +# sub-estimator, but not for the metadata consumed by the meta-estimator +# itself. +# +# In ``get_metadata_request``, we add what's consumed by this meta-estimator +# without masking them, before adding what's requested by the sub-estimator. +# Passing ``super()`` here means only what's explicitly mentioned in the +# methods' signature is considered as metadata consumed by this estimator; in +# this case fit's ``foo``. Let's see what the routing metadata looks like with +# different settings: + +# %% +# no metadata requested +est = RouterConsumerClassifier(estimator=ExampleClassifier()) +est.get_metadata_request()["fit"] + + +# %% +# ``foo`` requested by child estimator +est = RouterConsumerClassifier(estimator=ExampleClassifier().fit_requests(foo=True)) +est.get_metadata_request()["fit"] +# %% +# ``foo`` requested by meta-estimator +est = RouterConsumerClassifier(estimator=ExampleClassifier()).fit_requests(foo=True) +est.get_metadata_request()["fit"] + +# %% +# As you can see, the last two are identical, which is fine since that's what a +# meta-estimator having ``RouterConsumerClassifier`` as a sub-estimator needs. +# The situation is different if we use named aliases: +# +# Aliased on both +est = RouterConsumerClassifier( + foo_is_none=False, + estimator=ExampleClassifier(foo_is_none=False).fit_requests( + foo="first_aliased_foo" + ), +).fit_requests(foo="second_aliased_foo") +est.get_metadata_request()["fit"] + +# %% +# However, ``fit`` of the meta-estimator only needs the alias for the +# sub-estimator: +est.fit(X, y, foo=my_weights, first_aliased_foo=my_other_weights) + +# %% +# Alias only on the sub-estimator +est = RouterConsumerClassifier( + estimator=ExampleClassifier().fit_requests(foo="aliased_foo") +).fit_requests(foo=True) +est.get_metadata_request()["fit"] + +# %% +# Alias only on the meta-estimator. This example raises an error since there +# will be two conflicting values for routing ``foo``. +est = RouterConsumerClassifier( + estimator=ExampleClassifier().fit_requests(foo=True) +).fit_requests(foo="aliased_foo") +try: + est.get_metadata_request()["fit"] +except ValueError as e: + print(e) + + +# %% +# Simple Pipeline +# --------------- +# A slightly more complicated use-case is a meta-estimator which does something +# similar to the ``Pipeline``. Here is a meta-estimator, which accepts a +# transformer and a classifier, and applies the transformer before running the +# classifier. + + +class SimplePipeline(ClassifierMixin, BaseEstimator): + _required_parameters = ["estimator"] + + def __init__(self, transformer, classifier): + self.transformer = transformer + self.classifier = classifier + + def fit(self, X, y, **fit_params): + metadata_request_factory(self).fit.validate_metadata(kwargs=fit_params) + + transformer_fit_params = metadata_request_factory( + self.transformer + ).fit.get_method_input(ignore_extras=True, kwargs=fit_params) + transformer_transform_params = metadata_request_factory( + self.transformer + ).transform.get_method_input(ignore_extras=True, kwargs=fit_params) + self.transformer_ = clone(self.transformer).fit(X, y, **transformer_fit_params) + X_transformed = self.transformer_.transform(X, **transformer_transform_params) + + classifier_fit_params = metadata_request_factory( + self.classifier + ).fit.get_method_input(ignore_extras=True, kwargs=fit_params) + self.classifier_ = clone(self.classifier).fit( + X_transformed, y, **classifier_fit_params + ) + return self + + def predict(self, X, **predict_params): + metadata_request_factory(self).predict.validate_metadata(kwargs=predict_params) + + transformer_transform_params = metadata_request_factory( + self.transformer + ).transform.get_method_input(ignore_extras=True, kwargs=predict_params) + X_transformed = self.transformer_.transform(X, **transformer_transform_params) + + classifier_predict_params = metadata_request_factory( + self.classifier + ).predict.get_method_input(ignore_extras=True, kwargs=predict_params) + return self.classifier_.predict(X_transformed, **classifier_predict_params) + + def get_metadata_request(self): + router = ( + MetadataRouter() + .add( + self.transformer, + mapping={ + "fit": ["fit", "transform"], + "predict": "transform", + }, + overwrite="smart", + mask=True, + ) + .add(self.classifier, mapping="one-to-one", overwrite="smart", mask=True) + ) + return router.get_metadata_request() + + +# %% +# As you can see, we use the transformer's ``transform`` and ``fit`` methods in +# ``fit``, and its ``transform`` method in ``predict``, and that's what you see +# implemented in the routing structure of the pipeline class. In order to test +# the above pipeline, let's add an example transformer. + + +class ExampleTransformer(TransformerMixin, BaseEstimator): + def fit(self, X, y, foo=None): + if foo is None: + raise ValueError("foo is None!") + return self + + def transform(self, X, bar=None): + if bar is None: + raise ValueError("bar is None!") + return X + + +# %% +# Now we can test our pipeline, and see if metadata is correctly passed around. +# This example uses our simple pipeline, and our transformer, and our +# consumer+router estimator which uses our simple classifier. + +est = SimplePipeline( + transformer=ExampleTransformer() + # we transformer's fit to receive foo + .fit_requests(foo=True) + # we want transformer's transform to receive bar + .transform_requests(bar=True), + classifier=RouterConsumerClassifier( + foo_is_none=False, + estimator=ExampleClassifier(foo_is_none=False) + # we want this sub-estimator to receive foo in fit + .fit_requests(foo=True) + # but not bar in predict + .predict_requests(bar=False), + ).fit_requests( + # and we want the meta-estimator to receive foo as well + foo=True + ), +) +est.fit(X, y, foo=my_weights, bar=my_groups).predict(X[:3], bar=my_groups) diff --git a/sklearn/base.py b/sklearn/base.py index 60fc82eff6088..bde80fc65b3fb 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -25,6 +25,7 @@ from .utils.validation import _num_features from .utils.validation import _check_feature_names_in from .utils._estimator_html_repr import estimator_html_repr +from .utils.metadata_requests import _MetadataRequester from .utils.validation import _get_feature_names @@ -79,7 +80,13 @@ def clone(estimator, *, safe=True): new_object_params = estimator.get_params(deep=False) for name, param in new_object_params.items(): new_object_params[name] = clone(param, safe=False) + new_object = klass(**new_object_params) + try: + new_object._metadata_request = copy.deepcopy(estimator._metadata_request) + except AttributeError: + pass + params_set = new_object.get_params(deep=False) # quick sanity check of the parameters of the clone @@ -144,7 +151,7 @@ def _pprint(params, offset=0, printer=repr): return lines -class BaseEstimator: +class BaseEstimator(_MetadataRequester): """Base class for all estimators in scikit-learn. Notes diff --git a/sklearn/externals/_sentinels.py b/sklearn/externals/_sentinels.py new file mode 100644 index 0000000000000..662a82864ca8d --- /dev/null +++ b/sklearn/externals/_sentinels.py @@ -0,0 +1,82 @@ +# type: ignore +""" +Copied from https://github.com/taleinat/python-stdlib-sentinels +PEP-0661: Status: Draft +""" +import sys as _sys +from typing import Optional + + +__all__ = ["sentinel"] + + +def sentinel( + name: str, + repr: Optional[str] = None, + module: Optional[str] = None, +): + """Create a unique sentinel object. + + *name* should be the fully-qualified name of the variable to which the + return value shall be assigned. + + *repr*, if supplied, will be used for the repr of the sentinel object. + If not provided, "" will be used (with any leading class names + removed). + + *module*, if supplied, will be used as the module name for the purpose + of setting a unique name for the sentinels unique class. The class is + set as an attribute of this name on the "sentinels" module, so that it + may be found by the pickling mechanism. In most cases, the module name + does not need to be provided, and it will be found by inspecting the + stack frame. + """ + name = _sys.intern(str(name)) + repr = repr or f'<{name.rsplit(".", 1)[-1]}>' + + if module is None: + try: + module = _get_parent_frame().f_globals.get("__name__", "__main__") + except (AttributeError, ValueError): + pass + class_name = _sys.intern(_get_class_name(name, module)) + + class_namespace = { + "__repr__": lambda self: repr, + } + cls = type(class_name, (), class_namespace) + cls.__module__ = __name__ + globals()[class_name] = cls + + sentinel = cls() + + def __new__(cls): + return sentinel + + __new__.__qualname__ = f"{class_name}.__new__" + cls.__new__ = __new__ + + return sentinel + + +if hasattr(_sys, "_getframe"): + _get_parent_frame = lambda: _sys._getframe(2) +else: # pragma: no cover + + def _get_parent_frame(): + """Return the frame object for the caller's parent stack frame.""" + try: + raise Exception + except Exception: + return _sys.exc_info()[2].tb_frame.f_back.f_back + + +def _get_class_name( + sentinel_qualname: str, + module_name: Optional[str] = None, +) -> str: + return ( + "_sentinel_type__" + f'{module_name.replace(".", "_") + "__" if module_name else ""}' + f'{sentinel_qualname.replace(".", "_")}' + ) diff --git a/sklearn/tests/test_props.py b/sklearn/tests/test_props.py new file mode 100644 index 0000000000000..57030949a7e68 --- /dev/null +++ b/sklearn/tests/test_props.py @@ -0,0 +1,590 @@ +import re +import numpy as np +import pytest + +from sklearn.base import BaseEstimator +from sklearn.base import ClassifierMixin +from sklearn.base import TransformerMixin +from sklearn.base import MetaEstimatorMixin +from sklearn.base import clone +from sklearn.utils import MetadataRequest +from sklearn.utils.validation import check_is_fitted +from sklearn.utils.metadata_requests import RequestType +from sklearn.utils.metadata_requests import metadata_request_factory +from sklearn.utils.metadata_requests import MetadataRouter +from sklearn.utils.metadata_requests import MethodMetadataRequest + +from sklearn.base import _MetadataRequester + +N, M = 100, 4 +X = np.random.rand(N, M) +y = np.random.randint(0, 2, size=N) +my_groups = np.random.randint(0, 10, size=N) +my_weights = np.random.rand(N) +my_other_weights = np.random.rand(N) + + +def assert_request_is_empty(metadata_request, exclude=None): + """Check if a metadata request dict is empty. + + One can exclude a method or a list of methods from the check using the + ``exclude`` perameter. + """ + if isinstance(metadata_request, MetadataRequest): + metadata_request = metadata_request.to_dict() + if exclude is None: + exclude = [] + for method, request in metadata_request.items(): + if method in exclude: + continue + props = [ + prop + for prop, alias in request.items() + if isinstance(alias, str) + or RequestType(alias) != RequestType.ERROR_IF_PASSED + ] + assert not len(props) + + +class TestEstimatorNoMetadata(ClassifierMixin, BaseEstimator): + """An estimator which accepts no metadata on any method.""" + + def fit(self, X, y): + return self + + def predict(self, X): + return np.ones(len(X)) + + +class TestEstimatorFitMetadata(ClassifierMixin, BaseEstimator): + """An estimator accepting two metadata in its ``fit`` method.""" + + def __init__(self, sample_weight_none=True, brand_none=True): + self.sample_weight_none = sample_weight_none + self.brand_none = brand_none + + def fit(self, X, y, sample_weight=None, brand=None): + assert ( + sample_weight is None + ) == self.sample_weight_none, "sample_weight and sample_weight_none don't agree" + assert (brand is None) == self.brand_none, "brand and brand_none don't agree" + return self + + def predict(self, X): + return np.ones(len(X)) + + +class TestSimpleMetaEstimator(MetaEstimatorMixin, ClassifierMixin, BaseEstimator): + """A meta-estimator which also consumes sample_weight itself in ``fit``.""" + + def __init__(self, estimator, sample_weight_none): + self.sample_weight_none = sample_weight_none + self.estimator = estimator + + def fit(self, X, y, sample_weight=None, **kwargs): + assert ( + sample_weight is None + ) == self.sample_weight_none, "sample_weight and sample_weight_none don't agree" + + if sample_weight is not None: + kwargs["sample_weight"] = sample_weight + metadata_request_factory(self).fit.validate_metadata( + ignore_extras=False, self_metadata=super(), kwargs=kwargs + ) + fit_params = metadata_request_factory(self.estimator).fit.get_method_input( + ignore_extras=True, kwargs=kwargs + ) + self.estimator_ = clone(self.estimator).fit(X, y, **fit_params) + return self + + def get_metadata_request(self): + router = MetadataRouter().add(super(), mask=False) + router.add(self.estimator, mapping={"fit": "fit"}, mask=True, overwrite="smart") + return router.get_metadata_request() + + +class TestTransformer(TransformerMixin, BaseEstimator): + """A transformer which accepts metadata on fit and transform.""" + + def __init__( + self, + brand_none=True, + new_param_none=True, + fit_sample_weight_none=True, + transform_sample_weight_none=True, + ): + self.brand_none = brand_none + self.new_param_none = new_param_none + self.fit_sample_weight_none = fit_sample_weight_none + self.transform_sample_weight_none = transform_sample_weight_none + + def fit(self, X, y=None, brand=None, new_param=None, sample_weight=None): + assert ( + sample_weight is None + ) == self.fit_sample_weight_none, ( + "sample_weight and fit_sample_weight_none don't agree" + ) + assert ( + new_param is None + ) == self.new_param_none, "new_param and new_param_none don't agree" + assert (brand is None) == self.brand_none, "brand and brand_none don't agree" + return self + + def transform(self, X, y=None, sample_weight=None): + assert ( + sample_weight is None + ) == self.transform_sample_weight_none, ( + "sample_weight and transform_sample_weight_none don't agree" + ) + return X + + +class TestMetaTransformer(MetaEstimatorMixin, TransformerMixin, BaseEstimator): + """A simple meta-transformer.""" + + def __init__(self, transformer): + self.transformer = transformer + + def fit(self, X, y=None, **fit_params): + metadata_request_factory(self).fit.validate_metadata( + ignore_extras=False, + kwargs=fit_params, + ) + fit_params_ = metadata_request_factory(self.transformer).fit.get_method_input( + ignore_extras=False, kwargs=fit_params + ) + self.transformer_ = clone(self.transformer).fit(X, y, **fit_params_) + return self + + def transform(self, X, y=None, **transform_params): + # not validating since the following would validate due to ignore_extras=False + transform_params_ = metadata_request_factory( + self.transformer + ).transform.get_method_input(ignore_extras=False, kwargs=transform_params) + return self.transformer_.transform(X, **transform_params_) + + +class SimplePipeline(BaseEstimator): + """A very simple pipeline, assuming the last step is always a predictor.""" + + def __init__(self, steps): + self.steps = steps + + def fit(self, X, y, **fit_params): + self.steps_ = [] + metadata_request_factory(self).fit.validate( + ignore_extras=False, kwargs=fit_params + ) + X_transformed = X + for step in self.steps[:-1]: + requests = metadata_request_factory(step) + step_fit_params = requests.fit.get_method_input( + ignore_extras=True, kwargs=fit_params + ) + transformer = clone(step).fit(X_transformed, y, **step_fit_params) + self.steps_.append(transformer) + step_transform_params = requests.transform.get_method_input( + ignore_extras=True, kwargs=fit_params + ) + X_transformed = transformer.transform( + X_transformed, **step_transform_params + ) + + requests = metadata_request_factory(step) + step_fit_params = requests.fit.get_method_input( + ignore_extras=True, kwargs=fit_params + ) + self.steps_.append( + clone(self.steps[-1]).fit(X_transformed, y, **step_fit_params) + ) + return self + + def predict(self, X, **predict_params): + check_is_fitted(self) + X_transformed = X + metadata_request_factory(self).predict.validate_metadata( + ignore_extras=False, kwargs=predict_params + ) + for step in self.steps_[:-1]: + step_transform_params = metadata_request_factory( + step + ).transform.get_method_input(ignore_extras=True, kwargs=predict_params) + X_transformed = step.transform(X, **step_transform_params) + + step_predict_params = metadata_request_factory( + self.steps_[-1] + ).predict.get_method_input(ignore_extras=True, kwargs=predict_params) + return self.steps_[-1].predict(X_transformed, **step_predict_params) + + def get_metadata_request(self): + router = MetadataRouter() + if len(self.steps) > 1: + router.add( + self.steps[:-1], + mask=True, + mapping={"predict": "transform", "fit": ["transform", "fit"]}, + overwrite="smart", + ) + router.add(self.steps[-1], overwrite="smart", mapping="one-to-one") + return router.get_metadata_request() + + +def test_assert_request_is_empty(): + requests = MetadataRequest() + assert_request_is_empty(requests) + assert_request_is_empty(requests.to_dict()) + + requests.fit.add_request(prop="foo", alias=RequestType.ERROR_IF_PASSED) + # this should still work, since ERROR_IF_PASSED is the default value + assert_request_is_empty(requests) + + requests.fit.add_request(prop="bar", alias="value") + with pytest.raises(AssertionError): + # now requests is no more empty + assert_request_is_empty(requests) + + # but one can exclude a method + assert_request_is_empty(requests, exclude="fit") + + requests.score.add_request(prop="carrot", alias=RequestType.REQUESTED) + with pytest.raises(AssertionError): + # excluding `fit` is not enough + assert_request_is_empty(requests, exclude="fit") + + # and excluding both fit and score would avoid an exception + assert_request_is_empty(requests, exclude=["fit", "score"]) + + +def test_default_requests(): + class OddEstimator(BaseEstimator): + __metadata_request__sample_weight = { + "fit": {"sample_weight": RequestType.REQUESTED} # type: ignore + } # set a different default request + + odd_request = metadata_request_factory(OddEstimator()) + assert odd_request.fit.requests == {"sample_weight": RequestType.REQUESTED} + + # check other test estimators + assert not len(metadata_request_factory(TestEstimatorNoMetadata()).fit.requests) + assert_request_is_empty(TestEstimatorNoMetadata().get_metadata_request()) + + trs_request = metadata_request_factory(TestTransformer()) + assert trs_request.fit.requests == { + "sample_weight": RequestType(None), + "brand": RequestType(None), + "new_param": RequestType(None), + } + assert trs_request.transform.requests == { + "sample_weight": RequestType(None), + } + assert_request_is_empty(trs_request) + + est_request = metadata_request_factory(TestEstimatorFitMetadata()) + assert est_request.fit.requests == { + "sample_weight": RequestType(None), + "brand": RequestType(None), + } + assert_request_is_empty(est_request) + + +def test_simple_metadata_routing(): + # Tests that metadata is properly routed + # The underlying estimator doesn't accept or request metadata + cls = TestSimpleMetaEstimator( + estimator=TestEstimatorNoMetadata(), sample_weight_none=True + ) + cls.fit(X, y) + + # Meta-estimator consumes sample_weight, but doesn't forward it to the underlying + # estimator + cls = TestSimpleMetaEstimator( + estimator=TestEstimatorNoMetadata(), sample_weight_none=False + ) + cls.fit(X, y, sample_weight=my_weights) + + # If the estimator accepts the metadata but doesn't explicitly say it doesn't + # need it, there's an error + cls = TestSimpleMetaEstimator( + estimator=TestEstimatorFitMetadata(), + sample_weight_none=False, + ) + with pytest.raises( + ValueError, + match=( + "sample_weight is passed but is not explicitly set as requested or not. In" + " method: fit" + ), + ): + cls.fit(X, y, sample_weight=my_weights) + + # Explicitly saying the estimator doesn't need it, makes the error go away, + # but if a metadata is passed which is not requested by any object/estimator, + # there will be still an error + cls = TestSimpleMetaEstimator( + estimator=TestEstimatorFitMetadata().fit_requests(sample_weight=False), + sample_weight_none=False, + ) + with pytest.raises( + ValueError, + match=( + "sample_weight is not requested by any estimator but is provided. In" + " method: fit" + ), + ): + cls.fit(X, y, sample_weight=my_weights) + + # Requesting a metadata will make the meta-estimator forward it correctly + cls = TestSimpleMetaEstimator( + estimator=TestEstimatorFitMetadata(sample_weight_none=False).fit_requests( + sample_weight=True + ), + sample_weight_none=False, + ) + cls.fit(X, y, sample_weight=my_weights) + + +def test_invalid_metadata(): + # check that passing wrong metadata raises an error + trs = TestMetaTransformer( + transformer=TestTransformer().transform_requests(sample_weight=True) + ) + with pytest.raises( + ValueError, + match=( + re.escape( + "Metadata passed which is not understood: ['other_param']. In method:" + " transform" + ) + ), + ): + trs.fit(X, y).transform(X, other_param=my_weights) + + # passing a metadata which is not requested by any estimator should also raise + trs = TestMetaTransformer( + transformer=TestTransformer().transform_requests(sample_weight=False) + ) + with pytest.raises( + ValueError, + match=( + "sample_weight is not requested by any estimator but is provided. In" + " method: transform" + ), + ): + trs.fit(X, y).transform(X, sample_weight=my_weights) + + +def test_get_metadata_request(): + class TestDefaultsBadMetadataName(_MetadataRequester): + __metadata_request__sample_weight = { + "fit": "sample_weight", + "score": "sample_weight", + } + + __metadata_request__my_param = { + "score": {"my_param": True}, + # the following method raise an error + "other_method": {"my_param": True}, + } + + __metadata_request__my_other_param = { + "score": "my_other_param", + # this should raise since the name is different than the metadata + "fit": "my_param", + } + + class TestDefaultsBadMethodName(_MetadataRequester): + __metadata_request__sample_weight = { + "fit": "sample_weight", + "score": "sample_weight", + } + + __metadata_request__my_param = { + "score": {"my_param": True}, + # the following method raise an error + "other_method": {"my_param": True}, + } + + __metadata_request__my_other_param = { + "score": "my_other_param", + "fit": "my_other_param", + } + + class TestDefaults(_MetadataRequester): + __metadata_request__sample_weight = { + "fit": "sample_weight", + "score": "sample_weight", + } + + __metadata_request__my_param = { + "score": {"my_param": True}, + "predict": {"my_param": True}, + } + + __metadata_request__my_other_param = { + "score": "my_other_param", + "fit": "my_other_param", + } + + with pytest.raises(ValueError, match="Expected all metadata to be called"): + TestDefaultsBadMetadataName().get_metadata_request() + + with pytest.raises(ValueError, match="other_method is not supported as a method"): + TestDefaultsBadMethodName().get_metadata_request() + + expected = { + "score": { + "my_param": RequestType(True), + "my_other_param": RequestType(None), + "sample_weight": RequestType(None), + }, + "fit": { + "my_other_param": RequestType(None), + "sample_weight": RequestType(None), + }, + "partial_fit": {}, + "predict": {"my_param": RequestType(True)}, + "transform": {}, + "inverse_transform": {}, + "split": {}, + } + assert TestDefaults().get_metadata_request() == expected + + est = TestDefaults().score_requests(my_param="other_param") + expected = { + "score": { + "my_param": "other_param", + "my_other_param": RequestType(None), + "sample_weight": RequestType(None), + }, + "fit": { + "my_other_param": RequestType(None), + "sample_weight": RequestType(None), + }, + "partial_fit": {}, + "predict": {"my_param": RequestType(True)}, + "transform": {}, + "inverse_transform": {}, + "split": {}, + } + assert est.get_metadata_request() == expected + + est = TestDefaults().fit_requests(sample_weight=True) + expected = { + "score": { + "my_param": RequestType(True), + "my_other_param": RequestType(None), + "sample_weight": RequestType(None), + }, + "fit": { + "my_other_param": RequestType(None), + "sample_weight": RequestType(True), + }, + "partial_fit": {}, + "predict": {"my_param": RequestType(True)}, + "transform": {}, + "inverse_transform": {}, + "split": {}, + } + assert est.get_metadata_request() == expected + + +def test__get_default_requests(): + # Test _get_default_requests method + class ExplicitRequest(BaseEstimator): + __metadata_request__prop = {"fit": "prop"} + + def fit(self, X, y): + return self + + assert metadata_request_factory(ExplicitRequest()).fit.requests == { + "prop": RequestType.ERROR_IF_PASSED + } + assert_request_is_empty(ExplicitRequest().get_metadata_request(), exclude="fit") + + class ExplicitRequestOverwrite(BaseEstimator): + __metadata_request__prop = {"fit": {"prop": RequestType.REQUESTED}} + + def fit(self, X, y, prop=None, **kwargs): + return self + + assert metadata_request_factory(ExplicitRequestOverwrite()).fit.requests == { + "prop": RequestType.REQUESTED + } + assert_request_is_empty( + ExplicitRequestOverwrite().get_metadata_request(), exclude="fit" + ) + + class ImplicitRequest(BaseEstimator): + def fit(self, X, y, prop=None, **kwargs): + return self + + assert metadata_request_factory(ImplicitRequest()).fit.requests == { + "prop": RequestType.ERROR_IF_PASSED + } + assert_request_is_empty(ImplicitRequest().get_metadata_request(), exclude="fit") + + class ImplicitRequestRemoval(BaseEstimator): + __metadata_request__prop = {"fit": {"prop": RequestType.UNUSED}} + + def fit(self, X, y, prop=None, **kwargs): + return self + + assert metadata_request_factory(ImplicitRequestRemoval()).fit.requests == {} + assert_request_is_empty(ImplicitRequestRemoval().get_metadata_request()) + + +def test_method_metadata_request(): + mmr = MethodMetadataRequest(name="fit") + with pytest.raises( + ValueError, + match="overwrite can only be one of {True, False, 'smart', 'ignore'}.", + ): + mmr.add_request(prop="test", alias=None, overwrite="test") + + with pytest.raises(ValueError, match="Expected all metadata to be called test"): + mmr.add_request(prop="foo", alias="bar", expected_metadata="test") + + with pytest.raises(ValueError, match="Aliasing is not allowed"): + mmr.add_request(prop="foo", alias="bar", allow_aliasing=False) + + with pytest.raises(ValueError, match="alias should be either a string or"): + mmr.add_request(prop="foo", alias=1.4) + + mmr.add_request(prop="foo", alias=None) + assert mmr.requests == {"foo": RequestType.ERROR_IF_PASSED} + with pytest.raises(ValueError, match="foo is already requested"): + mmr.add_request(prop="foo", alias=True) + with pytest.raises(ValueError, match="foo is already requested"): + mmr.add_request(prop="foo", alias=True) + mmr.add_request(prop="foo", alias=True, overwrite="smart") + assert mmr.requests == {"foo": RequestType.REQUESTED} + + with pytest.raises(ValueError, match="Can only add another MethodMetadataRequest"): + mmr.merge_method_request({}) + + assert MethodMetadataRequest.from_dict(None, name="fit").requests == {} + assert MethodMetadataRequest.from_dict("foo", name="fit").requests == { + "foo": RequestType.ERROR_IF_PASSED + } + assert MethodMetadataRequest.from_dict(["foo", "bar"], name="fit").requests == { + "foo": RequestType.ERROR_IF_PASSED, + "bar": RequestType.ERROR_IF_PASSED, + } + + +def test_metadata_request_factory(): + class Consumer(BaseEstimator): + __metadata_request__prop = {"fit": "prop"} + + assert_request_is_empty(metadata_request_factory(None)) + assert_request_is_empty(metadata_request_factory({})) + assert_request_is_empty(metadata_request_factory(object())) + + mr = MetadataRequest({"fit": "foo"}, default="bar") + mr_factory = metadata_request_factory(mr) + assert_request_is_empty(mr_factory, exclude="fit") + assert mr_factory.fit.requests == {"foo": "bar"} + + mr = metadata_request_factory(Consumer()) + assert_request_is_empty(mr, exclude="fit") + assert mr.fit.requests == {"prop": RequestType.ERROR_IF_PASSED} diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index 8290318d35deb..30ef0b0c3cb7f 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -40,6 +40,10 @@ check_scalar, ) from .. import get_config +from .metadata_requests import MetadataRequest +from .metadata_requests import MethodMetadataRequest +from .metadata_requests import metadata_request_factory +from .metadata_requests import MetadataRouter # Do not deprecate parallel_backend and register_parallel_backend as they are @@ -74,6 +78,10 @@ "all_estimators", "DataConversionWarning", "estimator_html_repr", + "MetadataRequest", + "metadata_request_factory", + "MetadataRouter", + "MethodMetadataRequest", ] IS_PYPY = platform.python_implementation() == "PyPy" diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 37537bc1b0498..85b87655a4dc3 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -2941,6 +2941,8 @@ def check_no_attributes_set_in_init(name, estimator_orig): # Test for no setting apart from parameters during init invalid_attr = set(vars(estimator)) - set(init_params) - set(parents_init_params) + # Ignore private attributes + invalid_attr = set([attr for attr in invalid_attr if not attr.startswith("_")]) assert not invalid_attr, ( "Estimator %s should not set any attribute apart" " from parameters during init. Found attributes %s." @@ -3779,14 +3781,7 @@ def check_dataframe_column_names_consistency(name, estimator_orig): check_methods.append((method, callable_method)) for _, method in check_methods: - with warnings.catch_warnings(): - warnings.filterwarnings( - "error", - message="X does not have valid feature names", - category=UserWarning, - module="sklearn", - ) - method(X) # works without UserWarning for valid features + method(X) # works invalid_names = [ (names[::-1], "Feature names must be in the same order as they were in fit."), diff --git a/sklearn/utils/metadata_requests.py b/sklearn/utils/metadata_requests.py new file mode 100644 index 0000000000000..c04270f7c5cc2 --- /dev/null +++ b/sklearn/utils/metadata_requests.py @@ -0,0 +1,840 @@ +# from copy import deepcopy +import inspect +from enum import Enum +from collections import defaultdict +from typing import Union, Optional +from ..externals._sentinels import sentinel # type: ignore # mypy error!!! + + +class RequestType(Enum): + UNREQUESTED = False + REQUESTED = True + ERROR_IF_PASSED = None + # this sentinel is used in `__metadata_request__*` attributes to indicate + # that a metadata is not present even though it may be present in the + # corresponding method's signature. + UNUSED = sentinel("UNUSED") + + +# this sentinel is the default used in `{method}_requests` methods to indicate +# no change requested by the user. +UNCHANGED = sentinel("UNCHANGED") + +METHODS = [ + "fit", + "partial_fit", + "predict", + "score", + "split", + "transform", + "inverse_transform", +] + + +REQUESTER_DOC = """ Request metadata passed to the ``{method}`` method. + + Parameters + ---------- +""" +REQUESTER_DOC_PARAM = """ {metadata} : RequestType, str, True, False, or None, \ + default=UNCHANGED + Whether {metadata} should be passed to {method} by meta-estimators or + not, and if yes, should it have an alias. + + - True or RequestType.REQUESTED: {metadata} is requested, and passed to \ +{method} if provided. + + - False or RequestType.UNREQUESTED: {metadata} is not requested and the \ +meta-estimator will not pass it to {method}. + + - None or RequestType.ERROR_IF_PASSED: {metadata} is not requested, and \ +the meta-estimator will raise an error if the user provides {metadata} + + - str: {metadata} should be passed to the meta-estimator with this given \ +alias instead of the original name. + +""" +REQUESTER_DOC_RETURN = """ Returns + ------- + self + Returns the object itself. +""" + + +class MethodMetadataRequest: + """Contains the metadata request info for a single method. + + .. versionadded:: 1.1 + + Parameters + ---------- + name : str + The name of the method to which these requests belong. + """ + + def __init__(self, name): + self.requests = dict() + self.name = name + + def add_request( + self, + *, + prop, + alias, + allow_aliasing=True, + overwrite=False, + expected_metadata=None, + ): + """Add request info for a prop. + + Parameters + ---------- + prop : str + The property for which a request is set. + + alias : str, RequestType, or {True, False, None} + The alias which is routed to `prop` + + - str: the name which should be used as an alias when a meta-estimator + routes the metadata. + + - True or RequestType.REQUESTED: requested + + - False or RequestType.UNREQUESTED: not requested + + - None or RequestType.ERROR_IF_PASSED: error if passed + + allow_aliasing : bool, default=True + If False, alias should be the same as prop if it's a string. + + overwrite : bool or str, default=False + + - True: ``alias`` replaces the existing routing. + + - False: a ``ValueError`` is raised if the given value conflicts + with an existing one. + + - "smart": overwrite in this order: + ``RequestType.REQUESTED`` over ``RequestType.UNREQUESTED`` over + ``RequestType.ERROR_IF_PASSED``, and error if existing value is + a string. + + - "ignore": ignore the requested metadata if it already exists. + + expected_metadata : str, default=None + If provided, all props should be the same as this value. It used to + handle default values. + """ + if overwrite not in {True, False, "smart", "ignore"}: + raise ValueError( + "overwrite can only be one of {True, False, 'smart', 'ignore'}; " + f"but f{overwrite} is given." + ) + if expected_metadata is not None and expected_metadata != prop: + raise ValueError( + f"Expected all metadata to be called {expected_metadata} but " + f"{prop} was passed." + ) + if not allow_aliasing and isinstance(alias, str) and prop != alias: + raise ValueError( + "Aliasing is not allowed, prop and alias should " + "be the same strings if alias is a string." + ) + + if not isinstance(alias, str): + try: + alias = RequestType(alias) + except ValueError: + raise ValueError( + "alias should be either a string or one of " + "{None, True, False}, or a RequestType." + ) + + if alias == prop: + alias = RequestType.REQUESTED + + if alias == RequestType.UNUSED and prop in self.requests: + del self.requests[prop] + elif prop not in self.requests or overwrite is True: + self.requests[prop] = alias + elif prop in self.requests and overwrite == "ignore": + pass + elif overwrite == "smart": + current = self.requests[prop] + if isinstance(current, str): + raise ValueError( + f"Cannot overwrite {current} with {alias} when overwrite=smart." + ) + current = RequestType(current) + + # REQUESTED > UNREQUESTED > ERROR_IF_PASSED + if alias == RequestType.REQUESTED and current in { + RequestType.ERROR_IF_PASSED, + RequestType.UNREQUESTED, + }: + self.requests[prop] = alias + elif ( + alias == RequestType.UNREQUESTED + and current == RequestType.ERROR_IF_PASSED + ): + self.requests[prop] = alias + elif self.requests[prop] != alias: + raise ValueError( + f"{prop} is already requested as {self.requests[prop]}, " + f"which is not the same as the one given: {alias}. Cannot " + "overwrite when overwrite=False." + ) + + def merge_method_request(self, other, overwrite=False, expected_metadata=None): + """Merge the metadata request info of two methods. + + The methods can be the same, or different. For example, merging + fit and score info of the same object, or merging fit request info + from two different sub estimators. + + Parameters + ---------- + other : MethodMetadataRequest + The other object to be merged with this instance. + + overwrite : bool or str, default=False + + - True: ``alias`` replaces the existing routing. + + - False: a ``ValueError`` is raised if the given value conflicts + with an existing one. + + - "smart": overwrite in this order: + ``RequestType.REQUESTED`` over ``RequestType.UNREQUESTED`` over + ``RequestType.ERROR_IF_PASSED``, and error if existing value is + a string. + + - "ignore": ignore the requested metadata if it already exists. + + expected_metadata : str, default=None + If provided, all props should be the same as this value. It used to + handle default values. + """ + if not isinstance(other, MethodMetadataRequest): + raise ValueError("Can only add another MethodMetadataRequest.") + for prop, alias in other.requests.items(): + self.add_request( + prop=prop, + alias=alias, + overwrite=overwrite, + expected_metadata=expected_metadata, + ) + + def validate_metadata(self, ignore_extras=False, self_metadata=None, kwargs=None): + """Validate the given arguments against the requested ones. + + Parameters + ---------- + ignore_extras : bool, default=False + If ``True``, no error is raised if extra unknown args are passed. + + self_metadata : MetadataRequest-like, default=None + This parameter can be anything which can be an input to + ``metadata_request_factory``. Only the part of the metadata which + is the same as ``name`` is used. + + Consumers don't validate their own metadata. Validation is always + done by routers (i.e. usually meta-estimators). But sometimes an + object is a consumer and a router, e.g. ``LogisticRegressionCV`` + which consumes ``sample_weight``, but also routes metadata to the + given scorer(s) and CV object, and therefore is also a router. In + such a case, ``sample_weight`` is the metadata being consumed. A + router can get its own required metadata, as opposed to the ones + required by its sub-objects, using + ``metadata_request_factory(super())``. ``validate_metadata`` then + uses the part which is relevant to this validation. Since this + object knows which method is relevant using its ``name``, passing + ``super()`` here would be sufficient. + + kwargs : dict + Provided metadata. + + Returns + ------- + None + """ + kwargs = {} if kwargs is None else kwargs + self_metadata = getattr( + metadata_request_factory(self_metadata), self.name + ).requests + # we then remove self metadata from kwargs, since they should not be + # validated. + kwargs = {v: k for v, k in kwargs.items() if v not in self_metadata} + args = {arg for arg, value in kwargs.items() if value is not None} + if not ignore_extras and args - set(self.requests.keys()): + raise ValueError( + "Metadata passed which is not understood: " + f"{sorted(args - set(self.requests.keys()))}. In method: " + f"{self.name}" + ) + + for prop, alias in self.requests.items(): + if not isinstance(alias, str): + alias = RequestType(alias) + if alias == RequestType.UNREQUESTED: + if prop in args: + raise ValueError( + f"{prop} is not requested by any estimator but is provided. In" + f" method: {self.name}" + ) + elif alias == RequestType.REQUESTED or isinstance(alias, str): + # we ignore what the given alias here is, since aliases are + # checked at the parent meta-estimator level, and the child + # still expects the original names for the metadata. + # If a metadata is requested but not passed, no error is raised + continue + elif alias == RequestType.ERROR_IF_PASSED: + if prop in args: + raise ValueError( + f"{prop} is passed but is not explicitly set as " + f"requested or not. In method: {self.name}" + ) + + def get_method_input(self, ignore_extras=False, kwargs=None): + """Return the input parameters requested by the method. + + The output of this method can be used directly as the input to the + corresponding method as extra props. + + Parameters + ---------- + ignore_extras : bool, default=False + If ``True``, no error is raised if extra unknown args are passed. + + kwargs : dict + A dictionary of provided metadata. + + Returns + ------- + kwargs : dict + A dictionary of {prop: value} which can be given to the + corresponding method. + """ + kwargs = {} if kwargs is None else kwargs + args = {arg: value for arg, value in kwargs.items() if value is not None} + res = dict() + for prop, alias in self.requests.items(): + if not isinstance(alias, str): + alias = RequestType(alias) + + if alias == RequestType.UNREQUESTED: + continue + elif alias == RequestType.REQUESTED and prop in args: + res[prop] = args[prop] + elif alias == RequestType.ERROR_IF_PASSED and prop in args: + raise ValueError( + f"{prop} is passed but is not explicitly set as " + f"requested or not. In method: {self.name}" + ) + elif alias in args: + res[prop] = args[alias] + self.validate_metadata(ignore_extras=ignore_extras, kwargs=res) + return res + + def masked(self): + """Return a masked version of the requests. + + Returns + ------- + masked : MethodMetadataRequest + A masked version is one which converts a ``{'prop': 'alias'}`` to + ``{'alias': True}``. This is desired in meta-estimators passing + requests to their parent estimators. + """ + res = MethodMetadataRequest(name=self.name) + for prop, alias in self.requests.items(): + if isinstance(alias, str): + res.add_request( + prop=alias, + alias=alias, + allow_aliasing=False, + overwrite=False, + ) + else: + res.add_request( + prop=prop, + alias=alias, + allow_aliasing=False, + overwrite=False, + ) + return res + + @classmethod + def from_dict( + cls, requests, name, allow_aliasing=True, default=RequestType.ERROR_IF_PASSED + ): + """Construct a MethodMetadataRequest from a given dictionary. + + Parameters + ---------- + requests : dict + A dictionary representing the requests. + + name : str + The name of the method to which these requests belong. + + allow_aliasing : bool, default=True + If false, only aliases with the same name as the parameter are + allowed. This is useful when handling the default values. + + default : RequestType, True, False, None, or str, \ + default=RequestType.ERROR_IF_PASSED + The default value to be used if parameters are provided as a string + or list instead of the fully specifying dict. + + Returns + ------- + requests: MethodMetadataRequest + A :class:`MethodMetadataRequest` object. + """ + if requests is None: + requests = dict() + elif isinstance(requests, str): + requests = {requests: default} + elif isinstance(requests, (list, set)): + requests = {r: default for r in requests} + result = cls(name=name) + for prop, alias in requests.items(): + result.add_request(prop=prop, alias=alias, allow_aliasing=allow_aliasing) + return result + + def __repr__(self): + return str(self.requests) + + def __str__(self): + return str(self.requests) + + +class MetadataRequest: + """Contains the metadata request info of an object. + + .. versionadded:: 1.1 + + Parameters + ---------- + requests : dict of dict of {str: str}, default=None + A dictionary where the keys are the names of the methods, and the values are + a dictionary of the form ``{"required_metadata": "provided_metadata"}``. + ``"provided_metadata"`` can also be a ``RequestType`` or {True, False, None}. + + default : RequestType, True, False, None, or str, \ + default=RequestType.ERROR_IF_PASSED + The default value to be used if parameters are provided as a string instead of + the usual second layer dict. + """ + + def __init__(self, requests=None, default=RequestType.ERROR_IF_PASSED): + for method in METHODS: + setattr(self, method, MethodMetadataRequest(name=method)) + + if requests is None: + return + elif not isinstance(requests, dict): + raise ValueError( + "Can only construct an instance from a dict. Please call " + "metadata_request_factory for other types of input." + ) + + for method, method_requests in requests.items(): + if method not in METHODS: + raise ValueError(f"{method} is not supported as a method.") + setattr( + self, + method, + MethodMetadataRequest.from_dict( + method_requests, name=method, default=default + ), + ) + + def add_requests( + self, + obj, + mapping="one-to-one", + overwrite=False, + expected_metadata=None, + ): + """Add request info from the given object with the desired mapping. + + Parameters + ---------- + obj : object + An object from which a MetadataRequest can be constructed. + + mapping : dict or str, default="one-to-one" + The mapping between the ``obj``'s methods and this object's + methods. If ``"one-to-one"`` all methods' requests from ``obj`` are + merged into this instance's methods. If a dict, the mapping is of + the form ``{"destination_method": "source_method"}`` or + ``{"destination_method": ["source_method1", ...]}``. + + overwrite : bool or str, default=False + + - True: ``alias`` replaces the existing routing. + + - False: a ``ValueError`` is raised if the given value conflicts + with an existing one. + + - "smart": overwrite in this order: + ``RequestType.REQUESTED`` over ``RequestType.UNREQUESTED`` over + ``RequestType.ERROR_IF_PASSED``, and error if existing value is + a string. + + - "ignore": ignore the requested metadata if it already exists. + + expected_metadata : str, default=None + If provided, all props should be the same as this value. It used to + handle default values. + """ + if not isinstance(mapping, dict) and mapping != "one-to-one": + raise ValueError( + "mapping can only be a dict or the literal 'one-to-one'. " + f"Given value: {mapping}" + ) + if mapping == "one-to-one": + mapping = [(method, method) for method in METHODS] + else: + _mapping = [] + for destination, sources in mapping.items(): + if isinstance(sources, list): + _mapping.extend([(destination, source) for source in sources]) + else: + _mapping.append((destination, sources)) + mapping = _mapping + other = metadata_request_factory(obj) + for destination, source in mapping: + my_method = getattr(self, destination) + other_method = getattr(other, source) + my_method.merge_method_request( + other_method, + overwrite=overwrite, + expected_metadata=expected_metadata, + ) + + def masked(self): + """Return a masked version of the requests. + + A masked version is one which converts a ``{'prop': 'alias'}`` to + ``{'alias': True}``. This is desired in meta-estimators passing + requests to their parent estimators. + """ + res = MetadataRequest() + for method in METHODS: + setattr(res, method, getattr(self, method).masked()) + return res + + def to_dict(self): + """Return dictionary representation of this object.""" + output = dict() + for method in METHODS: + output[method] = getattr(self, method).requests + return output + + def __repr__(self): + return str(self.to_dict()) + + def __str__(self): + return str(self.to_dict()) + + +def metadata_request_factory(obj=None): + """Get a MetadataRequest instance from the given object. + + .. versionadded:: 1.1 + + Parameters + ---------- + obj : object + If the object is already a MetadataRequest, return that. + If the object is an estimator, try to call `get_metadata_request` and get + an instance from that method. + If the object is a dict, create a MetadataRequest from that. + + Returns + ------- + metadata_requests : MetadataRequest + A ``MetadataRequest`` taken or created from the given object. + """ + if obj is None: + return MetadataRequest() + + if isinstance(obj, MetadataRequest): + return obj + + if isinstance(obj, dict): + return MetadataRequest(obj) + + try: + return MetadataRequest(obj.get_metadata_request()) + except AttributeError: + # The object doesn't have a `get_metadata_request` method. + return MetadataRequest() + + +class MetadataRouter: + """Route the metadata to child objects. + + .. versionadded:: 1.1 + """ + + def __init__(self): + self.requests = MetadataRequest() + + def add(self, *obj, mapping="one-to-one", overwrite=False, mask=False): + """Add a set of requests to the existing ones. + + Parameters + ---------- + *obj : objects + A set of objects from which the requests are extracted. Passed as + arguments to this method. + + mapping : dict or str, default="one-to-one" + The mapping between the ``obj``'s methods and this routing object's + methods. If ``"one-to-one"`` all methods' requests from ``obj`` are + merged into this instance's methods. If a dict, the mapping is of + the form ``{"destination_method": "source_method"}`` or + ``{"destination_method": ["source_method1", ...]}``. + + overwrite : bool or str, default=False + + - True: ``alias`` replaces the existing routing. + + - False: a ``ValueError`` is raised if the given value conflicts + with an existing one. + + - "smart": overwrite in this order: + ``RequestType.REQUESTED`` over ``RequestType.UNREQUESTED`` over + ``RequestType.ERROR_IF_PASSED``, and error if existing value is + a string. + + - "ignore": ignore the requested metadata if it already exists. + + mask : bool, default=False + If the requested metadata should be masked by the alias. If + ``True``, then a request of the form + ``{'sample_weight' : 'my_weight'}`` is converted to + ``{'my_weight': 'my_weight'}``. This is required for meta-estimators + which should expose the requested parameters and not the ones + expected by the objects' methods. + """ + for x in obj: + if mask: + x = metadata_request_factory(x).masked() + self.requests.add_requests(x, mapping=mapping, overwrite=overwrite) + return self + + def get_metadata_request(self): + """Get requested data properties. + + Returns + ------- + request : dict + A dict of dict of str->value. The key to the first dict is the name + of the method, and the key to the second dict is the name of the + argument requested by the method. + """ + return self.requests.to_dict() + + +class RequestMethod: + """ + A descriptor for request methods. + + .. versionadded:: 1.1 + + Parameters + ---------- + name : str + The name of the method for which the request function should be + created, e.g. ``"fit"`` would create a ``fit_requests`` function. + + keys : list of str + A list of strings which are accepted parameters by the created + function, e.g. ``["sample_weight"]`` if the corresponding method + accepts it as a metadata. + + Notes + ----- + This class is a descriptor [1]_ and uses PEP-362 to set the signature of + the returned function [2]_. + + References + ---------- + .. [1] https://docs.python.org/3/howto/descriptor.html + + .. [2] https://www.python.org/dev/peps/pep-0362/ + """ + + def __init__(self, name, keys): + self.name = name + self.keys = keys + + def __get__(self, instance, owner): + # we would want to have a method which accepts only the expected args + def func(**kw): + if set(kw) - set(self.keys): + raise TypeError(f"Unexpected args: {set(kw) - set(self.keys)}") + + requests = metadata_request_factory(instance) + + try: + method_metadata_request = getattr(requests, self.name) + except AttributeError: + raise ValueError(f"{self.name} is not a supported method.") + + for prop, alias in kw.items(): + if alias is not UNCHANGED: + method_metadata_request.add_request( + prop=prop, alias=alias, allow_aliasing=True, overwrite=True + ) + instance._metadata_request = requests.to_dict() + + return instance + + # Now we set the relevant attributes of the function so that it seems + # like a normal method to the end user, with known expected arguments. + func.__name__ = f"{self.name}_requests" + params = [ + inspect.Parameter( + name="self", + kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, + annotation=type(instance), + ) + ] + params.extend( + [ + inspect.Parameter( + k, + inspect.Parameter.KEYWORD_ONLY, + default=UNCHANGED, + annotation=Optional[Union[RequestType, str]], + ) + for k in self.keys + ] + ) + func.__signature__ = inspect.Signature( + params, + return_annotation=type(instance), + ) + doc = REQUESTER_DOC.format(method=self.name) + for metadata in self.keys: + doc += REQUESTER_DOC_PARAM.format(metadata=metadata, method=self.name) + doc += REQUESTER_DOC_RETURN + func.__doc__ = doc + return func + + +class _MetadataRequester: + """Mixin class for adding metadata request functionality. + + .. versionadded:: 1.1 + """ + + def __init_subclass__(cls, **kwargs): + """Set the ``{method}_requests`` methods. + + This uses PEP-487 [1]_ to set the ``{method}_requests`` methods. It + looks for the information available in the set default values which are + set using ``__metadata_request__*`` class attributes. + + References + ---------- + .. [1] https://www.python.org/dev/peps/pep-0487 + """ + try: + requests = cls._get_default_requests().to_dict() + except Exception: + # if there are any issues in the default values, it will be raised + # when ``get_metadata_request`` is called. Here we are going to + # ignore all the issues such as bad defaults etc.` + super().__init_subclass__(**kwargs) + return + + for request_method, request_keys in requests.items(): + # set ``{method}_requests``` methods + if not len(request_keys): + continue + setattr( + cls, + f"{request_method}_requests", + RequestMethod(request_method, sorted(request_keys)), + ) + super().__init_subclass__(**kwargs) + + @classmethod + def _get_default_requests(cls): + """Collect default request values. + + This method combines the information present in ``metadata_request__*`` + class attributes. + """ + + requests = MetadataRequest() + + # need to go through the MRO since this is a class attribute and + # ``vars`` doesn't report the parent class attributes. We go through + # the reverse of the MRO since cls is the first in the tuple and object + # is the last. + defaults = defaultdict() + for klass in reversed(inspect.getmro(cls)): + klass_defaults = { + attr: value + for attr, value in vars(klass).items() + if attr.startswith("__metadata_request__") + } + defaults.update(klass_defaults) + defaults = dict(sorted(defaults.items())) + + # First take all arguments from the method signatures and have them as + # ERROR_IF_PASSED, except X, y, *args, and **kwargs. + for method in METHODS: + # Here we use `isfunction` instead of `ismethod` because calling `getattr` + # on a class instead of an instance returns an unbound function. + if not hasattr(cls, method) or not inspect.isfunction(getattr(cls, method)): + continue + # ignore the first parameter of the method, which is usually "self" + params = list(inspect.signature(getattr(cls, method)).parameters.items())[ + 1: + ] + for pname, param in params: + if pname in {"X", "y", "Y"}: + continue + if param.kind in {param.VAR_POSITIONAL, param.VAR_KEYWORD}: + continue + getattr(requests, method).add_request( + prop=pname, + alias=RequestType.ERROR_IF_PASSED, + allow_aliasing=False, + overwrite=False, + ) + + # Then overwrite those defaults with the ones provided in + # __metadata_request__* attributes, which are provided in `requests` here. + + for attr, value in defaults.items(): + requests.add_requests( + value, overwrite=True, expected_metadata="__".join(attr.split("__")[1:]) + ) + return requests + + def get_metadata_request(self): + """Get requested data properties. + + Returns + ------- + request : dict + A dict of dict of str->value. The key to the first dict is the name + of the method, and the key to the second dict is the name of the + argument requested by the method. + """ + if hasattr(self, "_metadata_request"): + requests = metadata_request_factory(self._metadata_request) + else: + requests = self._get_default_requests() + + return requests.to_dict() diff --git a/sklearn/utils/tests/test_estimator_checks.py b/sklearn/utils/tests/test_estimator_checks.py index c4f954790cd26..ed383c8e58ee6 100644 --- a/sklearn/utils/tests/test_estimator_checks.py +++ b/sklearn/utils/tests/test_estimator_checks.py @@ -651,6 +651,10 @@ class NonConformantEstimatorNoParamSet(BaseEstimator): def __init__(self, you_should_set_this_=None): pass + class ConformantEstimatorClassAttribute(BaseEstimator): + # making sure our __metadata_request__* class attributes are okay! + __metadata_request__foo = {"fit": "foo"} + msg = ( "Estimator estimator_name should not set any" " attribute apart from parameters during init." @@ -670,6 +674,14 @@ def __init__(self, you_should_set_this_=None): "estimator_name", NonConformantEstimatorNoParamSet() ) + # a private class attribute is okay! + check_no_attributes_set_in_init( + "estimator_name", ConformantEstimatorClassAttribute() + ) + check_no_attributes_set_in_init( + "estimator_name", ConformantEstimatorClassAttribute().fit_requests(foo=True) + ) + def test_check_estimator_pairwise(): # check that check_estimator() works on estimator with _pairwise From 78689507172b8be90a1a1ced69caf0b6f4beee4b Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Fri, 8 Oct 2021 19:03:58 +0200 Subject: [PATCH 02/18] fix test_props and the issue with attribute starting with __ --- sklearn/tests/test_props.py | 22 +++++++++++----------- sklearn/utils/metadata_requests.py | 8 ++++++-- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/sklearn/tests/test_props.py b/sklearn/tests/test_props.py index 57030949a7e68..cccdc1abcf9d4 100644 --- a/sklearn/tests/test_props.py +++ b/sklearn/tests/test_props.py @@ -157,7 +157,9 @@ def fit(self, X, y=None, **fit_params): return self def transform(self, X, y=None, **transform_params): - # not validating since the following would validate due to ignore_extras=False + metadata_request_factory(self).transform.validate_metadata( + ignore_extras=False, kwargs=transform_params + ) transform_params_ = metadata_request_factory( self.transformer ).transform.get_method_input(ignore_extras=False, kwargs=transform_params) @@ -324,14 +326,10 @@ def test_simple_metadata_routing(): estimator=TestEstimatorFitMetadata().fit_requests(sample_weight=False), sample_weight_none=False, ) - with pytest.raises( - ValueError, - match=( - "sample_weight is not requested by any estimator but is provided. In" - " method: fit" - ), - ): - cls.fit(X, y, sample_weight=my_weights) + # this doesn't raise since TestSimpleMetaEstimator itself is a consumer, + # and passing metadata to the consumer directly is fine regardless of its + # metadata_request values. + cls.fit(X, y, sample_weight=my_weights) # Requesting a metadata will make the meta-estimator forward it correctly cls = TestSimpleMetaEstimator( @@ -366,8 +364,10 @@ def test_invalid_metadata(): with pytest.raises( ValueError, match=( - "sample_weight is not requested by any estimator but is provided. In" - " method: transform" + re.escape( + "Metadata passed which is not understood: ['sample_weight']. In method:" + " transform" + ) ), ): trs.fit(X, y).transform(X, sample_weight=my_weights) diff --git a/sklearn/utils/metadata_requests.py b/sklearn/utils/metadata_requests.py index c04270f7c5cc2..47b4086755ef5 100644 --- a/sklearn/utils/metadata_requests.py +++ b/sklearn/utils/metadata_requests.py @@ -785,7 +785,7 @@ class attributes. klass_defaults = { attr: value for attr, value in vars(klass).items() - if attr.startswith("__metadata_request__") + if "__metadata_request__" in attr } defaults.update(klass_defaults) defaults = dict(sorted(defaults.items())) @@ -817,8 +817,12 @@ class attributes. # __metadata_request__* attributes, which are provided in `requests` here. for attr, value in defaults.items(): + # we don't check for attr.startswith() since python prefixes attrs + # starting with __ with the `_ClassName`. + substr = "__metadata_request__" + expected_metadata = attr[attr.index(substr) + len(substr) :] requests.add_requests( - value, overwrite=True, expected_metadata="__".join(attr.split("__")[1:]) + value, overwrite=True, expected_metadata=expected_metadata ) return requests From 5793318d2a1743beb6054bab25eb29d542a6441d Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Mon, 11 Oct 2021 14:22:03 +0200 Subject: [PATCH 03/18] skip doctest in metadata_routing.rst for now --- doc/conftest.py | 7 +++++++ doc/metadata_routing.rst | 6 ------ 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/doc/conftest.py b/doc/conftest.py index d0f49ac087477..b06b1c6bb6d4c 100644 --- a/doc/conftest.py +++ b/doc/conftest.py @@ -138,6 +138,13 @@ def pytest_runtest_setup(item): setup_preprocessing() elif fname.endswith("statistical_inference/unsupervised_learning.rst"): setup_unsupervised_learning() + elif fname.endswith("metadata_routing.rst"): + # TODO: remove this once implemented + # Skip metarouting because is it is not fully implemented yet + raise SkipTest( + "Skipping doctest for metadata_routing.rst because it " + "is not fully implemented yet" + ) rst_files_requiring_matplotlib = [ "modules/partial_dependence.rst", diff --git a/doc/metadata_routing.rst b/doc/metadata_routing.rst index 034c685bfb2eb..342e1a78f2b6a 100644 --- a/doc/metadata_routing.rst +++ b/doc/metadata_routing.rst @@ -19,7 +19,6 @@ Usage Examples Here we present a few examples to show different common use-cases. The examples in this section require the following imports and data:: -.. TODO: add once implemented >>> import numpy as np >>> from sklearn.metrics import make_scorer, accuracy_score >>> from sklearn.linear_model import LogisticRegressionCV @@ -44,7 +43,6 @@ explicitly request weights in ``make_scorer`` and for ``LogisticRegressionCV``. Both of these *consumers* understand the meaning of the key ``"sample_weight"``:: -.. TODO: add once implemented >>> weighted_acc = make_scorer(accuracy_score).score_requests( ... sample_weight=True ... ) @@ -72,7 +70,6 @@ weights explicitly be requested, we need to explicitly say that ``sample_weight`` is not used for it, so that ``cross_validate`` doesn't pass it along. -.. TODO: add once implemented >>> weighted_acc = make_scorer(accuracy_score).score_requests( ... sample_weight=True ... ) @@ -94,7 +91,6 @@ Unweighted feature selection Unlike ``LogisticRegressionCV``, ``SelectKBest`` doesn't accept weights and therefore `"sample_weight"` is not routed to it:: -.. TODO: add once implemented >>> weighted_acc = make_scorer(accuracy_score).score_requests( ... sample_weight=True ... ) @@ -120,7 +116,6 @@ Despite ``make_scorer`` and ``LogisticRegressionCV`` both expecting a key consumers. In this example, we pass ``scoring_weight`` to the scorer, and ``fitting_weight`` to ``LogisticRegressionCV``:: -.. TODO: add once implemented >>> weighted_acc = make_scorer(accuracy_score).score_requests( ... sample_weight="scoring_weight" ... ) @@ -180,7 +175,6 @@ be set by the user, otherwise an error is raised by the router object. For example, the following code would raise, since it hasn't been explicitly set whether ``sample_weight`` should be passed to the estimator's scorer or not:: -.. TODO: add once implemented >>> param_grid = {"C": [0.1, 1]} >>> lr = LogisticRegression().fit_requests(sample_weight=True) >>> try: From 669649760822c98eda0f7c01ce26af0667a531b5 Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Mon, 11 Oct 2021 14:24:59 +0200 Subject: [PATCH 04/18] DOC explain why aliasing on sub-estimator of a consumer/router is useful --- examples/metadata_routing.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/metadata_routing.py b/examples/metadata_routing.py index 9ee0090b32b59..9e1a336f33f0d 100644 --- a/examples/metadata_routing.py +++ b/examples/metadata_routing.py @@ -389,7 +389,9 @@ def get_metadata_request(self): est.fit(X, y, foo=my_weights, first_aliased_foo=my_other_weights) # %% -# Alias only on the sub-estimator +# Alias only on the sub-estimator. This is useful if we don't want the +# meta-estimator to use the metadata, and we only want the metadata to be used +# by the sub-estimator. est = RouterConsumerClassifier( estimator=ExampleClassifier().fit_requests(foo="aliased_foo") ).fit_requests(foo=True) From c0841c893df1007ef0a111d68fd59cc2249ad73a Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Mon, 11 Oct 2021 14:33:34 +0200 Subject: [PATCH 05/18] reduce diff --- sklearn/utils/estimator_checks.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 85b87655a4dc3..cd4a6568871d6 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -3781,7 +3781,14 @@ def check_dataframe_column_names_consistency(name, estimator_orig): check_methods.append((method, callable_method)) for _, method in check_methods: - method(X) # works + with warnings.catch_warnings(): + warnings.filterwarnings( + "error", + message="X does not have valid feature names", + category=UserWarning, + module="sklearn", + ) + method(X) # works without UserWarning for valid features invalid_names = [ (names[::-1], "Feature names must be in the same order as they were in fit."), From 1aff2eb427cf43e8c250bb0c5f0b0fec0634971e Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Mon, 11 Oct 2021 16:17:24 +0200 Subject: [PATCH 06/18] DOC add user guide link to method docstrings --- sklearn/utils/metadata_requests.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/sklearn/utils/metadata_requests.py b/sklearn/utils/metadata_requests.py index 47b4086755ef5..802b2ffb26916 100644 --- a/sklearn/utils/metadata_requests.py +++ b/sklearn/utils/metadata_requests.py @@ -33,6 +33,9 @@ class RequestType(Enum): REQUESTER_DOC = """ Request metadata passed to the ``{method}`` method. + Please check :ref:`User Guide ` on how the routing + mechanism works. + Parameters ---------- """ @@ -631,6 +634,9 @@ def add(self, *obj, mapping="one-to-one", overwrite=False, mask=False): def get_metadata_request(self): """Get requested data properties. + Please check :ref:`User Guide ` on how the routing + mechanism works. + Returns ------- request : dict @@ -829,6 +835,9 @@ class attributes. def get_metadata_request(self): """Get requested data properties. + Please check :ref:`User Guide ` on how the routing + mechanism works. + Returns ------- request : dict From 14572937a20d56b940ffcf51eb32f22056fc84fc Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Wed, 13 Oct 2021 18:04:37 +0200 Subject: [PATCH 07/18] DOC apply Thomas's suggestions to the rst file --- doc/metadata_routing.rst | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/doc/metadata_routing.rst b/doc/metadata_routing.rst index 342e1a78f2b6a..35ca6abc8fbfc 100644 --- a/doc/metadata_routing.rst +++ b/doc/metadata_routing.rst @@ -65,10 +65,10 @@ not requested by any of its children. Weighted scoring and unweighted fitting --------------------------------------- -Since ``LogisticRegressionCV``, like all scikit-learn estimators, requires that -weights explicitly be requested, we need to explicitly say that -``sample_weight`` is not used for it, so that ``cross_validate`` doesn't pass -it along. +All scikit-learn estimators requires weights to be explicitly requested or not +requested. To perform a unweighted fit, we need to configure +:class:`~linear_model.LogisticRegressionCV` to not request sample weights, so +that :func:`~model_selection.cross_validate` does not pass the weights along:: >>> weighted_acc = make_scorer(accuracy_score).score_requests( ... sample_weight=True @@ -85,6 +85,11 @@ it along. ... scoring=weighted_acc, ... ) +If :class:`~linear_model.LogisticRegressionCV` did not call ``fit_requests``, +:func:`~model_selection.cross_validate` will raise an error because weights is +passed in but :class:`~linear_model.LogisticRegressionCV` was not configured to +recognize the weights. + Unweighted feature selection ---------------------------- @@ -185,3 +190,9 @@ whether ``sample_weight`` should be passed to the estimator's scorer or not:: ... print(e) sample_weight is passed but is not explicitly set as requested or not. In method: score + +The issue can be fixed by explicitly setting the request value:: + + >>> lr = LogisticRegression().fit_requests( + ... sample_weight=True + ... ).score_requests(sample_weight=False) From af86e825ccee054502eb44688be168d29327f03e Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Mon, 25 Oct 2021 15:13:43 +0200 Subject: [PATCH 08/18] CLN address a few comments in docs --- doc/metadata_routing.rst | 12 +++++++++--- ...{metadata_routing.py => plot_metadata_routing.py} | 5 +++-- sklearn/utils/metadata_requests.py | 1 - 3 files changed, 12 insertions(+), 6 deletions(-) rename examples/{metadata_routing.py => plot_metadata_routing.py} (99%) diff --git a/doc/metadata_routing.rst b/doc/metadata_routing.rst index 35ca6abc8fbfc..4ae743374b6ab 100644 --- a/doc/metadata_routing.rst +++ b/doc/metadata_routing.rst @@ -1,6 +1,8 @@ .. _metadata_routing: +.. TODO: update doc/conftest.py once document is updated and examples run. + Metadata Routing ================ @@ -58,6 +60,9 @@ Both of these *consumers* understand the meaning of the key ... scoring=weighted_acc, ... ) +Note that in this example, ``my_weights`` is passed to both the scorer and +``~linear_model.LogisticRegressionCV``. + Error handling: if ``props={'sample_weigh': my_weights, ...}`` were passed (note the typo), cross_validate would raise an error, since 'sample_weigh' was not requested by any of its children. @@ -146,7 +151,7 @@ API Interface A *consumer* is an object (estimator, meta-estimator, scorer, splitter) which accepts and uses some metadata in at least one of their methods (``fit``, ``predict``, ``inverse_transform``, ``transform``, ``score``, ``split``). -Meta-estimators which only forward the metadata other objects (the child +Meta-estimators which only forward the metadata to other objects (the child estimator, scorers, or splitters) and don't use the metadata themselves are not consumers. (Meta)Estimators which route metadata to other objects are routers. An (meta)estimator can be a consumer and a router at the same time. @@ -164,7 +169,8 @@ which accepts at least one metadata. For instance, if an estimator supports - ``RequestType.ERROR_IF_PASSED`` or ``None``: router will raise an error if ``sample_weight`` is passed. This is in almost all cases the default value when an object is instantiated and ensures the user sets the metadata - requests explicitly when a metadata is passed. + requests explicitly when a metadata is passed. The only exception are + ``Group*Fold`` splitters. - ``"param_name"``: if this estimator is used in a meta-estimator, the meta-estimator should forward ``"param_name"`` as ``sample_weight`` to this estimator. This means the mapping between the metadata required by the @@ -174,7 +180,7 @@ which accepts at least one metadata. For instance, if an estimator supports For the scorers, this is done the same way, using ``.score_requests`` method. -If a metadata, e.g. ``sample_weight`` is passed by the user, the metadata +If a metadata, e.g. ``sample_weight``, is passed by the user, the metadata request for all objects which potentially can accept ``sample_weight`` should be set by the user, otherwise an error is raised by the router object. For example, the following code would raise, since it hasn't been explicitly set diff --git a/examples/metadata_routing.py b/examples/plot_metadata_routing.py similarity index 99% rename from examples/metadata_routing.py rename to examples/plot_metadata_routing.py index 9e1a336f33f0d..bba2defbe329b 100644 --- a/examples/metadata_routing.py +++ b/examples/plot_metadata_routing.py @@ -92,7 +92,8 @@ def predict(self, X, bar=None): # %% # As you can see, now the two metadata have explicit request values, one is # requested and the other one is not. Instead of ``True`` and ``False``, we -# could also use the ``RequestType`` values. +# could also use the :class:`~sklearn.utils.metadata_requests.RequestType`` +# values. est = ( ExampleClassifier() @@ -219,7 +220,7 @@ def get_metadata_request(self): # %% # In order to understand the above implementation of ``get_metadata_request``, -# we need to also introduce an aliaced metadata. This is when an estimator +# we need to also introduce an aliased metadata. This is when an estimator # requests a metadata with a different name than the default value. For # instance, in a setting where there are two estimators in a pipeline, one # could request ``sample_weight1`` and the other ``sample_weight2``. Note that diff --git a/sklearn/utils/metadata_requests.py b/sklearn/utils/metadata_requests.py index 802b2ffb26916..f71fc4bf6b425 100644 --- a/sklearn/utils/metadata_requests.py +++ b/sklearn/utils/metadata_requests.py @@ -1,4 +1,3 @@ -# from copy import deepcopy import inspect from enum import Enum from collections import defaultdict From 11649d95d1f17343a6c428e59ca13d6d4dcc48fd Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Mon, 25 Oct 2021 21:23:02 +0200 Subject: [PATCH 09/18] ignore sentinel docstring check --- maint_tools/test_docstrings.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/maint_tools/test_docstrings.py b/maint_tools/test_docstrings.py index edbc05d260dee..573167a8d92c5 100644 --- a/maint_tools/test_docstrings.py +++ b/maint_tools/test_docstrings.py @@ -235,6 +235,8 @@ "sklearn.utils.validation.check_random_state", "sklearn.utils.validation.column_or_1d", "sklearn.utils.validation.has_fit_parameter", + # Never fix this one, it's vendord code + "sklearn.externals._sentinels.sentinel", ] FUNCTION_DOCSTRING_IGNORE_LIST = set(FUNCTION_DOCSTRING_IGNORE_LIST) From b5c962c74c7bf8d27e9598660f3577299f81c0d3 Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Fri, 5 Nov 2021 18:21:41 +0100 Subject: [PATCH 10/18] handling backward compatibility and deprecation prototype --- examples/plot_metadata_routing.py | 103 +++++++++++++++++++++++++++++ sklearn/utils/metadata_requests.py | 38 +++++++++-- 2 files changed, 135 insertions(+), 6 deletions(-) diff --git a/examples/plot_metadata_routing.py b/examples/plot_metadata_routing.py index bba2defbe329b..3710ae1efe6fa 100644 --- a/examples/plot_metadata_routing.py +++ b/examples/plot_metadata_routing.py @@ -22,8 +22,10 @@ # %% import numpy as np +import warnings from sklearn.base import BaseEstimator from sklearn.base import ClassifierMixin +from sklearn.base import RegressorMixin from sklearn.base import MetaEstimatorMixin from sklearn.base import TransformerMixin from sklearn.base import clone @@ -31,6 +33,7 @@ from sklearn.utils.metadata_requests import metadata_request_factory from sklearn.utils.metadata_requests import MetadataRouter from sklearn.utils.validation import check_is_fitted +from sklearn.linear_model import LinearRegression N, M = 100, 4 X = np.random.rand(N, M) @@ -519,3 +522,103 @@ def transform(self, X, bar=None): ), ) est.fit(X, y, foo=my_weights, bar=my_groups).predict(X[:3], bar=my_groups) + +# %% +# Deprechation / Default Value Change +# ----------------------------------- +# In this section we show how one should handle the case where a router becomes +# also a consumer, especially when it consumes the same metadata as its +# sub-estimator. In this case, a warning should be raised for a while, to let +# users know the behavior is changed from previous versions. + + +class MetaRegressor(MetaEstimatorMixin, RegressorMixin, BaseEstimator): + def __init__(self, estimator): + self.estimator = estimator + + def fit(self, X, y, **fit_params): + metadata_request_factory(self).fit.validate_metadata( + ignore_extras=False, self_metadata=super(), kwargs=fit_params + ) + fit_params_ = metadata_request_factory(self.estimator).fit.get_method_input( + ignore_extras=False, kwargs=fit_params + ) + self.estimator_ = clone(self.estimator).fit(X, y, **fit_params_) + + def get_metadata_request(self): + router = MetadataRouter().add( + self.estimator, mapping="one-to-one", overwrite=False, mask=True + ) + return router.get_metadata_request() + + +# %% +# As explained above, this is now a valid usage: + +reg = MetaRegressor(estimator=LinearRegression().fit_requests(sample_weight=True)) +reg.fit(X, y, sample_weight=my_weights) + + +# %% +# Now imagine we further develop ``MetaRegressor`` and it now also *consumes* +# ``sample_weight``: + + +class SampledMetaRegressor(MetaEstimatorMixin, RegressorMixin, BaseEstimator): + __metadata_request__sample_weight = {"fit": {"sample_weight": RequestType.WARN}} + + def __init__(self, estimator): + self.estimator = estimator + + def fit(self, X, y, sample_weight=None, **fit_params): + if sample_weight is not None: + fit_params["sample_weight"] = sample_weight + metadata_request_factory(self).fit.validate_metadata( + ignore_extras=False, self_metadata=super(), kwargs=fit_params + ) + estimator_fit_params = metadata_request_factory( + self.estimator + ).fit.get_method_input(ignore_extras=True, kwargs=fit_params) + self.estimator_ = clone(self.estimator).fit(X, y, **estimator_fit_params) + + def get_metadata_request(self): + router = ( + MetadataRouter() + .add(super(), mapping="one-to-one", overwrite=False, mask=False) + .add(self.estimator, mapping="one-to-one", overwrite="smart", mask=True) + ) + return router.get_metadata_request() + + +# %% +# The above implementation is almost no different than ``MetaRegressor``, and +# because of the default request value defined in `__metadata_request__sample_weight`` +# there is a warning raised. + +with warnings.catch_warnings(record=True) as record: + SampledMetaRegressor( + estimator=LinearRegression().fit_requests(sample_weight=False) + ).fit(X, y, sample_weight=my_weights) +for w in record: + print(w.message) + + +# %% +# When an estimator suports a metadata which wasn't supported before, the +# following pattern can be used to warn the users about it. + + +class ExampleRegressor(RegressorMixin, BaseEstimator): + __metadata_request__sample_weight = {"fit": {"sample_weight": RequestType.WARN}} + + def fit(self, X, y, sample_weight=None): + return self + + def predict(self, X): + return np.zeros(shape=(len(X))) + + +with warnings.catch_warnings(record=True) as record: + MetaRegressor(estimator=ExampleRegressor()).fit(X, y, sample_weight=my_weights) +for w in record: + print(w.message) diff --git a/sklearn/utils/metadata_requests.py b/sklearn/utils/metadata_requests.py index f71fc4bf6b425..f7d50f0ccf62f 100644 --- a/sklearn/utils/metadata_requests.py +++ b/sklearn/utils/metadata_requests.py @@ -2,6 +2,7 @@ from enum import Enum from collections import defaultdict from typing import Union, Optional +from warnings import warn from ..externals._sentinels import sentinel # type: ignore # mypy error!!! @@ -13,6 +14,11 @@ class RequestType(Enum): # that a metadata is not present even though it may be present in the # corresponding method's signature. UNUSED = sentinel("UNUSED") + # this sentinel is used whenever a default value is changed, and therefore + # the user should explicitly set the value, otherwise a warning is shown. + # An example is when a meta-estimator is only a router, but then becomes + # also a consumer. + WARN = sentinel("WARN") # this sentinel is the default used in `{method}_requests` methods to indicate @@ -173,13 +179,14 @@ def add_request( if alias == RequestType.REQUESTED and current in { RequestType.ERROR_IF_PASSED, RequestType.UNREQUESTED, + RequestType.WARN, }: self.requests[prop] = alias - elif ( - alias == RequestType.UNREQUESTED - and current == RequestType.ERROR_IF_PASSED - ): - self.requests[prop] = alias + elif alias in {RequestType.UNREQUESTED, RequestType.WARN} and current in { + RequestType.ERROR_IF_PASSED, + RequestType.WARN, + }: + self.requests[prop] = RequestType.UNREQUESTED elif self.requests[prop] != alias: raise ValueError( f"{prop} is already requested as {self.requests[prop]}, " @@ -264,6 +271,17 @@ def validate_metadata(self, ignore_extras=False, self_metadata=None, kwargs=None self_metadata = getattr( metadata_request_factory(self_metadata), self.name ).requests + warn_metadata = [k for k, v in self_metadata.items() if v == RequestType.WARN] + warn_kwargs = [k for k in kwargs.keys() if k in warn_metadata] + if warn_kwargs: + warn( + "The following metadata are provided, which are now supported by this " + f"class: {warn_kwargs}. These metadata were not processed in previous " + "versions. Set their requested value to RequestType.UNREQUESTED " + "to maintain previous behavior, or to RequestType.REQUESTED to " + "consume and use the metadata.", + UserWarning, + ) # we then remove self metadata from kwargs, since they should not be # validated. kwargs = {v: k for v, k in kwargs.items() if v not in self_metadata} @@ -324,7 +342,15 @@ def get_method_input(self, ignore_extras=False, kwargs=None): if not isinstance(alias, str): alias = RequestType(alias) - if alias == RequestType.UNREQUESTED: + if alias == RequestType.WARN: + warn( + f"Support for {prop} has recently been added to this class. " + "To maintain backward compatibility, it is ignored now. " + "You can set the request value to RequestType.UNREQUESTED " + "to silence this warning, or to RequestType.REQUESTED to " + "consume and use the metadata." + ) + elif alias == RequestType.UNREQUESTED: continue elif alias == RequestType.REQUESTED and prop in args: res[prop] = args[prop] From fb200e20bf259fb540a1e6d76cff992cb232767d Mon Sep 17 00:00:00 2001 From: Adrin Jalali Date: Tue, 7 Dec 2021 17:05:59 +0100 Subject: [PATCH 11/18] Update examples/plot_metadata_routing.py Co-authored-by: Christian Lorentzen --- examples/plot_metadata_routing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/plot_metadata_routing.py b/examples/plot_metadata_routing.py index 3710ae1efe6fa..f8e9a836783d8 100644 --- a/examples/plot_metadata_routing.py +++ b/examples/plot_metadata_routing.py @@ -524,7 +524,7 @@ def transform(self, X, bar=None): est.fit(X, y, foo=my_weights, bar=my_groups).predict(X[:3], bar=my_groups) # %% -# Deprechation / Default Value Change +# Deprecation / Default Value Change # ----------------------------------- # In this section we show how one should handle the case where a router becomes # also a consumer, especially when it consumes the same metadata as its From 6f849b207fb1461e74ff01ec20024966ab866d45 Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Fri, 10 Dec 2021 15:22:01 +0100 Subject: [PATCH 12/18] make __metadata_request__* format more intuitive and less redundant --- examples/plot_metadata_routing.py | 6 +++--- sklearn/utils/metadata_requests.py | 17 ++++++++++------- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/examples/plot_metadata_routing.py b/examples/plot_metadata_routing.py index 3710ae1efe6fa..6b662898ef3a6 100644 --- a/examples/plot_metadata_routing.py +++ b/examples/plot_metadata_routing.py @@ -565,7 +565,7 @@ def get_metadata_request(self): class SampledMetaRegressor(MetaEstimatorMixin, RegressorMixin, BaseEstimator): - __metadata_request__sample_weight = {"fit": {"sample_weight": RequestType.WARN}} + __metadata_request__fit = {"sample_weight": RequestType.WARN} def __init__(self, estimator): self.estimator = estimator @@ -592,7 +592,7 @@ def get_metadata_request(self): # %% # The above implementation is almost no different than ``MetaRegressor``, and -# because of the default request value defined in `__metadata_request__sample_weight`` +# because of the default request value defined in `__metadata_request__fit`` # there is a warning raised. with warnings.catch_warnings(record=True) as record: @@ -609,7 +609,7 @@ def get_metadata_request(self): class ExampleRegressor(RegressorMixin, BaseEstimator): - __metadata_request__sample_weight = {"fit": {"sample_weight": RequestType.WARN}} + __metadata_request__fit = {"sample_weight": RequestType.WARN} def fit(self, X, y, sample_weight=None): return self diff --git a/sklearn/utils/metadata_requests.py b/sklearn/utils/metadata_requests.py index f7d50f0ccf62f..8852888cb9852 100644 --- a/sklearn/utils/metadata_requests.py +++ b/sklearn/utils/metadata_requests.py @@ -587,6 +587,7 @@ def metadata_request_factory(obj=None): metadata_requests : MetadataRequest A ``MetadataRequest`` taken or created from the given object. """ + # TODO: this should always return a copy if obj is None: return MetadataRequest() @@ -596,10 +597,11 @@ def metadata_request_factory(obj=None): if isinstance(obj, dict): return MetadataRequest(obj) - try: + # doing this instead of a try/except since an AttributeError could be raised + # for other reasons. + if hasattr(obj, "get_metadata_request"): return MetadataRequest(obj.get_metadata_request()) - except AttributeError: - # The object doesn't have a `get_metadata_request` method. + else: return MetadataRequest() @@ -851,10 +853,11 @@ class attributes. # we don't check for attr.startswith() since python prefixes attrs # starting with __ with the `_ClassName`. substr = "__metadata_request__" - expected_metadata = attr[attr.index(substr) + len(substr) :] - requests.add_requests( - value, overwrite=True, expected_metadata=expected_metadata - ) + method = attr[attr.index(substr) + len(substr) :] + for prop, alias in value.items(): + getattr(requests, method).add_request( + prop=prop, alias=alias, overwrite=True + ) return requests def get_metadata_request(self): From 82b2128593a8851ad1cdb3b2bc561a20e580c48b Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Fri, 10 Dec 2021 15:32:09 +0100 Subject: [PATCH 13/18] metadata_request_factory always returns a copy --- sklearn/utils/metadata_requests.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/sklearn/utils/metadata_requests.py b/sklearn/utils/metadata_requests.py index 8852888cb9852..f605ff5eee24e 100644 --- a/sklearn/utils/metadata_requests.py +++ b/sklearn/utils/metadata_requests.py @@ -1,4 +1,5 @@ import inspect +from copy import deepcopy from enum import Enum from collections import defaultdict from typing import Union, Optional @@ -572,6 +573,10 @@ def __str__(self): def metadata_request_factory(obj=None): """Get a MetadataRequest instance from the given object. + This function always returns a copy or an instance constructed from the + intput, such that changing the output of this function will not change the + original object. + .. versionadded:: 1.1 Parameters @@ -587,12 +592,11 @@ def metadata_request_factory(obj=None): metadata_requests : MetadataRequest A ``MetadataRequest`` taken or created from the given object. """ - # TODO: this should always return a copy if obj is None: return MetadataRequest() if isinstance(obj, MetadataRequest): - return obj + return deepcopy(obj) if isinstance(obj, dict): return MetadataRequest(obj) From 16c47b233ef27d06d14a620f8f3e6ebc9ee7c9c6 Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Sun, 12 Dec 2021 14:25:11 +0100 Subject: [PATCH 14/18] fix tests for the changed __metadata_request__* format --- sklearn/tests/test_props.py | 82 +++++++++++++------------------------ 1 file changed, 28 insertions(+), 54 deletions(-) diff --git a/sklearn/tests/test_props.py b/sklearn/tests/test_props.py index cccdc1abcf9d4..8b84fa830251a 100644 --- a/sklearn/tests/test_props.py +++ b/sklearn/tests/test_props.py @@ -259,9 +259,10 @@ def test_assert_request_is_empty(): def test_default_requests(): class OddEstimator(BaseEstimator): - __metadata_request__sample_weight = { - "fit": {"sample_weight": RequestType.REQUESTED} # type: ignore - } # set a different default request + __metadata_request__fit = { + # set a different default request + "sample_weight": RequestType.REQUESTED + } # type: ignore odd_request = metadata_request_factory(OddEstimator()) assert odd_request.fit.requests == {"sample_weight": RequestType.REQUESTED} @@ -374,61 +375,34 @@ def test_invalid_metadata(): def test_get_metadata_request(): - class TestDefaultsBadMetadataName(_MetadataRequester): - __metadata_request__sample_weight = { - "fit": "sample_weight", - "score": "sample_weight", - } - - __metadata_request__my_param = { - "score": {"my_param": True}, - # the following method raise an error - "other_method": {"my_param": True}, - } - - __metadata_request__my_other_param = { - "score": "my_other_param", - # this should raise since the name is different than the metadata - "fit": "my_param", - } - class TestDefaultsBadMethodName(_MetadataRequester): - __metadata_request__sample_weight = { - "fit": "sample_weight", - "score": "sample_weight", + __metadata_request__fit = { + "sample_weight": RequestType.ERROR_IF_PASSED, + "my_param": RequestType.ERROR_IF_PASSED, } - - __metadata_request__my_param = { - "score": {"my_param": True}, - # the following method raise an error - "other_method": {"my_param": True}, - } - - __metadata_request__my_other_param = { - "score": "my_other_param", - "fit": "my_other_param", + __metadata_request__score = { + "sample_weight": RequestType.ERROR_IF_PASSED, + "my_param": True, + "my_other_param": RequestType.ERROR_IF_PASSED, } + # this will raise an error since we don't understand "other_method" as a method + __metadata_request__other_method = {"my_param": True} class TestDefaults(_MetadataRequester): - __metadata_request__sample_weight = { - "fit": "sample_weight", - "score": "sample_weight", - } - - __metadata_request__my_param = { - "score": {"my_param": True}, - "predict": {"my_param": True}, + __metadata_request__fit = { + "sample_weight": RequestType.ERROR_IF_PASSED, + "my_other_param": RequestType.ERROR_IF_PASSED, } - - __metadata_request__my_other_param = { - "score": "my_other_param", - "fit": "my_other_param", + __metadata_request__score = { + "sample_weight": RequestType.ERROR_IF_PASSED, + "my_param": True, + "my_other_param": RequestType.ERROR_IF_PASSED, } + __metadata_request__predict = {"my_param": True} - with pytest.raises(ValueError, match="Expected all metadata to be called"): - TestDefaultsBadMetadataName().get_metadata_request() - - with pytest.raises(ValueError, match="other_method is not supported as a method"): + with pytest.raises( + AttributeError, match="'MetadataRequest' object has no attribute 'other_method'" + ): TestDefaultsBadMethodName().get_metadata_request() expected = { @@ -491,7 +465,7 @@ class TestDefaults(_MetadataRequester): def test__get_default_requests(): # Test _get_default_requests method class ExplicitRequest(BaseEstimator): - __metadata_request__prop = {"fit": "prop"} + __metadata_request__fit = {"prop": RequestType.ERROR_IF_PASSED} def fit(self, X, y): return self @@ -502,7 +476,7 @@ def fit(self, X, y): assert_request_is_empty(ExplicitRequest().get_metadata_request(), exclude="fit") class ExplicitRequestOverwrite(BaseEstimator): - __metadata_request__prop = {"fit": {"prop": RequestType.REQUESTED}} + __metadata_request__fit = {"prop": RequestType.REQUESTED} def fit(self, X, y, prop=None, **kwargs): return self @@ -524,7 +498,7 @@ def fit(self, X, y, prop=None, **kwargs): assert_request_is_empty(ImplicitRequest().get_metadata_request(), exclude="fit") class ImplicitRequestRemoval(BaseEstimator): - __metadata_request__prop = {"fit": {"prop": RequestType.UNUSED}} + __metadata_request__fit = {"prop": RequestType.UNUSED} def fit(self, X, y, prop=None, **kwargs): return self @@ -574,7 +548,7 @@ def test_method_metadata_request(): def test_metadata_request_factory(): class Consumer(BaseEstimator): - __metadata_request__prop = {"fit": "prop"} + __metadata_request__fit = {"prop": RequestType.ERROR_IF_PASSED} assert_request_is_empty(metadata_request_factory(None)) assert_request_is_empty(metadata_request_factory({})) From 1c591fef65f9fddecccf6b4610ca05b18dd9e46f Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Sun, 12 Dec 2021 14:41:40 +0100 Subject: [PATCH 15/18] in example: foo->sample_weight, bar->groups --- examples/plot_metadata_routing.py | 202 ++++++++++++++++-------------- 1 file changed, 110 insertions(+), 92 deletions(-) diff --git a/examples/plot_metadata_routing.py b/examples/plot_metadata_routing.py index 6b662898ef3a6..1087579ef495b 100644 --- a/examples/plot_metadata_routing.py +++ b/examples/plot_metadata_routing.py @@ -46,27 +46,30 @@ # Estimators # ---------- # Here we demonstrate how an estimator can expose the required API to support -# metadata routing as a consumer. Imagine a simple classifier accepting ``foo`` -# as a metadata on its ``fit`` and ``bar`` in its ``predict`` method. We add -# two constructor arguments to helps us check whether an expected metadata is -# given or not. This is a minimal scikit-learn compatible classifier: +# metadata routing as a consumer. Imagine a simple classifier accepting +# ``sample_weight`` as a metadata on its ``fit`` and ``groups`` in its +# ``predict`` method. We add two constructor arguments to helps us check +# whether an expected metadata is given or not. This is a minimal scikit-learn +# compatible classifier: class ExampleClassifier(ClassifierMixin, BaseEstimator): - def __init__(self, foo_is_none=True, bar_is_none=True): - self.foo_is_none = foo_is_none - self.bar_is_none = bar_is_none + def __init__(self, sample_weight_is_none=True, groups_is_none=True): + self.sample_weight_is_none = sample_weight_is_none + self.groups_is_none = groups_is_none - def fit(self, X, y, foo=None): - if (foo is None) != self.foo_is_none: - raise ValueError("foo's value and foo_is_none disagree!") + def fit(self, X, y, sample_weight=None): + if (sample_weight is None) != self.sample_weight_is_none: + raise ValueError( + "sample_weight's value and sample_weight_is_none disagree!" + ) # all classifiers need to expose a classes_ attribute once they're fit. self.classes_ = np.array([0, 1]) return self - def predict(self, X, bar=None): - if (bar is None) != self.bar_is_none: - raise ValueError("bar's value and bar_is_none disagree!") + def predict(self, X, groups=None): + if (groups is None) != self.groups_is_none: + raise ValueError("groups's value and groups_is_none disagree!") # return a constant value of 1, not a very smart classifier! return np.ones(len(X)) @@ -82,14 +85,16 @@ def predict(self, X, bar=None): ExampleClassifier().get_metadata_request() # %% -# The above output means that ``foo`` and ``bar`` are not requested, but if a -# router is given those metadata, it should raise an error, since the user has -# not explicitly set whether they are required or not. The same is true for -# ``sample_weight`` in ``score`` method, which is inherited from -# :class:`~base.ClassifierMixin`. In order to explicitly set request values for -# those metadata, we can use these methods: +# The above output means that ``sample_weight`` and ``groups`` are not +# requested, but if a router is given those metadata, it should raise an error, +# since the user has not explicitly set whether they are required or not. The +# same is true for ``sample_weight`` in ``score`` method, which is inherited +# from :class:`~base.ClassifierMixin`. In order to explicitly set request +# values for those metadata, we can use these methods: -est = ExampleClassifier().fit_requests(foo=False).predict_requests(bar=True) +est = ( + ExampleClassifier().fit_requests(sample_weight=False).predict_requests(groups=True) +) est.get_metadata_request() # %% @@ -100,8 +105,8 @@ def predict(self, X, bar=None): est = ( ExampleClassifier() - .fit_requests(foo=RequestType.UNREQUESTED) - .predict_requests(bar=RequestType.REQUESTED) + .fit_requests(sample_weight=RequestType.UNREQUESTED) + .predict_requests(groups=RequestType.REQUESTED) ) est.get_metadata_request() @@ -109,12 +114,12 @@ def predict(self, X, bar=None): # Please note that as long as the above estimator is not used in another # meta-estimator, the user does not need to set any requests for the metadata. # A simple usage of the above estimator would work as expected. Remember that -# ``{foo, bar}_is_none`` are for testing/demonstration purposes and don't have -# anything to do with the routing mechanisms. +# ``{sample_weight, groups}_is_none`` are for testing/demonstration purposes +# and don't have anything to do with the routing mechanisms. -est = ExampleClassifier(foo_is_none=False, bar_is_none=False) -est.fit(X, y, foo=my_weights) -est.predict(X[:3, :], bar=my_groups) +est = ExampleClassifier(sample_weight_is_none=False, groups_is_none=False) +est.fit(X, y, sample_weight=my_weights) +est.predict(X[:3, :], groups=my_groups) # %% # Now let's have a meta-estimator, which doesn't do much other than routing the @@ -182,12 +187,14 @@ def get_metadata_request(self): # silent bugs, and this is how it will work: est = MetaClassifier( - estimator=ExampleClassifier(foo_is_none=False).fit_requests(foo=True) + estimator=ExampleClassifier(sample_weight_is_none=False).fit_requests( + sample_weight=True + ) ) -est.fit(X, y, foo=my_weights) +est.fit(X, y, sample_weight=my_weights) # %% -# Note that the above example checks that ``foo`` is correctly passed to +# Note that the above example checks that ``sample_weight`` is correctly passed to # ``ExampleClassifier``, or else it would have raised: try: @@ -205,19 +212,19 @@ def get_metadata_request(self): # %% # And if we pass something which is not explicitly requested: try: - est.fit(X, y, foo=my_weights).predict(X, bar=my_groups) + est.fit(X, y, sample_weight=my_weights).predict(X, groups=my_groups) except ValueError as e: print(e) # %% # Also, if we explicitly say it's not requested, but pass it: est = MetaClassifier( - estimator=ExampleClassifier(foo_is_none=False) - .fit_requests(foo=True) - .predict_requests(bar=False) + estimator=ExampleClassifier(sample_weight_is_none=False) + .fit_requests(sample_weight=True) + .predict_requests(groups=False) ) try: - est.fit(X, y, foo=my_weights).predict(X[:3, :], bar=my_groups) + est.fit(X, y, sample_weight=my_weights).predict(X[:3, :], groups=my_groups) except ValueError as e: print(e) @@ -229,18 +236,21 @@ def get_metadata_request(self): # could request ``sample_weight1`` and the other ``sample_weight2``. Note that # this doesn't change what the estimator expects, it only tells the # meta-estimator how to map provided metadata to what's required. Here's an -# example, where we pass ``aliased_foo`` to the meta-estimator, but the -# meta-estimator understands that ``aliased_foo`` is an alias for ``foo``, and -# passes it as ``foo`` to the underlying estimator: +# example, where we pass ``aliased_sample_weight`` to the meta-estimator, but +# the meta-estimator understands that ``aliased_sample_weight`` is an alias for +# ``sample_weight``, and passes it as ``sample_weight`` to the underlying +# estimator: est = MetaClassifier( - estimator=ExampleClassifier(foo_is_none=False).fit_requests(foo="aliased_foo") + estimator=ExampleClassifier(sample_weight_is_none=False).fit_requests( + sample_weight="aliased_sample_weight" + ) ) -est.fit(X, y, aliased_foo=my_weights) +est.fit(X, y, aliased_sample_weight=my_weights) # %% -# And passing ``foo`` here will fail since it is requested with an alias: +# And passing ``sample_weight`` here will fail since it is requested with an alias: try: - est.fit(X, y, foo=my_weights) + est.fit(X, y, sample_weight=my_weights) except ValueError as e: print(e) @@ -256,9 +266,9 @@ def get_metadata_request(self): # %% # As you can see, the only metadata requested for method ``fit`` is -# ``"aliased_foo"``. This information is enough for another +# ``"aliased_sample_weight"``. This information is enough for another # meta-estimator/router to know what needs to be passed to ``est``. In other -# words, ``foo`` is *masked* . The ``MetadataRouter`` class enables us to +# words, ``sample_weight`` is *masked* . The ``MetadataRouter`` class enables us to # easily create the routing object which would create the output we need for # our ``get_metadata_request``. In the above implementation, # ``mapping="one-to-one"`` means all requests are mapped one to one from the @@ -268,18 +278,18 @@ def get_metadata_request(self): # original metadata name. Without it, having ``est`` in another meta-estimator # would break the routing. Imagine this example: -meta_est = MetaClassifier(estimator=est).fit(X, y, aliased_foo=my_weights) +meta_est = MetaClassifier(estimator=est).fit(X, y, aliased_sample_weight=my_weights) # %% # In the above example, this is how each ``fit`` method will call the # sub-estimator's ``fit``: # -# meta_est.fit(X, y, aliased_foo=my_weights): -# ... # this estimator (est), expects aliased_foo as seen above -# self.estimator_.fit(X, y, aliased_foo=aliased_foo): -# ... # est passes aliased_foo's value as foo, which is expected -# # by the sub-estimator -# self.estimator_.fit(X, y, foo=aliased_foo) +# meta_est.fit(X, y, aliased_sample_weight=my_weights): +# ... # this estimator (est), expects aliased_sample_weight as seen above +# self.estimator_.fit(X, y, aliased_sample_weight=aliased_sample_weight): +# ... # est passes aliased_sample_weight's value as sample_weight, +# # which is expected by the sub-estimator +# self.estimator_.fit(X, y, sample_weight=aliased_sample_weight) # ... # %% @@ -293,19 +303,21 @@ def get_metadata_request(self): class RouterConsumerClassifier(MetaEstimatorMixin, ClassifierMixin, BaseEstimator): - def __init__(self, estimator, foo_is_none=True): + def __init__(self, estimator, sample_weight_is_none=True): self.estimator = estimator - self.foo_is_none = foo_is_none + self.sample_weight_is_none = sample_weight_is_none - def fit(self, X, y, foo, **fit_params): + def fit(self, X, y, sample_weight, **fit_params): if self.estimator is None: raise ValueError("estimator cannot be None!") - if (foo is None) != self.foo_is_none: - raise ValueError("foo's value and foo_is_none disagree!") + if (sample_weight is None) != self.sample_weight_is_none: + raise ValueError( + "sample_weight's value and sample_weight_is_none disagree!" + ) - if foo is not None: - fit_params["foo"] = foo + if sample_weight is not None: + fit_params["sample_weight"] = sample_weight # meta-estimators are responsible for validating the given metadata metadata_request_factory(self).fit.validate_metadata( @@ -355,7 +367,7 @@ def get_metadata_request(self): # without masking them, before adding what's requested by the sub-estimator. # Passing ``super()`` here means only what's explicitly mentioned in the # methods' signature is considered as metadata consumed by this estimator; in -# this case fit's ``foo``. Let's see what the routing metadata looks like with +# this case fit's ``sample_weight``. Let's see what the routing metadata looks like with # different settings: # %% @@ -365,12 +377,16 @@ def get_metadata_request(self): # %% -# ``foo`` requested by child estimator -est = RouterConsumerClassifier(estimator=ExampleClassifier().fit_requests(foo=True)) +# ``sample_weight`` requested by child estimator +est = RouterConsumerClassifier( + estimator=ExampleClassifier().fit_requests(sample_weight=True) +) est.get_metadata_request()["fit"] # %% -# ``foo`` requested by meta-estimator -est = RouterConsumerClassifier(estimator=ExampleClassifier()).fit_requests(foo=True) +# ``sample_weight`` requested by meta-estimator +est = RouterConsumerClassifier(estimator=ExampleClassifier()).fit_requests( + sample_weight=True +) est.get_metadata_request()["fit"] # %% @@ -380,33 +396,33 @@ def get_metadata_request(self): # # Aliased on both est = RouterConsumerClassifier( - foo_is_none=False, - estimator=ExampleClassifier(foo_is_none=False).fit_requests( - foo="first_aliased_foo" + sample_weight_is_none=False, + estimator=ExampleClassifier(sample_weight_is_none=False).fit_requests( + sample_weight="first_aliased_sample_weight" ), -).fit_requests(foo="second_aliased_foo") +).fit_requests(sample_weight="second_aliased_sample_weight") est.get_metadata_request()["fit"] # %% # However, ``fit`` of the meta-estimator only needs the alias for the # sub-estimator: -est.fit(X, y, foo=my_weights, first_aliased_foo=my_other_weights) +est.fit(X, y, sample_weight=my_weights, first_aliased_sample_weight=my_other_weights) # %% # Alias only on the sub-estimator. This is useful if we don't want the # meta-estimator to use the metadata, and we only want the metadata to be used # by the sub-estimator. est = RouterConsumerClassifier( - estimator=ExampleClassifier().fit_requests(foo="aliased_foo") -).fit_requests(foo=True) + estimator=ExampleClassifier().fit_requests(sample_weight="aliased_sample_weight") +).fit_requests(sample_weight=True) est.get_metadata_request()["fit"] # %% # Alias only on the meta-estimator. This example raises an error since there -# will be two conflicting values for routing ``foo``. +# will be two conflicting values for routing ``sample_weight``. est = RouterConsumerClassifier( - estimator=ExampleClassifier().fit_requests(foo=True) -).fit_requests(foo="aliased_foo") + estimator=ExampleClassifier().fit_requests(sample_weight=True) +).fit_requests(sample_weight="aliased_sample_weight") try: est.get_metadata_request()["fit"] except ValueError as e: @@ -487,14 +503,14 @@ def get_metadata_request(self): class ExampleTransformer(TransformerMixin, BaseEstimator): - def fit(self, X, y, foo=None): - if foo is None: - raise ValueError("foo is None!") + def fit(self, X, y, sample_weight=None): + if sample_weight is None: + raise ValueError("sample_weight is None!") return self - def transform(self, X, bar=None): - if bar is None: - raise ValueError("bar is None!") + def transform(self, X, groups=None): + if groups is None: + raise ValueError("groups is None!") return X @@ -505,23 +521,25 @@ def transform(self, X, bar=None): est = SimplePipeline( transformer=ExampleTransformer() - # we transformer's fit to receive foo - .fit_requests(foo=True) - # we want transformer's transform to receive bar - .transform_requests(bar=True), + # we transformer's fit to receive sample_weight + .fit_requests(sample_weight=True) + # we want transformer's transform to receive groups + .transform_requests(groups=True), classifier=RouterConsumerClassifier( - foo_is_none=False, - estimator=ExampleClassifier(foo_is_none=False) - # we want this sub-estimator to receive foo in fit - .fit_requests(foo=True) - # but not bar in predict - .predict_requests(bar=False), + sample_weight_is_none=False, + estimator=ExampleClassifier(sample_weight_is_none=False) + # we want this sub-estimator to receive sample_weight in fit + .fit_requests(sample_weight=True) + # but not groups in predict + .predict_requests(groups=False), ).fit_requests( - # and we want the meta-estimator to receive foo as well - foo=True + # and we want the meta-estimator to receive sample_weight as well + sample_weight=True ), ) -est.fit(X, y, foo=my_weights, bar=my_groups).predict(X[:3], bar=my_groups) +est.fit(X, y, sample_weight=my_weights, groups=my_groups).predict( + X[:3], groups=my_groups +) # %% # Deprechation / Default Value Change From 93d448e4951031bb8bb2d06f0b846e6820443837 Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Sun, 12 Dec 2021 14:49:57 +0100 Subject: [PATCH 16/18] get_method_input->get_input --- examples/plot_metadata_routing.py | 42 +++++++++++++++--------------- sklearn/tests/test_props.py | 22 +++++++--------- sklearn/utils/metadata_requests.py | 2 +- 3 files changed, 32 insertions(+), 34 deletions(-) diff --git a/examples/plot_metadata_routing.py b/examples/plot_metadata_routing.py index 1087579ef495b..bc4f538ef74b0 100644 --- a/examples/plot_metadata_routing.py +++ b/examples/plot_metadata_routing.py @@ -100,7 +100,7 @@ def predict(self, X, groups=None): # %% # As you can see, now the two metadata have explicit request values, one is # requested and the other one is not. Instead of ``True`` and ``False``, we -# could also use the :class:`~sklearn.utils.metadata_requests.RequestType`` +# could also use the :class:`~sklearn.utils.metadata_requests.RequestType` # values. est = ( @@ -140,7 +140,7 @@ def fit(self, X, y, **fit_params): ) # we can use provided utility methods to map the given metadata to what # is required by the underlying estimator - fit_params_ = metadata_request_factory(self.estimator).fit.get_method_input( + fit_params_ = metadata_request_factory(self.estimator).fit.get_input( ignore_extras=False, kwargs=fit_params ) self.estimator_ = clone(self.estimator).fit(X, y, **fit_params_) @@ -154,9 +154,9 @@ def predict(self, X, **predict_params): ignore_extras=False, kwargs=predict_params ) # and then prepare the input to the underlying ``predict`` method. - predict_params_ = metadata_request_factory( - self.estimator_ - ).predict.get_method_input(ignore_extras=False, kwargs=predict_params) + predict_params_ = metadata_request_factory(self.estimator_).predict.get_input( + ignore_extras=False, kwargs=predict_params + ) return self.estimator_.predict(X, **predict_params_) def get_metadata_request(self): @@ -177,7 +177,7 @@ def get_metadata_request(self): # have such a method, then a default empty ``MetadataRequest`` is returned. # # Then in each method, we use the corresponding -# :method:`~utils.metadata_requests.MethodMetadataRequest.get_method_input` to +# :method:`~utils.metadata_requests.MethodMetadataRequest.get_input` to # construct a dictionary of the form ``{"metadata": value}`` to pass to the # underlying estimator's method. Please note that since in this example the # meta-estimator does not consume any of the given metadata itself, and there @@ -325,7 +325,7 @@ def fit(self, X, y, sample_weight, **fit_params): ) # we can use provided utility methods to map the given metadata to what # is required by the underlying estimator - fit_params_ = metadata_request_factory(self.estimator).fit.get_method_input( + fit_params_ = metadata_request_factory(self.estimator).fit.get_input( ignore_extras=False, kwargs=fit_params ) self.estimator_ = clone(self.estimator).fit(X, y, **fit_params_) @@ -339,9 +339,9 @@ def predict(self, X, **predict_params): ignore_extras=False, kwargs=predict_params ) # and then prepare the input to the underlying ``predict`` method. - predict_params_ = metadata_request_factory( - self.estimator_ - ).predict.get_method_input(ignore_extras=False, kwargs=predict_params) + predict_params_ = metadata_request_factory(self.estimator_).predict.get_input( + ignore_extras=False, kwargs=predict_params + ) return self.estimator_.predict(X, **predict_params_) def get_metadata_request(self): @@ -450,16 +450,16 @@ def fit(self, X, y, **fit_params): transformer_fit_params = metadata_request_factory( self.transformer - ).fit.get_method_input(ignore_extras=True, kwargs=fit_params) + ).fit.get_input(ignore_extras=True, kwargs=fit_params) transformer_transform_params = metadata_request_factory( self.transformer - ).transform.get_method_input(ignore_extras=True, kwargs=fit_params) + ).transform.get_input(ignore_extras=True, kwargs=fit_params) self.transformer_ = clone(self.transformer).fit(X, y, **transformer_fit_params) X_transformed = self.transformer_.transform(X, **transformer_transform_params) - classifier_fit_params = metadata_request_factory( - self.classifier - ).fit.get_method_input(ignore_extras=True, kwargs=fit_params) + classifier_fit_params = metadata_request_factory(self.classifier).fit.get_input( + ignore_extras=True, kwargs=fit_params + ) self.classifier_ = clone(self.classifier).fit( X_transformed, y, **classifier_fit_params ) @@ -470,12 +470,12 @@ def predict(self, X, **predict_params): transformer_transform_params = metadata_request_factory( self.transformer - ).transform.get_method_input(ignore_extras=True, kwargs=predict_params) + ).transform.get_input(ignore_extras=True, kwargs=predict_params) X_transformed = self.transformer_.transform(X, **transformer_transform_params) classifier_predict_params = metadata_request_factory( self.classifier - ).predict.get_method_input(ignore_extras=True, kwargs=predict_params) + ).predict.get_input(ignore_extras=True, kwargs=predict_params) return self.classifier_.predict(X_transformed, **classifier_predict_params) def get_metadata_request(self): @@ -558,7 +558,7 @@ def fit(self, X, y, **fit_params): metadata_request_factory(self).fit.validate_metadata( ignore_extras=False, self_metadata=super(), kwargs=fit_params ) - fit_params_ = metadata_request_factory(self.estimator).fit.get_method_input( + fit_params_ = metadata_request_factory(self.estimator).fit.get_input( ignore_extras=False, kwargs=fit_params ) self.estimator_ = clone(self.estimator).fit(X, y, **fit_params_) @@ -594,9 +594,9 @@ def fit(self, X, y, sample_weight=None, **fit_params): metadata_request_factory(self).fit.validate_metadata( ignore_extras=False, self_metadata=super(), kwargs=fit_params ) - estimator_fit_params = metadata_request_factory( - self.estimator - ).fit.get_method_input(ignore_extras=True, kwargs=fit_params) + estimator_fit_params = metadata_request_factory(self.estimator).fit.get_input( + ignore_extras=True, kwargs=fit_params + ) self.estimator_ = clone(self.estimator).fit(X, y, **estimator_fit_params) def get_metadata_request(self): diff --git a/sklearn/tests/test_props.py b/sklearn/tests/test_props.py index 8b84fa830251a..c5e1ea52cb67b 100644 --- a/sklearn/tests/test_props.py +++ b/sklearn/tests/test_props.py @@ -91,7 +91,7 @@ def fit(self, X, y, sample_weight=None, **kwargs): metadata_request_factory(self).fit.validate_metadata( ignore_extras=False, self_metadata=super(), kwargs=kwargs ) - fit_params = metadata_request_factory(self.estimator).fit.get_method_input( + fit_params = metadata_request_factory(self.estimator).fit.get_input( ignore_extras=True, kwargs=kwargs ) self.estimator_ = clone(self.estimator).fit(X, y, **fit_params) @@ -150,7 +150,7 @@ def fit(self, X, y=None, **fit_params): ignore_extras=False, kwargs=fit_params, ) - fit_params_ = metadata_request_factory(self.transformer).fit.get_method_input( + fit_params_ = metadata_request_factory(self.transformer).fit.get_input( ignore_extras=False, kwargs=fit_params ) self.transformer_ = clone(self.transformer).fit(X, y, **fit_params_) @@ -162,7 +162,7 @@ def transform(self, X, y=None, **transform_params): ) transform_params_ = metadata_request_factory( self.transformer - ).transform.get_method_input(ignore_extras=False, kwargs=transform_params) + ).transform.get_input(ignore_extras=False, kwargs=transform_params) return self.transformer_.transform(X, **transform_params_) @@ -180,12 +180,12 @@ def fit(self, X, y, **fit_params): X_transformed = X for step in self.steps[:-1]: requests = metadata_request_factory(step) - step_fit_params = requests.fit.get_method_input( + step_fit_params = requests.fit.get_input( ignore_extras=True, kwargs=fit_params ) transformer = clone(step).fit(X_transformed, y, **step_fit_params) self.steps_.append(transformer) - step_transform_params = requests.transform.get_method_input( + step_transform_params = requests.transform.get_input( ignore_extras=True, kwargs=fit_params ) X_transformed = transformer.transform( @@ -193,9 +193,7 @@ def fit(self, X, y, **fit_params): ) requests = metadata_request_factory(step) - step_fit_params = requests.fit.get_method_input( - ignore_extras=True, kwargs=fit_params - ) + step_fit_params = requests.fit.get_input(ignore_extras=True, kwargs=fit_params) self.steps_.append( clone(self.steps[-1]).fit(X_transformed, y, **step_fit_params) ) @@ -208,14 +206,14 @@ def predict(self, X, **predict_params): ignore_extras=False, kwargs=predict_params ) for step in self.steps_[:-1]: - step_transform_params = metadata_request_factory( - step - ).transform.get_method_input(ignore_extras=True, kwargs=predict_params) + step_transform_params = metadata_request_factory(step).transform.get_input( + ignore_extras=True, kwargs=predict_params + ) X_transformed = step.transform(X, **step_transform_params) step_predict_params = metadata_request_factory( self.steps_[-1] - ).predict.get_method_input(ignore_extras=True, kwargs=predict_params) + ).predict.get_input(ignore_extras=True, kwargs=predict_params) return self.steps_[-1].predict(X_transformed, **step_predict_params) def get_metadata_request(self): diff --git a/sklearn/utils/metadata_requests.py b/sklearn/utils/metadata_requests.py index f605ff5eee24e..c3463e2d3b167 100644 --- a/sklearn/utils/metadata_requests.py +++ b/sklearn/utils/metadata_requests.py @@ -316,7 +316,7 @@ def validate_metadata(self, ignore_extras=False, self_metadata=None, kwargs=None f"requested or not. In method: {self.name}" ) - def get_method_input(self, ignore_extras=False, kwargs=None): + def get_input(self, ignore_extras=False, kwargs=None): """Return the input parameters requested by the method. The output of this method can be used directly as the input to the From 167e4c290ab98bd17fcb7eced08c41053914fb68 Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Sun, 12 Dec 2021 14:54:00 +0100 Subject: [PATCH 17/18] minor comments from Guillaume --- examples/plot_metadata_routing.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/plot_metadata_routing.py b/examples/plot_metadata_routing.py index bc4f538ef74b0..63b73ea0a7f7e 100644 --- a/examples/plot_metadata_routing.py +++ b/examples/plot_metadata_routing.py @@ -282,7 +282,7 @@ def get_metadata_request(self): # %% # In the above example, this is how each ``fit`` method will call the -# sub-estimator's ``fit``: +# sub-estimator's ``fit``:: # # meta_est.fit(X, y, aliased_sample_weight=my_weights): # ... # this estimator (est), expects aliased_sample_weight as seen above @@ -398,15 +398,15 @@ def get_metadata_request(self): est = RouterConsumerClassifier( sample_weight_is_none=False, estimator=ExampleClassifier(sample_weight_is_none=False).fit_requests( - sample_weight="first_aliased_sample_weight" + sample_weight="clf_sample_weight" ), -).fit_requests(sample_weight="second_aliased_sample_weight") +).fit_requests(sample_weight="meta_clf_sample_weight") est.get_metadata_request()["fit"] # %% # However, ``fit`` of the meta-estimator only needs the alias for the # sub-estimator: -est.fit(X, y, sample_weight=my_weights, first_aliased_sample_weight=my_other_weights) +est.fit(X, y, sample_weight=my_weights, clf_sample_weight=my_other_weights) # %% # Alias only on the sub-estimator. This is useful if we don't want the From 20fe48aa987e39c12f3e5ed1550a4c74b5ed614b Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Mon, 13 Dec 2021 17:43:15 +0100 Subject: [PATCH 18/18] fix estimator checks tests --- sklearn/utils/tests/test_estimator_checks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/utils/tests/test_estimator_checks.py b/sklearn/utils/tests/test_estimator_checks.py index 9de14a3525eee..cb1fa7bc385fd 100644 --- a/sklearn/utils/tests/test_estimator_checks.py +++ b/sklearn/utils/tests/test_estimator_checks.py @@ -653,7 +653,7 @@ def __init__(self, you_should_set_this_=None): class ConformantEstimatorClassAttribute(BaseEstimator): # making sure our __metadata_request__* class attributes are okay! - __metadata_request__foo = {"fit": "foo"} + __metadata_request__fit = {"foo": True} msg = ( "Estimator estimator_name should not set any"