Skip to content

FEAT (alt3) allow setting auto routed strategy on objects #31413

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: main
Choose a base branch
from

Conversation

adrinjalali
Copy link
Member

This alternative allows the case for DefaultRoutingClassifier4 in the following script. I think I prefer this alternative to the other ones.

import pytest
import numpy as np
import sklearn
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.datasets import make_classification

sklearn.set_config(enable_metadata_routing=True, metadata_request_policy="auto")


class BaseMyClassifier(ClassifierMixin, BaseEstimator):
    __metadata_request__fit = {"sample_weight": True}

    def fit(self, X, y, sample_weight=None):
        self.classes_ = np.array([0, 1])
        print(sample_weight)
        return self

    def predict(self, X, groups=None):
        print(groups)
        return np.ones(len(X))


class DefaultRoutingClassifier1(BaseMyClassifier):
    def get_metadata_routing(self, **auto_requests):
        # Each instance can configure metadata which should be requested by default if
        # `set_config(metadata_request_policy="auto")` is set. These request values are
        # passed to the parent's `get_metadata_routing` method.
        return super().get_metadata_routing(predict="groups")


class DefaultRoutingClassifier2(BaseMyClassifier):
    def get_metadata_routing(self):
        return super().get_metadata_routing_with_auto_requests(predict="groups")


class DefaultRoutingClassifier3(BaseMyClassifier):
    def __sklearn_get_auto_requests__(self):
        return {"predict": ["groups"]}


class DefaultRoutingClassifier4(BaseMyClassifier):
    def get_metadata_routing(self):
        requests = super().get_metadata_routing()
        requests.predict.add_auto_request("groups")
        return requests


X, y = make_classification()

print("case 1")
pipeline = Pipeline(
    [
        (
            "scaler",
            StandardScaler()
            .set_fit_request(sample_weight=False)
            .set_partial_fit_request(sample_weight=True),
        ),
        (
            "classifier",
            DefaultRoutingClassifier4().set_predict_request(groups="my_groups"),
        ),
    ]
)
pipeline.fit(X, y, sample_weight=np.ones(len(X)))
pipeline.predict(X, my_groups=np.ones(len(X)) + 1)

print("case 2")
pipeline = Pipeline(
    [
        ("scaler", StandardScaler().set_fit_request(sample_weight=False)),
        ("classifier", DefaultRoutingClassifier4()),
    ]
)
pipeline.fit(X, y, sample_weight=np.ones(len(X)))
pipeline.predict(X, groups=np.ones(len(X)) + 1)

print("case 3")
pipeline = Pipeline(
    [
        ("scaler", StandardScaler().set_fit_request(sample_weight=False)),
        ("classifier", DefaultRoutingClassifier4().set_predict_request(groups=False)),
    ]
)
pipeline.fit(X, y, sample_weight=np.ones(len(X)))
with pytest.raises(TypeError, match="Pipeline.predict got unexpected argument"):
    pipeline.predict(X, groups=np.ones(len(X)) + 1)

cc @StefanieSenger @ogrisel @antoinebaker

Copy link

github-actions bot commented May 22, 2025

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: a27cee9. Link to the linter CI: here

Copy link
Contributor

@StefanieSenger StefanieSenger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have looked through this PR and commented a bit, some questions, some suggestions for improvement.

