Skip to content

FIX MultiOutput* when sub-estimator does not accept metadata #28240

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions doc/whats_new/v1.4.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,14 @@ Changes impacting all modules
columns of the returned dataframe.
:pr:`28262` by :user:`Guillaume Lemaitre <glemaitre>`.

Metadata Routing
----------------

- |Fix| Fix :class:`multioutput.MultiOutputRegressor` and
:class:`multioutput.MultiOutputClassifier` to work with estimators that don't
consume any metadata when metadata routing is enabled.
:pr:`28240` by `Adrin Jalali`_.

Changelog
---------

Expand Down
6 changes: 4 additions & 2 deletions sklearn/multioutput.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,11 @@ def partial_fit(self, X, y, classes=None, sample_weight=None, **partial_fit_para
)

if _routing_enabled():
if sample_weight is not None:
partial_fit_params["sample_weight"] = sample_weight
routed_params = process_routing(
self,
"partial_fit",
sample_weight=sample_weight,
**partial_fit_params,
)
else:
Expand Down Expand Up @@ -248,10 +249,11 @@ def fit(self, X, y, sample_weight=None, **fit_params):
)

if _routing_enabled():
if sample_weight is not None:
fit_params["sample_weight"] = sample_weight
routed_params = process_routing(
self,
"fit",
sample_weight=sample_weight,
**fit_params,
)
else:
Expand Down
26 changes: 21 additions & 5 deletions sklearn/tests/metadata_routing_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,16 +168,32 @@ def predict(self, X, sample_weight="default", metadata="default"):
class NonConsumingClassifier(ClassifierMixin, BaseEstimator):
"""A classifier which accepts no metadata on any method."""

def __init__(self, registry=None):
self.registry = registry
def __init__(self, alpha=0.0):
self.alpha = alpha

def fit(self, X, y):
if self.registry is not None:
self.registry.append(self)

self.classes_ = np.unique(y)
return self

def partial_fit(self, X, y, classes=None):
return self

def decision_function(self, X):
return self.predict(X)

def predict(self, X):
return np.ones(len(X))


class NonConsumingRegressor(RegressorMixin, BaseEstimator):
"""A classifier which accepts no metadata on any method."""

def fit(self, X, y):
return self

def partial_fit(self, X, y):
return self

def predict(self, X):
return np.ones(len(X)) # pragma: no cover

