Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
dbead5c
initial base implementation commit
adrinjalali Oct 8, 2021
7868950
fix test_props and the issue with attribute starting with __
adrinjalali Oct 8, 2021
5793318
skip doctest in metadata_routing.rst for now
adrinjalali Oct 11, 2021
6696497
DOC explain why aliasing on sub-estimator of a consumer/router is useful
adrinjalali Oct 11, 2021
c0841c8
reduce diff
adrinjalali Oct 11, 2021
1aff2eb
DOC add user guide link to method docstrings
adrinjalali Oct 11, 2021
1457293
DOC apply Thomas's suggestions to the rst file
adrinjalali Oct 13, 2021
af86e82
CLN address a few comments in docs
adrinjalali Oct 25, 2021
4c228cf
Merge remote-tracking branch 'upstream/sample-props' into sample-prop…
adrinjalali Oct 25, 2021
11649d9
ignore sentinel docstring check
adrinjalali Oct 25, 2021
b5c962c
handling backward compatibility and deprecation prototype
adrinjalali Nov 5, 2021
fb200e2
Update examples/plot_metadata_routing.py
adrinjalali Dec 7, 2021
6f849b2
make __metadata_request__* format more intuitive and less redundant
adrinjalali Dec 10, 2021
82b2128
metadata_request_factory always returns a copy
adrinjalali Dec 10, 2021
6f3f590
Merge remote-tracking branch 'upstream/main' into sample-props-base
adrinjalali Dec 11, 2021
16c47b2
fix tests for the changed __metadata_request__* format
adrinjalali Dec 12, 2021
1c591fe
in example: foo->sample_weight, bar->groups
adrinjalali Dec 12, 2021
93d448e
get_method_input->get_input
adrinjalali Dec 12, 2021
167e4c2
minor comments from Guillaume
adrinjalali Dec 12, 2021
3d199ee
Merge branch 'sample-props-base' of github.com:adrinjalali/scikit-lea…
adrinjalali Dec 12, 2021
20fe48a
fix estimator checks tests
adrinjalali Dec 13, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions doc/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,13 @@ def pytest_runtest_setup(item):
setup_preprocessing()
elif fname.endswith("statistical_inference/unsupervised_learning.rst"):
setup_unsupervised_learning()
elif fname.endswith("metadata_routing.rst"):
# TODO: remove this once implemented
# Skip metarouting because is it is not fully implemented yet
raise SkipTest(
"Skipping doctest for metadata_routing.rst because it "
"is not fully implemented yet"
)