As I think I have missed this: What is the difference in usage for the users between this PR, and the previous version of defining an auto request strategy (#31401)? Edit from 5 days later: I of cause meant the difference between this PR and what #31401 was supposed to become, not what it is.

Comment on lines +573 to +576
def actualize_auto_requests(self):
for method in SIMPLE_METHODS:
getattr(self, method).actualize_auto_requests()
return self
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For simplicity, could this be a private method instead?

I don't think many people would really want to use such a low-level method.

I believe making it possible to use get_metadata_routing().{method_name}.add_auto_request() in combination with get_routing_for_object(obj) gives enough room for any fine tuning and inspection and exposing this intermediate step could be confusing and requires too much explaining of when it's useful.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it matters. We have documented in the example file how people should be using it anyways.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might be misunderstanding, but I don’t see actualize_auto_requests documented in the example file.

My thought was that when third party developers call add_auto_request() and also use get_routing_for_object() (which internally calls actualize_auto_requests()) there’s no clear external use case for calling this method directly.

Comment on lines +380 to +387
def add_auto_request(self, *params):
self._auto_requests.update(params)
return self

def actualize_auto_requests(self):
for param in self._auto_requests:
self.add_request(param=param, alias=True)
return self
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be more convenient to add the auto-requests directly to the MetadataRequest object itself (as a **kwarg)? This way, a user can (and has to) define all the auto-set metadata in the same place, which may be easier to read with line breaks and black-formatted.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

those objects are not created by the user. They're created in our functions here. The user only modifies them.

Copy link
Contributor

@StefanieSenger StefanieSenger Jun 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant developers who define a custom meta-estimator, not end-users.

I would then look something like this

            requests = super().get_metadata_routing()
            requests.add_auto_request(fit="prop", predict="prop")

rather than like this, how it is now:

            requests = super().get_metadata_routing()
            requests.fit.add_auto_request("prop")
            requests.predict.add_auto_request("prop")

And add_auto_request() would be a method on MetadataRequest instead of on MethodMetadataRequest, which I think spares a few iterations?

Comment on lines +869 to +870
elif isinstance(obj, MetadataRouter):
self._self_request = deepcopy(obj._self_request)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please explain this part, @adrinjalali:

        elif isinstance(obj, MetadataRouter):
            self._self_request = deepcopy(obj._self_request)

If a MetadataRouter object is passed to another MetadataRouter object' add_self_request method, the latter overtakes the former's _self_request attribute? In which case would this happen?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it really shouldn't happen. This only happens in our tests and here it covers the edge case if it happens; but I don't think it would.

Comment on lines 474 to 479
# 1. Class-level defaults using `__metadata_request__method` class attributes,
# which set default request values for all instances of a class, and can even
# remove a metadata from the metadata routing machinery if necessary.
# 2. Instance-level defaults via the `add_auto_request` method, which would only
# request the metadata if ``set_config(metadata_request_policy="auto")`` is
# set.
Copy link
Contributor

@StefanieSenger StefanieSenger May 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's also document right here that the instance level requests will overwrite the class level requests and can be overwritten by user-set requests.

Comment on lines +477 to +479
# 2. Instance-level defaults via the `add_auto_request` method, which would only
# request the metadata if ``set_config(metadata_request_policy="auto")`` is
# set.
Copy link
Contributor

@StefanieSenger StefanieSenger May 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it is fine to only communicate the class level auto default setting possibility and not burden third party developer's with the instance level option.

Here in our case, the instance level auto default setting seem to only depend on the way set_config() is configured and - while being on instance level technically - doesn't really come with the flexibility of instance level adjustments it seems to me. (Maybe I am lacking the fantasy what else people could do with that, though.)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it's a feature, I rather document it properly instead of skipping it.

Copy link
Contributor

@StefanieSenger StefanieSenger Jun 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I had meant was that the instance-level default setting doesn't need to be a feature to begin with.

Not sure if I get this right, but what I am thinking is:
Third party developers could only make use of it if they also have a set_config in their library that also checks if auto requests are enabled. The same for developers who customise a scikit-learn compatible estimator for their corporate project.

And then the functionality of setting instance level requests would be very limited. It's very similar to setting defaults at class level, only it adds a extra little condition of having auto requests enabled in the configs, no? 🤷‍♀️

(e.g. set_fit_request) that allow runtime configuration of metadata
routing.

2. Before the user sets any specific routing, via `_get_default_requests`, it
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add_auto_request is meant, right?

Also, the term "user" means too many different things here, can we express this in a more precise way?

Suggested change
2. Before the user sets any specific routing, via `_get_default_requests`, it
2. Before the developer of an estimator sets any specific routing,
via `add_auto_request`, it

# %%
# And now with default routing enabled:
with config_context(enable_metadata_routing="default_routing"):
print_routing(clf)
Copy link
Contributor

@StefanieSenger StefanieSenger May 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

print_routing(), using the corrected config like this:

with config_context(metadata_request_policy="auto"):
    print_routing(clf)

actually prints:

{'fit': {'sample_weight': True},
 'predict': {'groups': None},
 'score': {'sample_weight': None}}

I think it should print 'predict': {'groups': True} and it seems that print_routing is using estimator without calling actualize_auto_requests on its MetadataRequest?

print(get_routing_for_object(clf)) works as expected.

return np.ones(len(X))


# Let's see the default routing configuration
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would not use the word "default" here. The default global configuration is "empty", as we have already established in config_context and in set_config.

Here, the default has been changed by adding a class level attribute that defines that sample_weight is getting routed to fit().

# %%
# .. _metadata_routing_auto_request:
#
# Auto-Requesting Metadata
Copy link
Contributor

@StefanieSenger StefanieSenger May 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking about the term "auto request".

I don't think it makes a difference to the third party developers, that this is accomplished setting requests via the instance. Also, auto-requests is only the second of two points mentioned here. The result of both is an automated routing for the users.

In the docs (specifically here) we could talk about different ways to define an automated routing (instead of request).

And to distinguish set_predict_request(groups=True) and requests.predict.add_auto_request("groups"), I think the latter could instead be called add_auto_routing.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we're not adding routing here per-say, we're adding a default request on the consumer. Routing is set on routers, while here we change default values for consumers.

For third party developers the difference between the two ways is that one is always there, the other is only there if the user sets a global flag.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes thanks, I can see that now.

Co-authored-by: Stefanie Senger <91849487+StefanieSenger@users.noreply.github.com>
Copy link
Member Author

@adrinjalali adrinjalali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I think I have missed this: What is the difference in usage for the users between this PR, and the previous version of defining an auto request strategy (#31401)?

There's no difference for users. The difference is only on how it's implemented.

# %%
# .. _metadata_routing_auto_request:
#
# Auto-Requesting Metadata
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we're not adding routing here per-say, we're adding a default request on the consumer. Routing is set on routers, while here we change default values for consumers.

For third party developers the difference between the two ways is that one is always there, the other is only there if the user sets a global flag.

Comment on lines +477 to +479
# 2. Instance-level defaults via the `add_auto_request` method, which would only
# request the metadata if ``set_config(metadata_request_policy="auto")`` is
# set.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it's a feature, I rather document it properly instead of skipping it.

Comment on lines +380 to +387
def add_auto_request(self, *params):
self._auto_requests.update(params)
return self

def actualize_auto_requests(self):
for param in self._auto_requests:
self.add_request(param=param, alias=True)
return self
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

those objects are not created by the user. They're created in our functions here. The user only modifies them.

Comment on lines +573 to +576
def actualize_auto_requests(self):
for method in SIMPLE_METHODS:
getattr(self, method).actualize_auto_requests()
return self
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it matters. We have documented in the example file how people should be using it anyways.

Comment on lines +869 to +870
elif isinstance(obj, MetadataRouter):
self._self_request = deepcopy(obj._self_request)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it really shouldn't happen. This only happens in our tests and here it covers the edge case if it happens; but I don't think it would.

@antoinebaker
Copy link
Contributor

Hello @adrinjalali and @StefanieSenger ! Sorry I think I'm a bit lost ;)

How does this new mechanism allow to solve the original intent of #30887, ie to have a config where scorers and estimators always request sample_weight (if they can) ?

For example, how could we implement a Mixin so that for all sklearn estimators, if the fit method has sample_weight in its signature, and the metadata_request_policy="auto" config is enabled, then it will automatically add sample_weight to the fit request ?

@adrinjalali
Copy link
Member Author

For example, how could we implement a Mixin so that for all sklearn estimators, if the fit method has sample_weight in its signature, and the metadata_request_policy="auto" config is enabled, then it will automatically add sample_weight to the fit request ?

The idea is not to have something that automatically requests sample_weight if it's present in the signature. The idea is to have a Mixin which adds an auto request on sample_weight and we'd use that mixin only in places where we want to have it.

Comment on lines +133 to +134
def _auto_routing_enabled():
"""Return whether auto-requested metadata routing is enabled.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it should be called _auto_requests_enabled ?

Comment on lines +384 to +387
def actualize_auto_requests(self):
for param in self._auto_requests:
self.add_request(param=param, alias=True)
return self
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we prevent adding auto requests when metadata_request_policy!="auto" ?

Suggested change
def actualize_auto_requests(self):
for param in self._auto_requests:
self.add_request(param=param, alias=True)
return self
def actualize_auto_requests(self):
if _auto_routing_enabled():
for param in self._auto_requests:
self.add_request(param=param, alias=True)
return self

Comment on lines +1195 to +1199
if _auto_routing_enabled():
if hasattr(requests, "actualize_auto_requests"):
requests.actualize_auto_requests()
if getattr(requests, "_self_request", None):
requests._self_request.actualize_auto_requests()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess here is where all the magic happens :)

The auto requests, defined at the instance level inside obj.get_metadata_routing(), are taken into consideration only when the auto request policy is enabled, and then added to the "default" requests, possibly overriding the class requests. The user can still specify requests afterwards. Did I get this right ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe so. This is at least how I also understood it.

Copy link
Contributor

@StefanieSenger StefanieSenger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your clarifications, @adrinjalali.

For example, how could we implement a Mixin so that for all sklearn estimators, if the fit method has sample_weight in its signature, and the metadata_request_policy="auto" config is enabled, then it will automatically add sample_weight to the fit request ?

I was thinking maybe something like combining available_if(_estimator_has("fit")) with inspect.signature(func) to set the class requests in a mixin could work? And maybe use the instance-level requests for a transition period until auto routing would be the default?

@antoinebaker I might be missing some context, but it seems you where asking whether setting class requests via mixins would be possible while auto metadata routing is not the default yet. Is such a roadmap existing, @adrinjalali?

# %%
# .. _metadata_routing_auto_request:
#
# Auto-Requesting Metadata
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes thanks, I can see that now.

Comment on lines +477 to +479
# 2. Instance-level defaults via the `add_auto_request` method, which would only
# request the metadata if ``set_config(metadata_request_policy="auto")`` is
# set.
Copy link
Contributor

@StefanieSenger StefanieSenger Jun 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I had meant was that the instance-level default setting doesn't need to be a feature to begin with.

Not sure if I get this right, but what I am thinking is:
Third party developers could only make use of it if they also have a set_config in their library that also checks if auto requests are enabled. The same for developers who customise a scikit-learn compatible estimator for their corporate project.

And then the functionality of setting instance level requests would be very limited. It's very similar to setting defaults at class level, only it adds a extra little condition of having auto requests enabled in the configs, no? 🤷‍♀️

Comment on lines +380 to +387
def add_auto_request(self, *params):
self._auto_requests.update(params)
return self

def actualize_auto_requests(self):
for param in self._auto_requests:
self.add_request(param=param, alias=True)
return self
Copy link
Contributor

@StefanieSenger StefanieSenger Jun 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant developers who define a custom meta-estimator, not end-users.

I would then look something like this

            requests = super().get_metadata_routing()
            requests.add_auto_request(fit="prop", predict="prop")

rather than like this, how it is now:

            requests = super().get_metadata_routing()
            requests.fit.add_auto_request("prop")
            requests.predict.add_auto_request("prop")

And add_auto_request() would be a method on MetadataRequest instead of on MethodMetadataRequest, which I think spares a few iterations?

Comment on lines +573 to +576
def actualize_auto_requests(self):
for method in SIMPLE_METHODS:
getattr(self, method).actualize_auto_requests()
return self
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might be misunderstanding, but I don’t see actualize_auto_requests documented in the example file.

My thought was that when third party developers call add_auto_request() and also use get_routing_for_object() (which internally calls actualize_auto_requests()) there’s no clear external use case for calling this method directly.

Comment on lines +1195 to +1199
if _auto_routing_enabled():
if hasattr(requests, "actualize_auto_requests"):
requests.actualize_auto_requests()
if getattr(requests, "_self_request", None):
requests._self_request.actualize_auto_requests()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe so. This is at least how I also understood it.

@adrinjalali
Copy link
Member Author

I was thinking maybe something like combining available_if(_estimator_has("fit")) with inspect.signature(func) to set the class requests in a mixin could work? And maybe use the instance-level requests for a transition period until auto routing would be the default?

It's certainly possible to inspect during runtime and do the magic, but I'd be opposed to the idea. I rather have the mixin be included only when we want to add the request, instead of checking the signature. If we want to check the signature and request metadata, then we can have a very different approach where metadata is always requested, no matter which metadata. I'm not opposed to that option, but that's a very different approach. I just don't want to special case sample_weight as a metadata in this approach. And as far as I remember from our meeting, we decided not to go down that path.

@antoinebaker
Copy link
Contributor

Thanks for the clarification @adrinjalali and @StefanieSenger, I was confused about the goal of the auto request policy.

The idea is not to have something that automatically requests sample_weight if it's present in the signature. The idea is to have a Mixin which adds an auto request on sample_weight and we'd use that mixin only in places where we want to have it.

If I understand correctly, would you suggest the following plan to tackle #30887 ?

  1. Have a mixin that add the auto request for sample_weight in fit
class AutoRequestSampleWeightFit():
    def get_metadata_routing(self):
        requests = super().get_metadata_routing()
        requests.fit.add_auto_request("sample_weight")
        return requests
  1. Add this mixin to the sklearn estimators that need it (currenlty we think all of them that support sample_weight in fit).

  2. Add a common estimator check that inspect the fit signature and the auto request of sample_weight to make sure that we did not forget anyone.

  3. Do the same for scorers.

  4. From the user perspective in ENH default routing policy for sample weight #30887, having to explicitly request the sample_weight for all consumers or not is a matter of config.

# current policy: need the verbose requests
sklearn.set_config(enable_metadata_routing=True, metadata_request_policy="empty")
scaler.set_fit_request(sample_weight=True)
spline.set_fit_request(sample_weight=True)
logistic.set_fit_request(sample_weight=True)
logistic.set_score_request(sample_weight=True)
pipe.set_score_request(sample_weight=True)
# new policy: no need for requests, `sample_weight` in meta-estimator works out of the box
sklearn.set_config(enable_metadata_routing=True, metadata_request_policy="auto")

@adrinjalali
Copy link
Member Author

Yep, that's how I see it @antoinebaker .

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants