Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion doc/whats_new/v1.4.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ Changelog
or *Miscellaneous*.
Entries should end with:
:pr:`123456` by :user:`Joe Bloggs <joeongithub>`.
where 123456 is the *pull request* number, not the issue number.
where 123455 is the *pull request* number, not the issue number.


:mod:`sklearn.base`
Expand All @@ -52,6 +52,12 @@ Changelog
passed to the ``fit`` method of the the estimator. :pr:`26506` by `Adrin
Jalali`_.

- |Enhancement| :meth:`base.TransformerMixin.fit_transform` and
:meth:`base.OutlierMixin.fit_predict` now raise a warning if ``transform`` /
``predict`` consume metadata, but no custom ``fit_transform`` / ``fit_predict``
is defined in the class inheriting from them correspondingly. :pr:`26831` by
`Adrin Jalali`_.

- |Enhancement| :func:`base.clone` now supports `dict` as input and creates a
copy. :pr:`26786` by `Adrin Jalali`_.

Expand Down Expand Up @@ -115,6 +121,14 @@ Changelog
:pr:`13649` by :user:`Samuel Ronsin <samronsin>`, initiated by
:user:`Patrick O'Reilly <pat-oreilly>`.

:mod:`sklearn.utils`
....................

- |Enhancement| :class:`~utils.metadata_routing.MetadataRequest` and
:class:`~utils.metadata_routing.MetadataRouter` now have a ``consumes`` method
which can be used to check whether a given set of parameters would be consumed.
:pr:`26831` by `Adrin Jalali`_.

Code and Documentation Contributors
-----------------------------------

Expand Down
55 changes: 54 additions & 1 deletion sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from .exceptions import InconsistentVersionWarning
from .utils import _IS_32BIT
from .utils._estimator_html_repr import estimator_html_repr
from .utils._metadata_requests import _MetadataRequester
from .utils._metadata_requests import _MetadataRequester, _routing_enabled
from .utils._param_validation import validate_parameter_constraints
from .utils._set_output import _SetOutputMixin
from .utils._tags import (
Expand Down Expand Up @@ -916,6 +916,33 @@ def fit_transform(self, X, y=None, **fit_params):
"""
# non-optimized default implementation; override when a better
# method is possible for a given clustering algorithm

# we do not route parameters here, since consumers don't route. But
# since it's possible for a `transform` method to also consume
# metadata, we check if that's the case, and we raise a warning telling
# users that they should implement a custom `fit_transform` method
# to forward metadata to `transform` as well.
#
# For that, we calculate routing and check if anything would be routed
# to `transform` if we were to route them.
if _routing_enabled():
transform_params = self.get_metadata_routing().consumes(
method="transform", params=fit_params.keys()
)
if transform_params:
warnings.warn(
(
f"This object ({self.__class__.__name__}) has a `transform`"
" method which consumes metadata, but `fit_transform` does not"
" forward metadata to `transform`. Please implement a custom"
" `fit_transform` method to forward metadata to `transform` as"
" well. Alternatively, you can explicitly do"
" `set_transform_request`and set all values to `False` to"
" disable metadata routed to `transform`, if that's an option."
),
UserWarning,
)

if y is None:
# fit method of arity 1 (unsupervised transformation)
return self.fit(X, **fit_params).transform(X)
Expand Down Expand Up @@ -1042,6 +1069,32 @@ def fit_predict(self, X, y=None, **kwargs):
y : ndarray of shape (n_samples,)
1 for inliers, -1 for outliers.
"""
# we do not route parameters here, since consumers don't route. But
# since it's possible for a `predict` method to also consume
# metadata, we check if that's the case, and we raise a warning telling
# users that they should implement a custom `fit_predict` method
# to forward metadata to `predict` as well.
#
# For that, we calculate routing and check if anything would be routed
# to `predict` if we were to route them.
if _routing_enabled():
transform_params = self.get_metadata_routing().consumes(
method="predict", params=kwargs.keys()
)
if transform_params:
warnings.warn(
(
f"This object ({self.__class__.__name__}) has a `predict` "
"method which consumes metadata, but `fit_predict` does not "
"forward metadata to `predict`. Please implement a custom "
"`fit_predict` method to forward metadata to `predict` as well."
"Alternatively, you can explicitly do `set_predict_request`"
"and set all values to `False` to disable metadata routed to "
"`predict`, if that's an option."
),
UserWarning,
)

# override for transductive outlier detectors like LocalOulierFactor
return self.fit(X, **kwargs).predict(X)

Expand Down
60 changes: 59 additions & 1 deletion sklearn/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,13 @@

import sklearn
from sklearn import config_context, datasets
from sklearn.base import BaseEstimator, TransformerMixin, clone, is_classifier
from sklearn.base import (
BaseEstimator,
OutlierMixin,
TransformerMixin,
clone,
is_classifier,
)
from sklearn.decomposition import PCA
from sklearn.exceptions import InconsistentVersionWarning
from sklearn.model_selection import GridSearchCV
Expand Down Expand Up @@ -861,3 +867,55 @@ def transform(self, X):
df_bad = _convert_container(data, constructor_name, columns_name=bad_names)
with pytest.raises(ValueError, match="The feature names should match"):
no_op.transform(df_bad)


@pytest.mark.usefixtures("enable_slep006")
def test_transformer_fit_transform_with_metadata_in_transform():
"""Test that having a transformer with metadata for transform raises a
warning when calling fit_transform."""

class CustomTransformer(BaseEstimator, TransformerMixin):
def fit(self, X, y=None, prop=None):
return self

def transform(self, X, prop=None):
return X

# passing the metadata to `fit_transform` should raise a warning since it
# could potentially be consumed by `transform`
with pytest.warns(UserWarning, match="`transform` method which consumes metadata"):
CustomTransformer().set_transform_request(prop=True).fit_transform(
[[1]], [1], prop=1
)

# not passing a metadata which can potentially be consumed by `transform` should
# not raise a warning
with warnings.catch_warnings(record=True) as record:
CustomTransformer().set_transform_request(prop=True).fit_transform([[1]], [1])
assert len(record) == 0


@pytest.mark.usefixtures("enable_slep006")
def test_outlier_mixin_fit_predict_with_metadata_in_predict():
"""Test that having an OutlierMixin with metadata for predict raises a
warning when calling fit_predict."""

class CustomOutlierDetector(BaseEstimator, OutlierMixin):
def fit(self, X, y=None, prop=None):
return self

def predict(self, X, prop=None):
return X

# passing the metadata to `fit_predict` should raise a warning since it
# could potentially be consumed by `predict`
with pytest.warns(UserWarning, match="`predict` method which consumes metadata"):
CustomOutlierDetector().set_predict_request(prop=True).fit_predict(
[[1]], [1], prop=1
)

# not passing a metadata which can potentially be consumed by `predict` should
# not raise a warning
with warnings.catch_warnings(record=True) as record:
CustomOutlierDetector().set_predict_request(prop=True).fit_predict([[1]], [1])
assert len(record) == 0
41 changes: 41 additions & 0 deletions sklearn/tests/test_metadata_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,47 @@ class Consumer(BaseEstimator):
assert mr.fit.requests == {"prop": None}


def test_metadata_request_consumes_method():
"""Test that MetadataRequest().consumes() method works as expected."""
request = MetadataRouter(owner="test")
assert request.consumes(method="fit", params={"foo"}) == set()

request = MetadataRequest(owner="test")
request.fit.add_request(param="foo", alias=True)
assert request.consumes(method="fit", params={"foo"}) == {"foo"}

request = MetadataRequest(owner="test")
request.fit.add_request(param="foo", alias="bar")
assert request.consumes(method="fit", params={"bar", "foo"}) == {"bar"}


def test_metadata_router_consumes_method():
"""Test that MetadataRouter().consumes method works as expected."""
# having it here instead of parametrizing the test since `set_fit_request`
# is not available while collecting the tests.
cases = [
(
WeightedMetaRegressor(
estimator=RegressorMetadata().set_fit_request(sample_weight=True)
),
{"sample_weight"},
{"sample_weight"},
),
(
WeightedMetaRegressor(
estimator=RegressorMetadata().set_fit_request(
sample_weight="my_weights"
)
),
{"my_weights", "sample_weight"},
{"my_weights"},
),
]

for obj, input, output in cases:
assert obj.get_metadata_routing().consumes(method="fit", params=input) == output


def test_metaestimator_warnings():
class WeightedMetaRegressorWarn(WeightedMetaRegressor):
__metadata_request__fit = {"sample_weight": metadata_routing.WARN}
Expand Down
73 changes: 73 additions & 0 deletions sklearn/utils/_metadata_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,28 @@ def _route_params(self, params):
)
return res