rst_files_requiring_matplotlib = [
"modules/partial_dependence.rst",
Expand Down
204 changes: 204 additions & 0 deletions doc/metadata_routing.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@

.. _metadata_routing:

.. TODO: update doc/conftest.py once document is updated and examples run.

Metadata Routing
================

This guide demonstrates how metadata such as ``sample_weight`` can be routed
and passed along to estimators, scorers, and CV splitters through
meta-estimators such as ``Pipeline`` and ``GridSearchCV``. In order to pass
metadata to a method such as ``fit`` or ``score``, the object accepting the
metadata, must *request* it. For estimators and splitters this is done via
``*_requests`` methods, e.g. ``fit_requests(...)``, and for scorers this is
done via ``score_requests`` method of a scorer. For grouped splitters such as
``GroupKFold`` a ``groups`` parameter is requested by default. This is best
demonstrated by the following examples.

Usage Examples
**************
Here we present a few examples to show different common use-cases. The examples
in this section require the following imports and data::

>>> import numpy as np
>>> from sklearn.metrics import make_scorer, accuracy_score
>>> from sklearn.linear_model import LogisticRegressionCV
>>> from sklearn.linear_model import LogisticRegression
>>> from sklearn.model_selection import cross_validate
>>> from sklearn.model_selection import GridSearchCV
>>> from sklearn.model_selection import GroupKFold
>>> from sklearn.feature_selection import SelectKBest
>>> from sklearn.pipeline import make_pipeline
>>> n_samples, n_features = 100, 4
>>> X = np.random.rand(n_samples, n_features)
>>> y = np.random.randint(0, 2, size=n_samples)
>>> my_groups = np.random.randint(0, 10, size=n_samples)
>>> my_weights = np.random.rand(n_samples)
>>> my_other_weights = np.random.rand(n_samples)
Comment on lines +33 to +38
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For reproducibility, we can define a RandomState object.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will need to rework this rst file once things are implemented anyway, this example doesn't run the code for now.


Weighted scoring and fitting
----------------------------

Here ``GroupKFold`` requests ``groups`` by default. However, we need to
explicitly request weights in ``make_scorer`` and for ``LogisticRegressionCV``.
Both of these *consumers* understand the meaning of the key
``"sample_weight"``::

>>> weighted_acc = make_scorer(accuracy_score).score_requests(
... sample_weight=True
... )
>>> lr = LogisticRegressionCV(
... cv=GroupKFold(), scoring=weighted_acc,
... ).fit_requests(sample_weight=True)
>>> cv_results = cross_validate(
... lr,
... X,
... y,
... cv=GroupKFold(),
... props={"sample_weight": my_weights, "groups": my_groups},
... scoring=weighted_acc,
... )

Note that in this example, ``my_weights`` is passed to both the scorer and
``~linear_model.LogisticRegressionCV``.

Error handling: if ``props={'sample_weigh': my_weights, ...}`` were passed
(note the typo), cross_validate would raise an error, since 'sample_weigh' was
not requested by any of its children.

Weighted scoring and unweighted fitting
---------------------------------------

All scikit-learn estimators requires weights to be explicitly requested or not
requested. To perform a unweighted fit, we need to configure
:class:`~linear_model.LogisticRegressionCV` to not request sample weights, so
that :func:`~model_selection.cross_validate` does not pass the weights along::

>>> weighted_acc = make_scorer(accuracy_score).score_requests(
... sample_weight=True
... )
>>> lr = LogisticRegressionCV(
... cv=GroupKFold(), scoring=weighted_acc,
... ).fit_requests(sample_weight=False)
>>> cv_results = cross_validate(
... lr,
... X,
... y,
... cv=GroupKFold(),
... props={"sample_weight": my_weights, "groups": my_groups},
... scoring=weighted_acc,
... )

If :class:`~linear_model.LogisticRegressionCV` did not call ``fit_requests``,
:func:`~model_selection.cross_validate` will raise an error because weights is
passed in but :class:`~linear_model.LogisticRegressionCV` was not configured to
recognize the weights.
Comment on lines +93 to +96
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this error would be raised no matter if scoring is weighted or unweighted?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes


Unweighted feature selection
----------------------------

Unlike ``LogisticRegressionCV``, ``SelectKBest`` doesn't accept weights and
therefore `"sample_weight"` is not routed to it::

>>> weighted_acc = make_scorer(accuracy_score).score_requests(
... sample_weight=True
... )
>>> lr = LogisticRegressionCV(
... cv=GroupKFold(), scoring=weighted_acc,
... ).fit_requests(sample_weight=True)
>>> sel = SelectKBest(k=2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this workflow break if SelectKBest were to accept weights in the future?

In other words, let's say 1.3 SelectKBest accepts weights, would we need to call fit_requests(sample_weight=False) to have the same behavior?

I guess we would need to deprecation cycle migrating from RequestType.UNREQUESTED to RequestType.ERROR_IF_PASSED as the default.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, all you say is true, and the deprecation cycle is not hard to implement, since we have the mechanism of having UNREQUESTED as the default.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can also implement a RequestType.DEPRECATED or something, to make the deprecation easier if necessary.

Copy link
Member

@jnothman jnothman Oct 21, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm... The case of a metaestimator adding support for a prop that is requested by its child is indeed a tricky one. I can't yet see a way to make this generally backwards compatible within the SLEP006 proposal. This makes me sad.

Indeed, generally a metaestimator supporting the same prop name as one of its children is tricky. I.e. if the metaestimator supports metadata x and its child requests metadata x, the metaestimator should only work where either:

  • the child's request aliases x to another name without such a clash;
  • the child's request and the metaestimator's request for x implies being passed the same metadata.

In other cases, this must raise an error. This is something, I'm pretty sure, we've not yet covered in SLEP006 (and it's a pretty messy and intricate consequence of having the caller responsible for delivering metadata in accordance with the request).

Deprecation would be pretty tricky as far as I can tell.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is also the obligation of a user to know which estimator supports which property. This could be confusing.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lorentzenchr Are you referring to a case where the user does not already need to know which estimator supports which property? I'm not sure what burden you are referring to

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just realised my response didn't really relate to @thomasjpfan's question which wasn't about a metaestimator adding support. Anyway, I've opened this issue as scikit-learn/enhancement_proposals#58

>>> pipe = make_pipeline(sel, lr)
>>> cv_results = cross_validate(
... pipe,
... X,
... y,
... cv=GroupKFold(),
... props={"sample_weight": my_weights, "groups": my_groups},
... scoring=weighted_acc,
... )

Different scoring and fitting weights
-------------------------------------

Despite ``make_scorer`` and ``LogisticRegressionCV`` both expecting a key
``sample_weight``, we can use aliases to pass different weights to different
consumers. In this example, we pass ``scoring_weight`` to the scorer, and
``fitting_weight`` to ``LogisticRegressionCV``::

>>> weighted_acc = make_scorer(accuracy_score).score_requests(
... sample_weight="scoring_weight"
... )
>>> lr = LogisticRegressionCV(
... cv=GroupKFold(), scoring=weighted_acc,
... ).fit_requests(sample_weight="fitting_weight")
>>> cv_results = cross_validate(
... lr,
... X,
... y,
... cv=GroupKFold(),
... props={
... "scoring_weight": my_weights,
... "fitting_weight": my_other_weights,
... "groups": my_groups,
... },
... scoring=weighted_acc,
... )

API Interface
*************

A *consumer* is an object (estimator, meta-estimator, scorer, splitter) which
accepts and uses some metadata in at least one of their methods (``fit``,
``predict``, ``inverse_transform``, ``transform``, ``score``, ``split``).
Meta-estimators which only forward the metadata to other objects (the child
estimator, scorers, or splitters) and don't use the metadata themselves are not
consumers. (Meta)Estimators which route metadata to other objects are routers.
An (meta)estimator can be a consumer and a router at the same time.
(Meta)Estimators and splitters expose a ``*_requests`` method for each method
which accepts at least one metadata. For instance, if an estimator supports
``sample_weight`` in ``fit`` and ``score``, it exposes
``estimator.fit_requests(sample_weight=value)`` and
``estimator.score_requests(sample_weight=value)``. Here ``value`` can be:

- ``RequestType.REQUESTED`` or ``True``: method requests a ``sample_weight``.
This means if the metadata is provided, it will be used, otherwise no error
is raised.
- ``RequestType.UNREQUESTED`` or ``False``: method does not request a
``sample_weight``.
- ``RequestType.ERROR_IF_PASSED`` or ``None``: router will raise an error if
``sample_weight`` is passed. This is in almost all cases the default value
when an object is instantiated and ensures the user sets the metadata
requests explicitly when a metadata is passed. The only exception are
``Group*Fold`` splitters.
- ``"param_name"``: if this estimator is used in a meta-estimator, the
meta-estimator should forward ``"param_name"`` as ``sample_weight`` to this
estimator. This means the mapping between the metadata required by the
object, e.g. ``sample_weight`` and what is provided by the user, e.g.
``my_weights`` is done at the router level, and not by the object, e.g.
estimator, itself.

For the scorers, this is done the same way, using ``.score_requests`` method.

If a metadata, e.g. ``sample_weight``, is passed by the user, the metadata
request for all objects which potentially can accept ``sample_weight`` should
be set by the user, otherwise an error is raised by the router object. For
example, the following code would raise, since it hasn't been explicitly set
whether ``sample_weight`` should be passed to the estimator's scorer or not::

>>> param_grid = {"C": [0.1, 1]}
>>> lr = LogisticRegression().fit_requests(sample_weight=True)
>>> try:
... GridSearchCV(
... estimator=lr, param_grid=param_grid
... ).fit(X, y, sample_weight=my_weights)
... except ValueError as e:
... print(e)
sample_weight is passed but is not explicitly set as requested or not. In
method: score

The issue can be fixed by explicitly setting the request value::

>>> lr = LogisticRegression().fit_requests(
... sample_weight=True
... ).score_requests(sample_weight=False)
1 change: 1 addition & 0 deletions doc/user_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,4 @@ User Guide
computing.rst
modules/model_persistence.rst
common_pitfalls.rst
metadata_routing.rst
Loading