Expand Down
1 change: 0 additions & 1 deletion sklearn/tests/test_metadata_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,6 @@ def test_assert_request_is_empty():
ConsumingClassifier(registry=_Registry()),
ConsumingRegressor(registry=_Registry()),
ConsumingTransformer(registry=_Registry()),
NonConsumingClassifier(registry=_Registry()),
WeightedMetaClassifier(estimator=ConsumingClassifier(), registry=_Registry()),
WeightedMetaRegressor(estimator=ConsumingRegressor(), registry=_Registry()),
],
Expand Down
96 changes: 75 additions & 21 deletions sklearn/tests/test_metaestimators_metadata_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@
ConsumingRegressor,
ConsumingScorer,
ConsumingSplitter,
NonConsumingClassifier,
NonConsumingRegressor,
_Registry,
assert_request_is_empty,
check_recorded_metadata,
Expand Down Expand Up @@ -97,15 +99,15 @@ def enable_slep006():
{
"metaestimator": MultiOutputRegressor,
"estimator_name": "estimator",
"estimator": ConsumingRegressor,
"estimator": "regressor",
"X": X,
"y": y_multi,
"estimator_routing_methods": ["fit", "partial_fit"],
},
{
"metaestimator": MultiOutputClassifier,
"estimator_name": "estimator",
"estimator": ConsumingClassifier,
"estimator": "classifier",
"X": X,
"y": y_multi,
"estimator_routing_methods": ["fit", "partial_fit"],
Expand All @@ -114,7 +116,7 @@ def enable_slep006():
{
"metaestimator": CalibratedClassifierCV,
"estimator_name": "estimator",
"estimator": ConsumingClassifier,
"estimator": "classifier",
"X": X,
"y": y,
"estimator_routing_methods": ["fit"],
Expand All @@ -123,15 +125,15 @@ def enable_slep006():
{
"metaestimator": ClassifierChain,
"estimator_name": "base_estimator",
"estimator": ConsumingClassifier,
"estimator": "classifier",
"X": X,
"y": y_multi,
"estimator_routing_methods": ["fit"],
},
{
"metaestimator": RegressorChain,
"estimator_name": "base_estimator",
"estimator": ConsumingRegressor,
"estimator": "regressor",
"X": X,
"y": y_multi,
"estimator_routing_methods": ["fit"],
Expand All @@ -148,7 +150,7 @@ def enable_slep006():
{
"metaestimator": GridSearchCV,
"estimator_name": "estimator",
"estimator": ConsumingClassifier,
"estimator": "classifier",
"init_args": {"param_grid": {"alpha": [0.1, 0.2]}},
"X": X,
"y": y,
Expand All @@ -162,7 +164,7 @@ def enable_slep006():
{
"metaestimator": RandomizedSearchCV,
"estimator_name": "estimator",
"estimator": ConsumingClassifier,
"estimator": "classifier",
"init_args": {"param_distributions": {"alpha": [0.1, 0.2]}},
"X": X,
"y": y,
Expand All @@ -176,7 +178,7 @@ def enable_slep006():
{
"metaestimator": HalvingGridSearchCV,
"estimator_name": "estimator",
"estimator": ConsumingClassifier,
"estimator": "classifier",
"init_args": {"param_grid": {"alpha": [0.1, 0.2]}},
"X": X,
"y": y,
Expand All @@ -190,7 +192,7 @@ def enable_slep006():
{
"metaestimator": HalvingRandomSearchCV,
"estimator_name": "estimator",
"estimator": ConsumingClassifier,
"estimator": "classifier",
"init_args": {"param_distributions": {"alpha": [0.1, 0.2]}},
"X": X,
"y": y,
Expand All @@ -204,7 +206,7 @@ def enable_slep006():
{
"metaestimator": OneVsRestClassifier,
"estimator_name": "estimator",
"estimator": ConsumingClassifier,
"estimator": "classifier",
"X": X,
"y": y,
"estimator_routing_methods": ["fit", "partial_fit"],
Expand All @@ -213,7 +215,7 @@ def enable_slep006():
{
"metaestimator": OneVsOneClassifier,
"estimator_name": "estimator",
"estimator": ConsumingClassifier,
"estimator": "classifier",
"X": X,
"y": y,
"estimator_routing_methods": ["fit", "partial_fit"],
Expand All @@ -223,7 +225,7 @@ def enable_slep006():
{
"metaestimator": OutputCodeClassifier,
"estimator_name": "estimator",
"estimator": ConsumingClassifier,
"estimator": "classifier",
"init_args": {"random_state": 42},
"X": X,
"y": y,
Expand All @@ -232,7 +234,7 @@ def enable_slep006():
{
"metaestimator": SelectFromModel,
"estimator_name": "estimator",
"estimator": ConsumingClassifier,
"estimator": "classifier",
"X": X,
"y": y,
"estimator_routing_methods": ["fit", "partial_fit"],
Expand Down Expand Up @@ -294,7 +296,7 @@ def enable_slep006():

- metaestimator: The metaestmator to be tested
- estimator_name: The name of the argument for the sub-estimator
- estimator: The sub-estimator
- estimator: The sub-estimator type, either "regressor" or "classifier"
- init_args: The arguments to be passed to the metaestimator's constructor
- X: X-data to fit and predict
- y: y-data to fit
Expand Down Expand Up @@ -345,13 +347,21 @@ def enable_slep006():
]


def get_init_args(metaestimator_info):
def get_init_args(metaestimator_info, sub_estimator_consumes):
"""Get the init args for a metaestimator

This is a helper function to get the init args for a metaestimator from
the METAESTIMATORS list. It returns an empty dict if no init args are
required.

Parameters
----------
metaestimator_info : dict
The metaestimator info from METAESTIMATORS

sub_estimator_consumes : bool
Whether the sub-estimator consumes metadata or not.

Returns
-------
kwargs : dict
Expand All @@ -373,7 +383,17 @@ def get_init_args(metaestimator_info):
if "estimator" in metaestimator_info:
estimator_name = metaestimator_info["estimator_name"]
estimator_registry = _Registry()
estimator = metaestimator_info["estimator"](estimator_registry)
sub_estimator_type = metaestimator_info["estimator"]
if sub_estimator_consumes:
if sub_estimator_type == "regressor":
estimator = ConsumingRegressor(estimator_registry)
else:
estimator = ConsumingClassifier(estimator_registry)
else:
if sub_estimator_type == "regressor":
estimator = NonConsumingRegressor()
else:
estimator = NonConsumingClassifier()
kwargs[estimator_name] = estimator
if "scorer_name" in metaestimator_info:
scorer_name = metaestimator_info["scorer_name"]
Expand Down Expand Up @@ -429,7 +449,7 @@ def test_registry_copy():
def test_default_request(metaestimator):
# Check that by default request is empty and the right type
cls = metaestimator["metaestimator"]
kwargs, *_ = get_init_args(metaestimator)
kwargs, *_ = get_init_args(metaestimator, sub_estimator_consumes=True)
instance = cls(**kwargs)
if "cv_name" in metaestimator:
# Our GroupCV splitters request groups by default, which we should
Expand Down Expand Up @@ -457,7 +477,9 @@ def test_error_on_missing_requests_for_sub_estimator(metaestimator):

for method_name in routing_methods:
for key in ["sample_weight", "metadata"]:
kwargs, (estimator, _), (scorer, _), *_ = get_init_args(metaestimator)
kwargs, (estimator, _), (scorer, _), *_ = get_init_args(
metaestimator, sub_estimator_consumes=True
)
if scorer:
scorer.set_score_request(**{key: True})
val = {"sample_weight": sample_weight, "metadata": metadata}[key]
Expand Down Expand Up @@ -501,7 +523,7 @@ def set_request(estimator, method_name):
method_kwargs = {key: val}

kwargs, (estimator, registry), (scorer, _), (cv, _) = get_init_args(
metaestimator
metaestimator, sub_estimator_consumes=True
)
if scorer:
set_request(scorer, "score")
Expand Down Expand Up @@ -530,6 +552,38 @@ def set_request(estimator, method_name):
)


@pytest.mark.parametrize("metaestimator", METAESTIMATORS, ids=METAESTIMATOR_IDS)
def test_non_consuming_estimator_works(metaestimator):
# Test that when a non-consuming estimator is given, the meta-estimator
# works w/o setting any requests.
# Regression test for https://github.com/scikit-learn/scikit-learn/issues/28239
if "estimator" not in metaestimator:
# This test only makes sense for metaestimators which have a
# sub-estimator, e.g. MyMetaEstimator(estimator=MySubEstimator())
return

def set_request(estimator, method_name):
# e.g. call set_fit_request on estimator
if is_classifier(estimator) and method_name == "partial_fit":
estimator.set_partial_fit_request(classes=True)

cls = metaestimator["metaestimator"]
X = metaestimator["X"]
y = metaestimator["y"]
routing_methods = metaestimator["estimator_routing_methods"]

for method_name in routing_methods:
kwargs, (estimator, _), (_, _), (_, _) = get_init_args(
metaestimator, sub_estimator_consumes=False
)
instance = cls(**kwargs)
set_request(estimator, method_name)
method = getattr(instance, method_name)
extra_method_args = metaestimator.get("method_args", {}).get(method_name, {})
# This following line should pass w/o raising a routing error.
method(X, y, **extra_method_args)


@pytest.mark.parametrize("metaestimator", METAESTIMATORS, ids=METAESTIMATOR_IDS)
def test_metadata_is_routed_correctly_to_scorer(metaestimator):
"""Test that any requested metadata is correctly routed to the underlying
Expand All @@ -544,7 +598,7 @@ def test_metadata_is_routed_correctly_to_scorer(metaestimator):

for method_name in routing_methods:
kwargs, (estimator, _), (scorer, registry), (cv, _) = get_init_args(
metaestimator
metaestimator, sub_estimator_consumes=True
)
if estimator:
estimator.set_fit_request(sample_weight=True, metadata=True)
Expand Down Expand Up @@ -584,7 +638,7 @@ def test_metadata_is_routed_correctly_to_splitter(metaestimator):

for method_name in routing_methods:
kwargs, (estimator, _), (scorer, _), (cv, registry) = get_init_args(
metaestimator
metaestimator, sub_estimator_consumes=True
)
if estimator:
estimator.set_fit_request(sample_weight=False, metadata=False)
Expand Down