Skip to content

Base sample-prop implementation and docs (alternative to #21284) #22083

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
88 commits
Select commit Hold shift + click to select a range
dbead5c
initial base implementation commit
adrinjalali Oct 8, 2021
7868950
fix test_props and the issue with attribute starting with __
adrinjalali Oct 8, 2021
5793318
skip doctest in metadata_routing.rst for now
adrinjalali Oct 11, 2021
6696497
DOC explain why aliasing on sub-estimator of a consumer/router is useful
adrinjalali Oct 11, 2021
c0841c8
reduce diff
adrinjalali Oct 11, 2021
1aff2eb
DOC add user guide link to method docstrings
adrinjalali Oct 11, 2021
1457293
DOC apply Thomas's suggestions to the rst file
adrinjalali Oct 13, 2021
af86e82
CLN address a few comments in docs
adrinjalali Oct 25, 2021
4c228cf
Merge remote-tracking branch 'upstream/sample-props' into sample-prop…
adrinjalali Oct 25, 2021
11649d9
ignore sentinel docstring check
adrinjalali Oct 25, 2021
b5c962c
handling backward compatibility and deprecation prototype
adrinjalali Nov 5, 2021
fb200e2
Update examples/plot_metadata_routing.py
adrinjalali Dec 7, 2021
6f849b2
make __metadata_request__* format more intuitive and less redundant
adrinjalali Dec 10, 2021
82b2128
metadata_request_factory always returns a copy
adrinjalali Dec 10, 2021
6f3f590
Merge remote-tracking branch 'upstream/main' into sample-props-base
adrinjalali Dec 11, 2021
16c47b2
fix tests for the changed __metadata_request__* format
adrinjalali Dec 12, 2021
1c591fe
in example: foo->sample_weight, bar->groups
adrinjalali Dec 12, 2021
93d448e
get_method_input->get_input
adrinjalali Dec 12, 2021
167e4c2
minor comments from Guillaume
adrinjalali Dec 12, 2021
3d199ee
Merge branch 'sample-props-base' of github.com:adrinjalali/scikit-lea…
adrinjalali Dec 12, 2021
20fe48a
fix estimator checks tests
adrinjalali Dec 13, 2021
39a462d
Improved sample props developer API
adrinjalali Dec 21, 2021
bd5ae36
fixes, updated doc, decorator
adrinjalali Dec 22, 2021
79b43f1
Add docstrings and some API cleanup
adrinjalali Dec 27, 2021
515d00c
unify serialize/deserialize methods
adrinjalali Dec 27, 2021
01c942a
Merge remote-tracking branch 'upstream/sample-props' into sample-prop…
adrinjalali Dec 27, 2021
346532d
Add more docstring to process_routing
adrinjalali Dec 27, 2021
e2adff0
fix MetadataRouter.get_params parameter mismatch
adrinjalali Dec 27, 2021
98af496
DOC add missing name to MethodMetadataRequest.deserialize docstring
adrinjalali Dec 27, 2021
93f0698
DOC add MethodMapping.add docstring
adrinjalali Dec 27, 2021
93e7b5e
DOC fix colons after versionadded
adrinjalali Dec 28, 2021
3426d54
fix {method}_requests return type annotation
adrinjalali Dec 28, 2021
d07f949
metadata_request_factory -> metadata_router_factory and docstring fixes
adrinjalali Dec 30, 2021
b0cfdd5
move 'me' out of the map in MetadataRouter
adrinjalali Dec 30, 2021
6d3942f
more docstring refinements
adrinjalali Dec 30, 2021
6b3c2d1
cleanup API addresses and create a utils.metadata_routing sub-folder
adrinjalali Dec 30, 2021
9f0741e
fix module import issue
adrinjalali Dec 30, 2021
e2c9376
more tests and a few bug fixes
adrinjalali Jan 2, 2022
c99b340
Merge remote-tracking branch 'upstream/sample-props' into sample-prop…
adrinjalali Jan 2, 2022
0ad69f2
Joel's comments
adrinjalali Jan 3, 2022
c4eb53e
make process_routing a function
adrinjalali Jan 4, 2022
16fc971
docstring fix
adrinjalali Jan 4, 2022
19da9f7
^type -> $type
adrinjalali Jan 4, 2022
1adc00b
remove deserialize, return instance, and add type as an attribute
adrinjalali Jan 9, 2022
13dc2ff
remove sentinels and use strings instead
adrinjalali Jan 9, 2022
f8e5005
make RequestType searchable and check for valid identifier
adrinjalali Jan 9, 2022
6f54e2c
Route -> MethodPair
adrinjalali Jan 9, 2022
32c7a52
remove unnecessary sorted
adrinjalali Jan 9, 2022
1bfe7ae
Merge remote-tracking branch 'upstream/sample-props' into sample-prop…
adrinjalali Jan 10, 2022
4c5ebfc
clarification on usage of the process_routing func in the example
adrinjalali Jan 10, 2022
227e727
only print methods with non-empty requests
adrinjalali Jan 10, 2022
8f6dbd7
fix test_string_representations
adrinjalali Jan 11, 2022
c834ba3
remove source build cache from CircleCI (temporarily)
adrinjalali Jan 11, 2022
4d067d1
Trigger CI
ogrisel Jan 12, 2022
4fc1ac1
Invalidate linux-arm64 ccache my changing the key
ogrisel Jan 12, 2022
59b779e
Trigger CI
ogrisel Jan 12, 2022
4da93c6
method, used_in -> callee, caller
adrinjalali Jan 13, 2022
27ba25c
show RequestType instead of RequestType.value in _serialize()
adrinjalali Jan 14, 2022
78de01c
more informative error messages
adrinjalali Jan 14, 2022
2f03a1b
fix checking for conflicting keys
adrinjalali Jan 14, 2022
ada1b69
Merge remote-tracking branch 'upstream/sample-props' into sample-prop…
adrinjalali Jan 16, 2022
6df8049
get_router_for_object -> get_routing_for_object
adrinjalali Jan 18, 2022
d52f2f6
\{method\}_requests -> set_\{method\}_request
adrinjalali Feb 9, 2022
4d908e3
Merge remote-tracking branch 'upstream/main' into sample-props-base-a…
adrinjalali Feb 9, 2022
3bfe856
address metadata_routing.rst comments
adrinjalali Feb 15, 2022
5fb9366
some test enhancements
adrinjalali Feb 15, 2022
c259c93
TypeError for extra arguments
adrinjalali Feb 16, 2022
c278696
add_request: prop -> param
adrinjalali Feb 16, 2022
3ebbf69
original_names -> return_alias
adrinjalali Feb 16, 2022
d75a803
Merge remote-tracking branch 'upstream/sample-props' into sample-prop…
adrinjalali Feb 16, 2022
beb2544
add more tests for MetadataRouter and MethodMapping
adrinjalali Feb 17, 2022
ad63680
more suggestions from Joel's review
adrinjalali Feb 22, 2022
71e79df
fix return type
adrinjalali Feb 24, 2022
5871e50
apply more suggestions from Joel's review
adrinjalali Feb 25, 2022
e3f897c
Merge remote-tracking branch 'upstream/sample-props' into sample-prop…
adrinjalali Feb 25, 2022
c380ad7
Christian\'s suggestions
adrinjalali Feb 28, 2022
bae8402
more notes from Christian
adrinjalali Feb 28, 2022
8ca978b
test_get_routing_for_object returns empty requests on unknown objects
adrinjalali Mar 1, 2022
1aed83c
more notes from Christian
adrinjalali Mar 1, 2022
5fda075
remove double line break
adrinjalali Mar 1, 2022
7663cfe
more notes from Christian
adrinjalali Mar 2, 2022
2f51480
more notes from Christian
adrinjalali Mar 3, 2022
d87cbc2
make type private
adrinjalali Mar 7, 2022
46dccf2
add more comments/docs
adrinjalali Mar 7, 2022
e5d46e3
fix test
adrinjalali Mar 7, 2022
70f8c6c
fix nits
adrinjalali Mar 10, 2022
560a2da
add forgotten nit
adrinjalali Mar 10, 2022
e932501
Merge branch 'sample-props' into sample-props-base-alternate2
adrinjalali Mar 10, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,10 @@ jobs:
- checkout
- run: ./build_tools/circle/checkout_merge_commit.sh
- restore_cache:
key: linux-arm64-{{ .Branch }}
key: linux-arm64-ccache-v1-{{ .Branch }}
- run: ./build_tools/circle/build_test_arm.sh
- save_cache:
key: linux-arm64-{{ .Branch }}
key: linux-arm64-ccache-v1-{{ .Branch }}
paths:
- ~/.cache/ccache
- ~/.cache/pip
Expand Down
7 changes: 7 additions & 0 deletions doc/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,13 @@ def pytest_runtest_setup(item):
setup_preprocessing()
elif fname.endswith("statistical_inference/unsupervised_learning.rst"):
setup_unsupervised_learning()
elif fname.endswith("metadata_routing.rst"):
# TODO: remove this once implemented
# Skip metarouting because is it is not fully implemented yet
raise SkipTest(
"Skipping doctest for metadata_routing.rst because it "
"is not fully implemented yet"
)

rst_files_requiring_matplotlib = [
"modules/partial_dependence.rst",
Expand Down
215 changes: 215 additions & 0 deletions doc/metadata_routing.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@

.. _metadata_routing:

.. TODO: update doc/conftest.py once document is updated and examples run.

Metadata Routing
================

This guide demonstrates how metadata such as ``sample_weight`` can be routed
and passed along to estimators, scorers, and CV splitters through
meta-estimators such as ``Pipeline`` and ``GridSearchCV``. In order to pass
metadata to a method such as ``fit`` or ``score``, the object accepting the
metadata, must *request* it. For estimators and splitters this is done via
``set_*_request`` methods, e.g. ``set_fit_request(...)``, and for scorers this
is done via ``set_score_request`` method. For grouped splitters such as
``GroupKFold`` a ``groups`` parameter is requested by default. This is best
demonstrated by the following examples.

If you are developing a scikit-learn compatible estimator or meta-estimator,
you can check our related developer guide:
:ref:`sphx_glr_auto_examples_plot_metadata_routing.py`.

Usage Examples
**************
Here we present a few examples to show different common use-cases. The examples
in this section require the following imports and data::

>>> import numpy as np
>>> from sklearn.metrics import make_scorer, accuracy_score
>>> from sklearn.linear_model import LogisticRegressionCV
>>> from sklearn.linear_model import LogisticRegression
>>> from sklearn.model_selection import cross_validate
>>> from sklearn.model_selection import GridSearchCV
>>> from sklearn.model_selection import GroupKFold
>>> from sklearn.feature_selection import SelectKBest
>>> from sklearn.utils.metadata_requests import RequestType
>>> from sklearn.pipeline import make_pipeline
>>> n_samples, n_features = 100, 4
>>> X = np.random.rand(n_samples, n_features)
>>> y = np.random.randint(0, 2, size=n_samples)
>>> my_groups = np.random.randint(0, 10, size=n_samples)
>>> my_weights = np.random.rand(n_samples)
>>> my_other_weights = np.random.rand(n_samples)

Weighted scoring and fitting
----------------------------

Here ``GroupKFold`` requests ``groups`` by default. However, we need to
explicitly request weights in ``make_scorer`` and for ``LogisticRegressionCV``.
Both of these *consumers* know how to use metadata called ``"sample_weight"``::

>>> weighted_acc = make_scorer(accuracy_score).set_score_request(
... sample_weight=True
... )
>>> lr = LogisticRegressionCV(
... cv=GroupKFold(), scoring=weighted_acc,
... ).set_fit_request(sample_weight=True)
>>> cv_results = cross_validate(
... lr,
... X,
... y,
... cv=GroupKFold(),
... props={"sample_weight": my_weights, "groups": my_groups},
... scoring=weighted_acc,
... )

Note that in this example, ``my_weights`` is passed to both the scorer and
:class:`~linear_model.LogisticRegressionCV`.

Error handling: if ``props={"sample_weigh": my_weights, ...}`` were passed
(note the typo), ``cross_validate`` would raise an error, since
``sample_weigh`` was not requested by any of its children.

Weighted scoring and unweighted fitting
---------------------------------------

All scikit-learn estimators requires weights to be either explicitly requested
or not requested (i.e. ``UNREQUESTED``) when used in another router such as a
``Pipeline`` or a ``*GridSearchCV``. To perform a unweighted fit, we need to
configure :class:`~linear_model.LogisticRegressionCV` to not request sample
weights, so that :func:`~model_selection.cross_validate` does not pass the
weights along::

>>> weighted_acc = make_scorer(accuracy_score).set_score_request(
... sample_weight=True
... )
>>> lr = LogisticRegressionCV(
... cv=GroupKFold(), scoring=weighted_acc,
... ).set_fit_request(sample_weight=RequestType.UNREQUESTED)
>>> cv_results = cross_validate(
... lr,
... X,
... y,
... cv=GroupKFold(),
... props={"sample_weight": my_weights, "groups": my_groups},
... scoring=weighted_acc,
... )

Note the usage of ``RequestType`` which in this case is equivalent to
``False``; the type is explained further at the end of this document.

If :class:`~linear_model.LogisticRegressionCV` does not call
``set_fit_request``, :func:`~model_selection.cross_validate` will raise an
error because weights is passed in but
:class:`~linear_model.LogisticRegressionCV` would not be explicitly configured
to recognize the weights.

Unweighted feature selection
----------------------------

Unlike ``LogisticRegressionCV``, ``SelectKBest`` doesn't accept weights and
therefore `"sample_weight"` is not routed to it::

>>> weighted_acc = make_scorer(accuracy_score).set_score_request(
... sample_weight=True
... )
>>> lr = LogisticRegressionCV(
... cv=GroupKFold(), scoring=weighted_acc,
... ).set_fit_request(sample_weight=True)
>>> sel = SelectKBest(k=2)
>>> pipe = make_pipeline(sel, lr)
>>> cv_results = cross_validate(
... pipe,
... X,
... y,
... cv=GroupKFold(),
... props={"sample_weight": my_weights, "groups": my_groups},
... scoring=weighted_acc,
... )

Advanced: Different scoring and fitting weights
-----------------------------------------------

Despite ``make_scorer`` and ``LogisticRegressionCV`` both expecting the key
``sample_weight``, we can use aliases to pass different weights to different
consumers. In this example, we pass ``scoring_weight`` to the scorer, and
``fitting_weight`` to ``LogisticRegressionCV``::

>>> weighted_acc = make_scorer(accuracy_score).set_score_request(
... sample_weight="scoring_weight"
... )
>>> lr = LogisticRegressionCV(
... cv=GroupKFold(), scoring=weighted_acc,
... ).set_fit_request(sample_weight="fitting_weight")
>>> cv_results = cross_validate(
... lr,
... X,
... y,
... cv=GroupKFold(),
... props={
... "scoring_weight": my_weights,
... "fitting_weight": my_other_weights,
... "groups": my_groups,
... },
... scoring=weighted_acc,
... )

API Interface
*************

A *consumer* is an object (estimator, meta-estimator, scorer, splitter) which
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Worth adding a glossary entry for metadata consumer, metadata router and metadata request based on sme of this text?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some of this seems to be over-complicated, in the sense that a consumer is just an estimator that can make use of some metadata, and a router is just a SLEP006-compliant estimator that calls a method or function which accepts metadata.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure how to simplify the text AND be accurate: e.g. a consumer is not just an estimator. It can be any other object which is used in the router.

I think we can add the glossary entries once we're set on these definitions.

accepts and uses some metadata in at least one of its methods (``fit``,
``predict``, ``inverse_transform``, ``transform``, ``score``, ``split``).
Meta-estimators which only forward the metadata to other objects (the child
estimator, scorers, or splitters) and don't use the metadata themselves are not
consumers. (Meta)Estimators which route metadata to other objects are
*routers*. An (meta)estimator can be a consumer and a router at the same time.
(Meta)Estimators and splitters expose a ``set_*_request`` method for each
method which accepts at least one metadata. For instance, if an estimator
supports ``sample_weight`` in ``fit`` and ``score``, it exposes
``estimator.set_fit_request(sample_weight=value)`` and
``estimator.set_score_request(sample_weight=value)``. Here ``value`` can be:

- ``RequestType.REQUESTED`` or ``True``: method requests a ``sample_weight``.
This means if the metadata is provided, it will be used, otherwise no error
is raised.
- ``RequestType.UNREQUESTED`` or ``False``: method does not request a
``sample_weight``.
- ``RequestType.ERROR_IF_PASSED`` or ``None``: router will raise an error if
``sample_weight`` is passed. This is in almost all cases the default value
when an object is instantiated and ensures the user sets the metadata
requests explicitly when a metadata is passed. The only exception are
``Group*Fold`` splitters.
- ``"param_name"``: if this estimator is used in a meta-estimator, the
meta-estimator should forward ``"param_name"`` as ``sample_weight`` to this
estimator. This means the mapping between the metadata required by the
object, e.g. ``sample_weight`` and what is provided by the user, e.g.
``my_weights`` is done at the router level, and not by the object, e.g.
estimator, itself.

For the scorers, this is done the same way, using ``set_score_request`` method.

If a metadata, e.g. ``sample_weight``, is passed by the user, the metadata
request for all objects which potentially can accept ``sample_weight`` should
be set by the user, otherwise an error is raised by the router object. For
example, the following code raises an error, since it hasn't been explicitly
specified whether ``sample_weight`` should be passed to the estimator's scorer
or not::

>>> param_grid = {"C": [0.1, 1]}
>>> lr = LogisticRegression().set_fit_request(sample_weight=True)
>>> try:
... GridSearchCV(
... estimator=lr, param_grid=param_grid
... ).fit(X, y, sample_weight=my_weights)
... except ValueError as e:
... print(e)
sample_weight is passed but is not explicitly set as requested or not for
LogisticRegression.score

The issue can be fixed by explicitly setting the request value::

>>> lr = LogisticRegression().set_fit_request(
... sample_weight=True
... ).set_score_request(sample_weight=False)
7 changes: 7 additions & 0 deletions doc/modules/classes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ Base classes
base.DensityMixin
base.RegressorMixin
base.TransformerMixin
base.MetaEstimatorMixin
feature_selection.SelectorMixin

Functions
Expand Down Expand Up @@ -1640,6 +1641,12 @@ Plotting
utils.validation.column_or_1d
utils.validation.has_fit_parameter
utils.all_estimators
utils.metadata_routing.RequestType
utils.metadata_routing.get_routing_for_object
utils.metadata_routing.MetadataRouter
utils.metadata_routing.MetadataRequest
utils.metadata_routing.MethodMapping
utils.metadata_routing.process_routing

Utilities from joblib:

Expand Down
1 change: 1 addition & 0 deletions doc/user_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ User Guide
visualizations.rst
data_transforms.rst
datasets.rst
metadata_routing.rst
computing.rst
model_persistence.rst
common_pitfalls.rst
Loading