-
-
Notifications
You must be signed in to change notification settings - Fork 26.2k
FEAT allow configuring automatically requested metadata #31401
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
4e9c8a7
d4bd852
2c38534
53152e8
d14200e
fcbc431
f24eecb
a35bf40
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
- |Feature| Added support for default metadata routing through | ||
``set_config(enable_metadata_routing="default_routing")``. This enables | ||
automatic routing of common metadata like ``sample_weight`` and ``groups`` by | ||
default. Estimators can now define instance-level default routing via | ||
``__sklearn_default_request__`` method, which complements the existing | ||
class-level defaults set through ``__metadata_request__*`` attributes. | ||
By `Adrin Jalali`_. |
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -19,6 +19,7 @@ | |||||||||
"array_api_dispatch": False, | ||||||||||
"transform_output": "default", | ||||||||||
"enable_metadata_routing": False, | ||||||||||
"metadata_request_policy": "empty", | ||||||||||
"skip_parameter_validation": False, | ||||||||||
} | ||||||||||
_threadlocal = threading.local() | ||||||||||
|
@@ -67,6 +68,7 @@ def set_config( | |||||||||
array_api_dispatch=None, | ||||||||||
transform_output=None, | ||||||||||
enable_metadata_routing=None, | ||||||||||
metadata_request_policy=None, | ||||||||||
skip_parameter_validation=None, | ||||||||||
): | ||||||||||
"""Set global scikit-learn configuration. | ||||||||||
|
@@ -151,7 +153,7 @@ def set_config( | |||||||||
.. versionadded:: 1.4 | ||||||||||
`"polars"` option was added. | ||||||||||
|
||||||||||
enable_metadata_routing : bool, default=None | ||||||||||
enable_metadata_routing : bool, str, default=None | ||||||||||
Enable metadata routing. By default this feature is disabled. | ||||||||||
|
||||||||||
Refer to :ref:`metadata routing user guide <metadata_routing>` for more | ||||||||||
|
@@ -163,6 +165,19 @@ def set_config( | |||||||||
|
||||||||||
.. versionadded:: 1.3 | ||||||||||
|
||||||||||
metadata_request_policy : str, default=None | ||||||||||
Configure the default metadata request policy. | ||||||||||
|
||||||||||
The default value of this configuration is "empty". Refer to :ref:`metadata | ||||||||||
routing user guide <metadata_routing>` for more details. | ||||||||||
|
||||||||||
- `"empty"`: No metadata is requested by default. | ||||||||||
- `"auto"`: Metadata is requested if the consumer has flagged it as an | ||||||||||
auto-request. | ||||||||||
Comment on lines
+175
to
+176
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Trying to find a formulation that communicates that this entails that a sub-part of the metadata can be requested. I think it's not clear otherwise. I would also avoid talking about "the" consumer, since there can be several.
Suggested change
|
||||||||||
- `None`: Configuration is unchanged. | ||||||||||
|
||||||||||
.. versionadded:: 1.8 | ||||||||||
|
||||||||||
skip_parameter_validation : bool, default=None | ||||||||||
If `True`, disable the validation of the hyper-parameters' types and values in | ||||||||||
the fit method of estimators and for arguments passed to public helper | ||||||||||
|
@@ -207,6 +222,8 @@ def set_config( | |||||||||
local_config["transform_output"] = transform_output | ||||||||||
if enable_metadata_routing is not None: | ||||||||||
local_config["enable_metadata_routing"] = enable_metadata_routing | ||||||||||
if metadata_request_policy is not None: | ||||||||||
local_config["metadata_request_policy"] = metadata_request_policy | ||||||||||
if skip_parameter_validation is not None: | ||||||||||
local_config["skip_parameter_validation"] = skip_parameter_validation | ||||||||||
|
||||||||||
|
@@ -223,6 +240,7 @@ def config_context( | |||||||||
array_api_dispatch=None, | ||||||||||
transform_output=None, | ||||||||||
enable_metadata_routing=None, | ||||||||||
metadata_request_policy=None, | ||||||||||
skip_parameter_validation=None, | ||||||||||
): | ||||||||||
"""Context manager for global scikit-learn configuration. | ||||||||||
|
@@ -306,7 +324,7 @@ def config_context( | |||||||||
.. versionadded:: 1.4 | ||||||||||
`"polars"` option was added. | ||||||||||
|
||||||||||
enable_metadata_routing : bool, default=None | ||||||||||
enable_metadata_routing : bool, str, default=None | ||||||||||
Enable metadata routing. By default this feature is disabled. | ||||||||||
|
||||||||||
Refer to :ref:`metadata routing user guide <metadata_routing>` for more | ||||||||||
|
@@ -318,6 +336,19 @@ def config_context( | |||||||||
|
||||||||||
.. versionadded:: 1.3 | ||||||||||
|
||||||||||
metadata_request_policy : str, default=None | ||||||||||
Configure the default metadata request policy. | ||||||||||
|
||||||||||
The default value of this configuration is "empty". Refer to :ref:`metadata | ||||||||||
routing user guide <metadata_routing>` for more details. | ||||||||||
|
||||||||||
- `"empty"`: No metadata is requested by default. | ||||||||||
- `"auto"`: Metadata is requested if the consumer has flagged it as an | ||||||||||
auto-request. | ||||||||||
- `None`: Configuration is unchanged. | ||||||||||
|
||||||||||
.. versionadded:: 1.8 | ||||||||||
|
||||||||||
skip_parameter_validation : bool, default=None | ||||||||||
If `True`, disable the validation of the hyper-parameters' types and values in | ||||||||||
the fit method of estimators and for arguments passed to public helper | ||||||||||
|
@@ -367,6 +398,7 @@ def config_context( | |||||||||
array_api_dispatch=array_api_dispatch, | ||||||||||
transform_output=transform_output, | ||||||||||
enable_metadata_routing=enable_metadata_routing, | ||||||||||
metadata_request_policy=metadata_request_policy, | ||||||||||
skip_parameter_validation=skip_parameter_validation, | ||||||||||
) | ||||||||||
|
||||||||||
|
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -130,6 +130,19 @@ def _routing_enabled(): | |||||||||||||
return get_config().get("enable_metadata_routing", False) | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
def _auto_routing_enabled(): | ||||||||||||||
"""Return whether auto-requested metadata routing is enabled. | ||||||||||||||
|
||||||||||||||
.. versionadded:: 1.8 | ||||||||||||||
|
||||||||||||||
Returns | ||||||||||||||
------- | ||||||||||||||
enabled : bool | ||||||||||||||
Whether auto-requested metadata routing is enabled. | ||||||||||||||
""" | ||||||||||||||
return get_config().get("metadata_request_policy", "empty") == "auto" | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
def _raise_for_params(params, owner, method, allow=None): | ||||||||||||||
"""Raise an error if metadata routing is not enabled and params are passed. | ||||||||||||||
|
||||||||||||||
|
@@ -1385,7 +1398,7 @@ def __init_subclass__(cls, **kwargs): | |||||||||||||
.. [1] https://www.python.org/dev/peps/pep-0487 | ||||||||||||||
""" | ||||||||||||||
try: | ||||||||||||||
requests = cls._get_default_requests() | ||||||||||||||
requests = cls._get_class_requests() | ||||||||||||||
except Exception: | ||||||||||||||
# if there are any issues in the default values, it will be raised | ||||||||||||||
# when ``get_metadata_routing`` is called. Here we are going to | ||||||||||||||
|
@@ -1444,12 +1457,30 @@ def _build_request_for_signature(cls, router, method): | |||||||||||||
return mmr | ||||||||||||||
|
||||||||||||||
@classmethod | ||||||||||||||
def _get_default_requests(cls): | ||||||||||||||
"""Collect default request values. | ||||||||||||||
|
||||||||||||||
This method combines the information present in ``__metadata_request__*`` | ||||||||||||||
class attributes, as well as determining request keys from method | ||||||||||||||
signatures. | ||||||||||||||
def _get_class_requests(cls): | ||||||||||||||
"""Collect class level request values. | ||||||||||||||
|
||||||||||||||
This method serves two purposes: | ||||||||||||||
|
||||||||||||||
1. During class creation via `__init_subclass__`, it determines what metadata | ||||||||||||||
routing methods should be created. It does this by: | ||||||||||||||
- Collecting metadata request info from `__metadata_request__*` class | ||||||||||||||
attributes | ||||||||||||||
- Analyzing method signatures for implicit metadata parameters | ||||||||||||||
The collected information is used to create `set_{method}_request` methods | ||||||||||||||
(e.g. set_fit_request) that allow runtime configuration of metadata | ||||||||||||||
routing. | ||||||||||||||
|
||||||||||||||
2. Before the user sets any specific routing, via `_get_default_requests`, it | ||||||||||||||
provides the default metadata routing configuration for the instance. This | ||||||||||||||
ensures each instance starts with the class-level routing settings before any | ||||||||||||||
instance specific configurations are applied. | ||||||||||||||
|
||||||||||||||
For example, if a method's signature includes `sample_weight`, this method will: | ||||||||||||||
- During class creation: Create a `set_{method}_request` method to configure | ||||||||||||||
how `sample_weight` should be routed | ||||||||||||||
- Right after initialization: Provide the default routing configuration for | ||||||||||||||
`sample_weight` based on class attributes and method signatures | ||||||||||||||
""" | ||||||||||||||
requests = MetadataRequest(owner=cls.__name__) | ||||||||||||||
|
||||||||||||||
|
@@ -1487,12 +1518,54 @@ class attributes, as well as determining request keys from method | |||||||||||||
|
||||||||||||||
return requests | ||||||||||||||
|
||||||||||||||
def _get_metadata_request(self): | ||||||||||||||
def _get_default_requests(self, **auto_requests): | ||||||||||||||
"""Get default request values for this object. | ||||||||||||||
|
||||||||||||||
This method combines class level default values returned by | ||||||||||||||
`_get_class_requests()` with instance specific auto-requested metadata. | ||||||||||||||
|
||||||||||||||
Parameters | ||||||||||||||
---------- | ||||||||||||||
**auto_requests : str or list of str | ||||||||||||||
Auto-requested metadata. These metadata request values override default | ||||||||||||||
request values if `set_config(metadata_request_policy="auto")` is set. | ||||||||||||||
|
||||||||||||||
Call pattern is `_get_default_request(fit="sample_weight", | ||||||||||||||
predict=["metadata1", "metadata2"])`. | ||||||||||||||
|
||||||||||||||
.. versionadded:: 1.8 | ||||||||||||||
|
||||||||||||||
""" | ||||||||||||||
requests = self._get_class_requests() | ||||||||||||||
if _auto_routing_enabled(): | ||||||||||||||
for method, method_requests in auto_requests.items(): | ||||||||||||||
method_requests = ( | ||||||||||||||
[method_requests] | ||||||||||||||
if isinstance(method_requests, str) | ||||||||||||||
else method_requests | ||||||||||||||
) | ||||||||||||||
for param in method_requests: | ||||||||||||||
getattr(requests, method).add_request(param=param, alias=True) | ||||||||||||||
|
||||||||||||||
return requests | ||||||||||||||
|
||||||||||||||
def _get_metadata_request(self, **auto_requests): | ||||||||||||||
"""Get requested data properties. | ||||||||||||||
|
||||||||||||||
Please check :ref:`User Guide <metadata_routing>` on how the routing | ||||||||||||||
mechanism works. | ||||||||||||||
|
||||||||||||||
Parameters | ||||||||||||||
---------- | ||||||||||||||
**auto_requests : str or list of str | ||||||||||||||
Auto-requested metadata. These metadata request values override default | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The default are the empty requests, right?
Suggested change
|
||||||||||||||
request values if `set_config(metadata_request_policy="auto")` is set. | ||||||||||||||
|
||||||||||||||
Call pattern is `_get_metadata_request(fit="sample_weight", | ||||||||||||||
predict=["metadata1", "metadata2"])`. | ||||||||||||||
|
||||||||||||||
.. versionadded:: 1.8 | ||||||||||||||
|
||||||||||||||
Returns | ||||||||||||||
------- | ||||||||||||||
request : MetadataRequest | ||||||||||||||
|
@@ -1501,23 +1574,34 @@ def _get_metadata_request(self): | |||||||||||||
if hasattr(self, "_metadata_request"): | ||||||||||||||
requests = get_routing_for_object(self._metadata_request) | ||||||||||||||
else: | ||||||||||||||
requests = self._get_default_requests() | ||||||||||||||
requests = self._get_default_requests(**auto_requests) | ||||||||||||||
|
||||||||||||||
return requests | ||||||||||||||
|
||||||||||||||
def get_metadata_routing(self): | ||||||||||||||
def get_metadata_routing(self, **auto_requests): | ||||||||||||||
"""Get metadata routing of this object. | ||||||||||||||
|
||||||||||||||
Please check :ref:`User Guide <metadata_routing>` on how the routing | ||||||||||||||
mechanism works. | ||||||||||||||
|
||||||||||||||
Parameters | ||||||||||||||
---------- | ||||||||||||||
**auto_requests : str or list of str | ||||||||||||||
Auto-requested metadata. These metadata request values override default | ||||||||||||||
request values if `set_config(metadata_request_policy="auto")` is set. | ||||||||||||||
Comment on lines
+1590
to
+1591
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe explain this a bit better:
Suggested change
|
||||||||||||||
|
||||||||||||||
Call pattern is `get_metadata_routing(fit="sample_weight", | ||||||||||||||
predict=["metadata1", "metadata2"])`. | ||||||||||||||
|
||||||||||||||
.. versionadded:: 1.8 | ||||||||||||||
|
||||||||||||||
Returns | ||||||||||||||
------- | ||||||||||||||
routing : MetadataRequest | ||||||||||||||
A :class:`~sklearn.utils.metadata_routing.MetadataRequest` encapsulating | ||||||||||||||
routing information. | ||||||||||||||
""" | ||||||||||||||
return self._get_metadata_request() | ||||||||||||||
return self._get_metadata_request(**auto_requests) | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
# Process Routing in Routers | ||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we rename this, since it's not a router?