Skip to content

Commit fb3c1d3

Browse files
adrinjalaliOmarManzoor
authored andcommitted
ENH warn if {transform, predict} consume metadata but no custom fit_{transform, predict} is defined (scikit-learn#26831)
Co-authored-by: Omar Salman <omar.salman@arbisoft.com>
1 parent bf4d512 commit fb3c1d3

File tree

5 files changed

+242
-3
lines changed

5 files changed

+242
-3
lines changed

doc/whats_new/v1.4.rst

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ Changelog
4141
or *Miscellaneous*.
4242
Entries should end with:
4343
:pr:`123456` by :user:`Joe Bloggs <joeongithub>`.
44-
where 123456 is the *pull request* number, not the issue number.
44+
where 123455 is the *pull request* number, not the issue number.
4545
4646

4747
:mod:`sklearn.base`
@@ -52,6 +52,12 @@ Changelog
5252
passed to the ``fit`` method of the the estimator. :pr:`26506` by `Adrin
5353
Jalali`_.
5454

55+
- |Enhancement| :meth:`base.TransformerMixin.fit_transform` and
56+
:meth:`base.OutlierMixin.fit_predict` now raise a warning if ``transform`` /
57+
``predict`` consume metadata, but no custom ``fit_transform`` / ``fit_predict``
58+
is defined in the class inheriting from them correspondingly. :pr:`26831` by
59+
`Adrin Jalali`_.
60+
5561
- |Enhancement| :func:`base.clone` now supports `dict` as input and creates a
5662
copy. :pr:`26786` by `Adrin Jalali`_.
5763

@@ -115,6 +121,14 @@ Changelog
115121
:pr:`13649` by :user:`Samuel Ronsin <samronsin>`, initiated by
116122
:user:`Patrick O'Reilly <pat-oreilly>`.
117123

124+
:mod:`sklearn.utils`
125+
....................
126+
127+
- |Enhancement| :class:`~utils.metadata_routing.MetadataRequest` and
128+
:class:`~utils.metadata_routing.MetadataRouter` now have a ``consumes`` method
129+
which can be used to check whether a given set of parameters would be consumed.
130+
:pr:`26831` by `Adrin Jalali`_.
131+
118132
Code and Documentation Contributors
119133
-----------------------------------
120134

sklearn/base.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from .exceptions import InconsistentVersionWarning
1919
from .utils import _IS_32BIT
2020
from .utils._estimator_html_repr import estimator_html_repr
21-
from .utils._metadata_requests import _MetadataRequester
21+
from .utils._metadata_requests import _MetadataRequester, _routing_enabled
2222
from .utils._param_validation import validate_parameter_constraints
2323
from .utils._set_output import _SetOutputMixin
2424
from .utils._tags import (
@@ -916,6 +916,33 @@ def fit_transform(self, X, y=None, **fit_params):
916916
"""
917917
# non-optimized default implementation; override when a better
918918
# method is possible for a given clustering algorithm
919+
920+
# we do not route parameters here, since consumers don't route. But
921+
# since it's possible for a `transform` method to also consume
922+
# metadata, we check if that's the case, and we raise a warning telling
923+
# users that they should implement a custom `fit_transform` method
924+
# to forward metadata to `transform` as well.
925+
#
926+
# For that, we calculate routing and check if anything would be routed
927+
# to `transform` if we were to route them.
928+
if _routing_enabled():
929+
transform_params = self.get_metadata_routing().consumes(
930+
method="transform", params=fit_params.keys()
931+
)
932+
if transform_params:
933+
warnings.warn(
934+
(
935+
f"This object ({self.__class__.__name__}) has a `transform`"
936+
" method which consumes metadata, but `fit_transform` does not"
937+
" forward metadata to `transform`. Please implement a custom"
938+
" `fit_transform` method to forward metadata to `transform` as"
939+
" well. Alternatively, you can explicitly do"
940+
" `set_transform_request`and set all values to `False` to"
941+
" disable metadata routed to `transform`, if that's an option."
942+
),
943+
UserWarning,
944+
)
945+
919946
if y is None:
920947
# fit method of arity 1 (unsupervised transformation)
921948
return self.fit(X, **fit_params).transform(X)
@@ -1042,6 +1069,32 @@ def fit_predict(self, X, y=None, **kwargs):
10421069
y : ndarray of shape (n_samples,)
10431070
1 for inliers, -1 for outliers.
10441071
"""
1072+
# we do not route parameters here, since consumers don't route. But
1073+
# since it's possible for a `predict` method to also consume
1074+
# metadata, we check if that's the case, and we raise a warning telling
1075+
# users that they should implement a custom `fit_predict` method
1076+
# to forward metadata to `predict` as well.
1077+
#
1078+
# For that, we calculate routing and check if anything would be routed
1079+
# to `predict` if we were to route them.
1080+
if _routing_enabled():
1081+
transform_params = self.get_metadata_routing().consumes(
1082+
method="predict", params=kwargs.keys()
1083+
)
1084+
if transform_params:
1085+
warnings.warn(
1086+
(
1087+
f"This object ({self.__class__.__name__}) has a `predict` "
1088+
"method which consumes metadata, but `fit_predict` does not "
1089+
"forward metadata to `predict`. Please implement a custom "
1090+
"`fit_predict` method to forward metadata to `predict` as well."
1091+
"Alternatively, you can explicitly do `set_predict_request`"
1092+
"and set all values to `False` to disable metadata routed to "
1093+
"`predict`, if that's an option."
1094+
),
1095+
UserWarning,
1096+
)
1097+
10451098
# override for transductive outlier detectors like LocalOulierFactor
10461099
return self.fit(X, **kwargs).predict(X)
10471100

sklearn/tests/test_base.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,13 @@
1212

1313
import sklearn
1414
from sklearn import config_context, datasets
15-
from sklearn.base import BaseEstimator, TransformerMixin, clone, is_classifier
15+
from sklearn.base import (
16+
BaseEstimator,
17+
OutlierMixin,
18+
TransformerMixin,
19+
clone,
20+
is_classifier,
21+
)
1622
from sklearn.decomposition import PCA
1723
from sklearn.exceptions import InconsistentVersionWarning
1824
from sklearn.model_selection import GridSearchCV
@@ -861,3 +867,55 @@ def transform(self, X):
861867
df_bad = _convert_container(data, constructor_name, columns_name=bad_names)
862868
with pytest.raises(ValueError, match="The feature names should match"):
863869
no_op.transform(df_bad)
870+
871+
872+
@pytest.mark.usefixtures("enable_slep006")
873+
def test_transformer_fit_transform_with_metadata_in_transform():
874+
"""Test that having a transformer with metadata for transform raises a
875+
warning when calling fit_transform."""
876+
877+
class CustomTransformer(BaseEstimator, TransformerMixin):
878+
def fit(self, X, y=None, prop=None):
879+
return self
880+
881+
def transform(self, X, prop=None):
882+
return X
883+
884+
# passing the metadata to `fit_transform` should raise a warning since it
885+
# could potentially be consumed by `transform`
886+
with pytest.warns(UserWarning, match="`transform` method which consumes metadata"):
887+
CustomTransformer().set_transform_request(prop=True).fit_transform(
888+
[[1]], [1], prop=1
889+
)
890+
891+
# not passing a metadata which can potentially be consumed by `transform` should
892+
# not raise a warning
893+
with warnings.catch_warnings(record=True) as record:
894+
CustomTransformer().set_transform_request(prop=True).fit_transform([[1]], [1])
895+
assert len(record) == 0
896+
897+
898+
@pytest.mark.usefixtures("enable_slep006")
899+
def test_outlier_mixin_fit_predict_with_metadata_in_predict():
900+
"""Test that having an OutlierMixin with metadata for predict raises a
901+
warning when calling fit_predict."""
902+
903+
class CustomOutlierDetector(BaseEstimator, OutlierMixin):
904+
def fit(self, X, y=None, prop=None):
905+
return self
906+
907+
def predict(self, X, prop=None):
908+
return X
909+
910+
# passing the metadata to `fit_predict` should raise a warning since it
911+
# could potentially be consumed by `predict`
912+
with pytest.warns(UserWarning, match="`predict` method which consumes metadata"):
913+
CustomOutlierDetector().set_predict_request(prop=True).fit_predict(
914+
[[1]], [1], prop=1
915+
)
916+
917+
# not passing a metadata which can potentially be consumed by `predict` should
918+
# not raise a warning
919+
with warnings.catch_warnings(record=True) as record:
920+
CustomOutlierDetector().set_predict_request(prop=True).fit_predict([[1]], [1])
921+
assert len(record) == 0

sklearn/tests/test_metadata_routing.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -716,6 +716,47 @@ class Consumer(BaseEstimator):
716716
assert mr.fit.requests == {"prop": None}
717717

718718

719+
def test_metadata_request_consumes_method():
720+
"""Test that MetadataRequest().consumes() method works as expected."""
721+
request = MetadataRouter(owner="test")
722+
assert request.consumes(method="fit", params={"foo"}) == set()
723+
724+
request = MetadataRequest(owner="test")
725+
request.fit.add_request(param="foo", alias=True)
726+
assert request.consumes(method="fit", params={"foo"}) == {"foo"}
727+
728+
request = MetadataRequest(owner="test")
729+
request.fit.add_request(param="foo", alias="bar")
730+
assert request.consumes(method="fit", params={"bar", "foo"}) == {"bar"}
731+
732+
733+
def test_metadata_router_consumes_method():
734+
"""Test that MetadataRouter().consumes method works as expected."""
735+
# having it here instead of parametrizing the test since `set_fit_request`
736+
# is not available while collecting the tests.
737+
cases = [
738+
(
739+
WeightedMetaRegressor(
740+
estimator=RegressorMetadata().set_fit_request(sample_weight=True)
741+
),
742+
{"sample_weight"},
743+
{"sample_weight"},
744+
),
745+
(
746+
WeightedMetaRegressor(
747+
estimator=RegressorMetadata().set_fit_request(
748+
sample_weight="my_weights"
749+
)
750+
),
751+
{"my_weights", "sample_weight"},
752+
{"my_weights"},
753+
),
754+
]
755+
756+
for obj, input, output in cases:
757+
assert obj.get_metadata_routing().consumes(method="fit", params=input) == output
758+
759+
719760
def test_metaestimator_warnings():
720761
class WeightedMetaRegressorWarn(WeightedMetaRegressor):
721762
__metadata_request__fit = {"sample_weight": metadata_routing.WARN}

sklearn/utils/_metadata_requests.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,28 @@ def _route_params(self, params):
360360
)
361361
return res
362362

363+
def _consumes(self, params):
364+
"""Check whether the given parameters are consumed by this method.
365+
366+
Parameters
367+
----------
368+
params : iterable of str
369+
An iterable of parameters to check.
370+
371+
Returns
372+
-------
373+
consumed : set of str
374+
A set of parameters which are consumed by this method.
375+
"""
376+
params = set(params)
377+
res = set()
378+
for prop, alias in self._requests.items():
379+
if alias is True and prop in params:
380+
res.add(prop)
381+
elif isinstance(alias, str) and alias in params:
382+
res.add(alias)
383+
return res
384+
363385
def _serialize(self):
364386
"""Serialize the object.
365387
@@ -408,6 +430,26 @@ def __init__(self, owner):
408430
MethodMetadataRequest(owner=owner, method=method),
409431
)
410432

433+
def consumes(self, method, params):
434+
"""Check whether the given parameters are consumed by the given method.
435+
436+
.. versionadded:: 1.4
437+
438+
Parameters
439+
----------
440+
method : str
441+
The name of the method to check.
442+
443+
params : iterable of str
444+
An iterable of parameters to check.
445+
446+
Returns
447+
-------
448+
consumed : set of str
449+
A set of parameters which are consumed by the given method.
450+
"""
451+
return getattr(self, method)._consumes(params=params)
452+
411453
def __getattr__(self, name):
412454
# Called when the default attribute access fails with an AttributeError
413455
# (either __getattribute__() raises an AttributeError because name is
@@ -736,6 +778,37 @@ def add(self, *, method_mapping, **objs):
736778
)
737779
return self
738780

781+
def consumes(self, method, params):
782+
"""Check whether the given parameters are consumed by the given method.
783+
784+
.. versionadded:: 1.4
785+
786+
Parameters
787+
----------
788+
method : str
789+
The name of the method to check.
790+
791+
params : iterable of str
792+
An iterable of parameters to check.
793+
794+
Returns
795+
-------
796+
consumed : set of str
797+
A set of parameters which are consumed by the given method.
798+
"""
799+
res = set()
800+
if self._self_request:
801+
res = res | self._self_request.consumes(method=method, params=params)
802+
803+
for _, route_mapping in self._route_mappings.items():
804+
for callee, caller in route_mapping.mapping:
805+
if caller == method:
806+
res = res | route_mapping.router.consumes(
807+
method=callee, params=params
808+
)
809+
810+
return res
811+
739812
def _get_param_names(self, *, method, return_alias, ignore_self_request):
740813
"""Get names of all metadata that can be consumed or routed by specified \
741814
method.

0 commit comments

Comments
 (0)