def _consumes(self, params):
"""Check whether the given parameters are consumed by this method.

Parameters
----------
params : iterable of str
An iterable of parameters to check.

Returns
-------
consumed : set of str
A set of parameters which are consumed by this method.
"""
params = set(params)
res = set()
for prop, alias in self._requests.items():
if alias is True and prop in params:
res.add(prop)
elif isinstance(alias, str) and alias in params:
res.add(alias)
return res

def _serialize(self):
"""Serialize the object.

Expand Down Expand Up @@ -408,6 +430,26 @@ def __init__(self, owner):
MethodMetadataRequest(owner=owner, method=method),
)

def consumes(self, method, params):
"""Check whether the given parameters are consumed by the given method.

.. versionadded:: 1.4

Parameters
----------
method : str
The name of the method to check.

params : iterable of str
An iterable of parameters to check.

Returns
-------
consumed : set of str
A set of parameters which are consumed by the given method.
"""
return getattr(self, method)._consumes(params=params)

def __getattr__(self, name):
# Called when the default attribute access fails with an AttributeError
# (either __getattribute__() raises an AttributeError because name is
Expand Down Expand Up @@ -736,6 +778,37 @@ def add(self, *, method_mapping, **objs):
)
return self

def consumes(self, method, params):
"""Check whether the given parameters are consumed by the given method.

.. versionadded:: 1.4

Parameters
----------
method : str
The name of the method to check.

params : iterable of str
An iterable of parameters to check.

Returns
-------
consumed : set of str
A set of parameters which are consumed by the given method.
"""
res = set()
if self._self_request:
res = res | self._self_request.consumes(method=method, params=params)

for _, route_mapping in self._route_mappings.items():
for callee, caller in route_mapping.mapping:
if caller == method:
res = res | route_mapping.router.consumes(
method=callee, params=params
)

return res

def _get_param_names(self, *, method, return_alias, ignore_self_request):
"""Get names of all metadata that can be consumed or routed by specified \
method.
Expand Down