Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
- |Feature| Added support for default metadata routing through
``set_config(enable_metadata_routing="default_routing")``. This enables
automatic routing of common metadata like ``sample_weight`` and ``groups`` by
default. Estimators can now define instance-level default routing via
``__sklearn_default_request__`` method, which complements the existing
class-level defaults set through ``__metadata_request__*`` attributes.
By `Adrin Jalali`_.
53 changes: 52 additions & 1 deletion examples/miscellaneous/plot_metadata_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

import numpy as np

from sklearn import set_config
from sklearn import config_context, set_config
from sklearn.base import (
BaseEstimator,
ClassifierMixin,
Expand Down Expand Up @@ -464,6 +464,57 @@ def predict(self, X, **predict_params):
# it passed as `sample_weight`. This would apply even if
# `set_fit_request(sample_weight=True)` was set on it.

# %%
# Auto-Requesting Metadata
# ------------------------
# There are two ways a class can modify the default metadata routing requests:
#
# 1. Class-level defaults using `__metadata_request__method` class attributes,
# which set default request values for all instances of a class, and can even
# remove a metadata from the metadata routing machinery if necessary.
# 2. Instance-level defaults via the `add_auto_request` method, which would only
# request the metadata if ``set_config(metadata_request_policy="auto")`` is
# set.
#
# Here's an example demonstrating both approaches:


class DefaultRoutingClassifier(ClassifierMixin, BaseEstimator):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we rename this, since it's not a router?

Suggested change
class DefaultRoutingClassifier(ClassifierMixin, BaseEstimator):
class DefaultClassifier(ClassifierMixin, BaseEstimator):

# Class-level default request for fit method
__metadata_request__fit = {"sample_weight": True}

def get_metadata_routing(self):
# Each instance can configure metadata which should be requested by default if
# `set_config(metadata_request_policy="auto")` is set. The `add_auto_request`
# method does this.
requests = super().get_metadata_routing()
requests.predict.add_auto_request("groups")
return requests

def fit(self, X, y, sample_weight=None):
check_metadata(self, sample_weight=sample_weight)
self.classes_ = np.array([0, 1])
return self

def predict(self, X, groups=None):
check_metadata(self, groups=groups)
return np.ones(len(X))


# Let's see the default routing configuration
clf = DefaultRoutingClassifier()
print_routing(clf)

# %%
# And now with default routing enabled:
with config_context(enable_metadata_routing="default_routing"):
print_routing(clf)

# %%
# The routing can still be modified using set_*_request methods
clf.set_fit_request(sample_weight=False)
print_routing(clf)

# %%
# Simple Pipeline
# ---------------
Expand Down
36 changes: 34 additions & 2 deletions sklearn/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"array_api_dispatch": False,
"transform_output": "default",
"enable_metadata_routing": False,
"metadata_request_policy": "empty",
"skip_parameter_validation": False,
}
_threadlocal = threading.local()
Expand Down Expand Up @@ -67,6 +68,7 @@ def set_config(
array_api_dispatch=None,
transform_output=None,
enable_metadata_routing=None,
metadata_request_policy=None,
skip_parameter_validation=None,
):
"""Set global scikit-learn configuration.
Expand Down Expand Up @@ -151,7 +153,7 @@ def set_config(
.. versionadded:: 1.4
`"polars"` option was added.

enable_metadata_routing : bool, default=None
enable_metadata_routing : bool, str, default=None
Enable metadata routing. By default this feature is disabled.

Refer to :ref:`metadata routing user guide <metadata_routing>` for more
Expand All @@ -163,6 +165,19 @@ def set_config(

.. versionadded:: 1.3

metadata_request_policy : str, default=None
Configure the default metadata request policy.

The default value of this configuration is "empty". Refer to :ref:`metadata
routing user guide <metadata_routing>` for more details.

- `"empty"`: No metadata is requested by default.
- `"auto"`: Metadata is requested if the consumer has flagged it as an
auto-request.
Comment on lines +175 to +176
Copy link
Member

@StefanieSenger StefanieSenger May 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trying to find a formulation that communicates that this entails that a sub-part of the metadata can be requested. I think it's not clear otherwise.

I would also avoid talking about "the" consumer, since there can be several.

Suggested change
- `"auto"`: Metadata is requested if the consumer has flagged it as an
auto-request.
- `"auto"`: Only the subset of metadata flagged by consumers for
auto-request will be requested.

- `None`: Configuration is unchanged.

.. versionadded:: 1.8

skip_parameter_validation : bool, default=None
If `True`, disable the validation of the hyper-parameters' types and values in
the fit method of estimators and for arguments passed to public helper
Expand Down Expand Up @@ -207,6 +222,8 @@ def set_config(
local_config["transform_output"] = transform_output
if enable_metadata_routing is not None:
local_config["enable_metadata_routing"] = enable_metadata_routing
if metadata_request_policy is not None:
local_config["metadata_request_policy"] = metadata_request_policy
if skip_parameter_validation is not None:
local_config["skip_parameter_validation"] = skip_parameter_validation

Expand All @@ -223,6 +240,7 @@ def config_context(
array_api_dispatch=None,
transform_output=None,
enable_metadata_routing=None,
metadata_request_policy=None,
skip_parameter_validation=None,
):
"""Context manager for global scikit-learn configuration.
Expand Down Expand Up @@ -306,7 +324,7 @@ def config_context(
.. versionadded:: 1.4
`"polars"` option was added.

enable_metadata_routing : bool, default=None
enable_metadata_routing : bool, str, default=None
Enable metadata routing. By default this feature is disabled.

Refer to :ref:`metadata routing user guide <metadata_routing>` for more
Expand All @@ -318,6 +336,19 @@ def config_context(

.. versionadded:: 1.3

metadata_request_policy : str, default=None
Configure the default metadata request policy.

The default value of this configuration is "empty". Refer to :ref:`metadata
routing user guide <metadata_routing>` for more details.

- `"empty"`: No metadata is requested by default.
- `"auto"`: Metadata is requested if the consumer has flagged it as an
auto-request.
- `None`: Configuration is unchanged.

.. versionadded:: 1.8

skip_parameter_validation : bool, default=None
If `True`, disable the validation of the hyper-parameters' types and values in
the fit method of estimators and for arguments passed to public helper
Expand Down Expand Up @@ -367,6 +398,7 @@ def config_context(
array_api_dispatch=array_api_dispatch,
transform_output=transform_output,
enable_metadata_routing=enable_metadata_routing,
metadata_request_policy=metadata_request_policy,
skip_parameter_validation=skip_parameter_validation,
)

Expand Down
36 changes: 36 additions & 0 deletions sklearn/tests/test_metadata_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
SIMPLE_METHODS,
MethodMetadataRequest,
MethodPair,
_auto_routing_enabled,
_MetadataRequester,
request_is_alias,
request_is_valid,
Expand Down Expand Up @@ -1156,3 +1157,38 @@ def fit(self, X, y, sample_weight=None):
# Test positional arguments error after making the descriptor method unbound.
with pytest.raises(TypeError, match=error_message):
A().set_fit_request(True)


@pytest.mark.parametrize(
"enable_metadata_routing, default_routing",
[
(True, False),
(False, False),
("default_routing", True),
],
)
def test_default_routing_disabled(enable_metadata_routing, default_routing):
"""Check correctness of _auto_routing_enabled."""
with config_context(enable_metadata_routing=enable_metadata_routing):
assert _auto_routing_enabled() == default_routing


def test_default_instance_routing_overrides_class_level():
"""Test that instance-level default routing overrides class-level."""

class DefaultRoutingEstimator(BaseEstimator):
__metadata_request__fit = {"prop": False}

def __sklearn_default_request__(self):
values = super().__sklearn_default_request__()
values["fit"]["prop"] = True # Override class-level False with True
values["predict"]["prop"] = True # Add new method request
return values

est = DefaultRoutingEstimator()

with config_context(enable_metadata_routing="default_routing"):
# Instance-level True should override class-level False
assert est.get_metadata_routing().fit.requests["prop"] is True
# New method request should be present
assert est.get_metadata_routing().predict.requests["prop"] is True
106 changes: 95 additions & 11 deletions sklearn/utils/_metadata_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,19 @@ def _routing_enabled():
return get_config().get("enable_metadata_routing", False)


def _auto_routing_enabled():
"""Return whether auto-requested metadata routing is enabled.

.. versionadded:: 1.8

Returns
-------
enabled : bool
Whether auto-requested metadata routing is enabled.
"""
return get_config().get("metadata_request_policy", "empty") == "auto"


def _raise_for_params(params, owner, method, allow=None):
"""Raise an error if metadata routing is not enabled and params are passed.

Expand Down Expand Up @@ -1385,7 +1398,7 @@ def __init_subclass__(cls, **kwargs):
.. [1] https://www.python.org/dev/peps/pep-0487
"""
try:
requests = cls._get_default_requests()
requests = cls._get_class_requests()
except Exception:
# if there are any issues in the default values, it will be raised
# when ``get_metadata_routing`` is called. Here we are going to
Expand Down Expand Up @@ -1444,12 +1457,30 @@ def _build_request_for_signature(cls, router, method):
return mmr

@classmethod
def _get_default_requests(cls):
"""Collect default request values.

This method combines the information present in ``__metadata_request__*``
class attributes, as well as determining request keys from method
signatures.
def _get_class_requests(cls):
"""Collect class level request values.

This method serves two purposes:

1. During class creation via `__init_subclass__`, it determines what metadata
routing methods should be created. It does this by:
- Collecting metadata request info from `__metadata_request__*` class
attributes
- Analyzing method signatures for implicit metadata parameters
The collected information is used to create `set_{method}_request` methods
(e.g. set_fit_request) that allow runtime configuration of metadata
routing.

2. Before the user sets any specific routing, via `_get_default_requests`, it
provides the default metadata routing configuration for the instance. This
ensures each instance starts with the class-level routing settings before any
instance specific configurations are applied.

For example, if a method's signature includes `sample_weight`, this method will:
- During class creation: Create a `set_{method}_request` method to configure
how `sample_weight` should be routed
- Right after initialization: Provide the default routing configuration for
`sample_weight` based on class attributes and method signatures
"""
requests = MetadataRequest(owner=cls.__name__)

Expand Down Expand Up @@ -1487,12 +1518,54 @@ class attributes, as well as determining request keys from method

return requests

def _get_metadata_request(self):
def _get_default_requests(self, **auto_requests):
"""Get default request values for this object.

This method combines class level default values returned by
`_get_class_requests()` with instance specific auto-requested metadata.

Parameters
----------
**auto_requests : str or list of str
Auto-requested metadata. These metadata request values override default
request values if `set_config(metadata_request_policy="auto")` is set.

Call pattern is `_get_default_request(fit="sample_weight",
predict=["metadata1", "metadata2"])`.

.. versionadded:: 1.8

"""
requests = self._get_class_requests()
if _auto_routing_enabled():
for method, method_requests in auto_requests.items():
method_requests = (
[method_requests]
if isinstance(method_requests, str)
else method_requests
)
for param in method_requests:
getattr(requests, method).add_request(param=param, alias=True)

return requests

def _get_metadata_request(self, **auto_requests):
"""Get requested data properties.

Please check :ref:`User Guide <metadata_routing>` on how the routing
mechanism works.

Parameters
----------
**auto_requests : str or list of str
Auto-requested metadata. These metadata request values override default
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default are the empty requests, right?

Suggested change
Auto-requested metadata. These metadata request values override default
Auto-requested metadata. These metadata request values override empty

request values if `set_config(metadata_request_policy="auto")` is set.

Call pattern is `_get_metadata_request(fit="sample_weight",
predict=["metadata1", "metadata2"])`.

.. versionadded:: 1.8

Returns
-------
request : MetadataRequest
Expand All @@ -1501,23 +1574,34 @@ def _get_metadata_request(self):
if hasattr(self, "_metadata_request"):
requests = get_routing_for_object(self._metadata_request)
else:
requests = self._get_default_requests()
requests = self._get_default_requests(**auto_requests)

return requests

def get_metadata_routing(self):
def get_metadata_routing(self, **auto_requests):
"""Get metadata routing of this object.

Please check :ref:`User Guide <metadata_routing>` on how the routing
mechanism works.

Parameters
----------
**auto_requests : str or list of str
Auto-requested metadata. These metadata request values override default
request values if `set_config(metadata_request_policy="auto")` is set.
Comment on lines +1590 to +1591
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe explain this a bit better:

Suggested change
Auto-requested metadata. These metadata request values override default
request values if `set_config(metadata_request_policy="auto")` is set.
Auto-requested metadata. Keyword arguments where the key is a method
name (e.g., 'fit', 'predict') and the value is a metadata request
(either a str or a list of str). These override default request values
if `set_config(metadata_request_policy="auto")` is set.


Call pattern is `get_metadata_routing(fit="sample_weight",
predict=["metadata1", "metadata2"])`.

.. versionadded:: 1.8

Returns
-------
routing : MetadataRequest
A :class:`~sklearn.utils.metadata_routing.MetadataRequest` encapsulating
routing information.
"""
return self._get_metadata_request()
return self._get_metadata_request(**auto_requests)


# Process Routing in Routers
Expand Down
Loading