Skip to content

FEAT multioutput routes metadata #22986

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 44 commits into from
Jul 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
1fbda38
FEAT add metadata routing support to scorers
adrinjalali Mar 10, 2022
51baf9d
clarify doc on not mutating scorers
adrinjalali Mar 10, 2022
78638c4
improve tests
adrinjalali Mar 10, 2022
ea317bc
remove unused import
adrinjalali Mar 10, 2022
6c6c2f1
fix tests
adrinjalali Mar 11, 2022
b24c3ad
Christian's comments
adrinjalali Mar 11, 2022
a663b86
set_score_request -> with_score_request on scorers
adrinjalali Mar 15, 2022
e838a71
warn on overlapping kwargs and metadata
adrinjalali Mar 15, 2022
0c6e2c4
add references to docs
adrinjalali Mar 15, 2022
0c61c05
add a note on custom scorers
adrinjalali Mar 15, 2022
7840a6c
Merge remote-tracking branch 'upstream/sample-props' into slep6-scorers
adrinjalali Mar 22, 2022
ba36307
Revert "set_score_request -> with_score_request on scorers"
adrinjalali Mar 22, 2022
2bea203
set_score_request now mutates the instance
adrinjalali Mar 23, 2022
dba266f
don't test repr
adrinjalali Mar 23, 2022
203b4e4
fix and test _passthrough_scorer
adrinjalali Mar 23, 2022
4e32f31
Joel's comments
adrinjalali Mar 23, 2022
da3c6a6
writing test
adrinjalali Mar 24, 2022
44f91ce
Thomas's comments
adrinjalali Mar 24, 2022
3a5ca34
Merge branch 'slep6-scorers' into slep6/multioutput
adrinjalali Mar 24, 2022
7421fe4
for Thomas
adrinjalali Mar 25, 2022
380b20a
...
adrinjalali Mar 28, 2022
8c19c8c
Merge remote-tracking branch 'upstream/sample-props' into slep6/multi…
adrinjalali Mar 28, 2022
1c64ec0
...
adrinjalali Mar 28, 2022
f15ef55
all
adrinjalali Mar 29, 2022
3e5048e
Merge remote-tracking branch 'upstream/sample-props' into slep6/multi…
adrinjalali Mar 29, 2022
ab9daa8
remove unused imports
adrinjalali Mar 29, 2022
bc46a2e
can't remove sample_weight from signature since we forgot to make the…
adrinjalali Mar 29, 2022
fba9e17
common test sub-estimators request weight by default
adrinjalali Apr 1, 2022
54b093b
fix docstrings
adrinjalali Apr 5, 2022
9a60ded
Merge remote-tracking branch 'upstream/sample-props' into slep6/multi…
adrinjalali May 12, 2022
268ae56
minor edits
adrinjalali May 12, 2022
aeb4ed0
Merge remote-tracking branch 'upstream/sample-props' into slep6/multi…
adrinjalali May 12, 2022
fb36733
staking a stab at param specific deprecation
adrinjalali May 16, 2022
1baea75
fix tests
adrinjalali May 17, 2022
a985e83
add test for new code
adrinjalali May 17, 2022
cb77212
remove unused _assume_requested
adrinjalali May 17, 2022
3053582
document exception parmaters
adrinjalali May 17, 2022
a8bc9ca
try a different URL
adrinjalali Jun 9, 2022
2fd59b0
address comments
adrinjalali Jul 16, 2022
8329a6f
Merge remote-tracking branch 'upstream/sample-props' into slep6/multi…
adrinjalali Jul 16, 2022
c904c28
add docs
adrinjalali Jul 16, 2022
8eb85cf
fix example code
adrinjalali Jul 16, 2022
a056d25
remove extra backtick
adrinjalali Jul 16, 2022
916e92b
test _is_default_request and remove some unused lines
adrinjalali Jul 17, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/metadata_routing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ or not::
... ).fit(X, y, sample_weight=my_weights)
... except ValueError as e:
... print(e)
sample_weight is passed but is not explicitly set as requested or not for
[sample_weight] are passed but are not explicitly set as requested or not for
LogisticRegression.score

The issue can be fixed by explicitly setting the request value::
Expand Down
51 changes: 49 additions & 2 deletions examples/plot_metadata_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from sklearn.utils.metadata_routing import MethodMapping
from sklearn.utils.metadata_routing import process_routing
from sklearn.utils.validation import check_is_fitted
from sklearn.linear_model import LinearRegression
from sklearn.linear_model import LinearRegression, LogisticRegression

N, M = 100, 4
X = np.random.rand(N, M)
Expand Down Expand Up @@ -585,7 +585,7 @@ def get_metadata_routing(self):


# %%
# When an estimator suports a metadata which wasn't supported before, the
# When an estimator supports a metadata which wasn't supported before, the
# following pattern can be used to warn the users about it.


Expand All @@ -605,6 +605,53 @@ def predict(self, X):
for w in record:
print(w.message)

# %%
# Deprecation to Give Users Time to Adapt their Code
# --------------------------------------------------
# With the introduction of metadata routing, following user code would raise an
# error:

try:
reg = MetaRegressor(estimator=LinearRegression())
reg.fit(X, y, sample_weight=my_weights)
except Exception as e:
print(e)

# %%
# You might want to give your users a period during which they see a
# ``FutureWarning`` instead in order to have time to adapt to the new API. For
# this, the :class:`~sklearn.utils.metadata_routing.MetadataRouter` provides a
# `warn_on` method:


class WarningMetaRegressor(MetaEstimatorMixin, RegressorMixin, BaseEstimator):
def __init__(self, estimator):
self.estimator = estimator

def fit(self, X, y, **fit_params):
params = process_routing(self, "fit", fit_params)
self.estimator_ = clone(self.estimator).fit(X, y, **params.estimator.fit)

def get_metadata_routing(self):
router = (
MetadataRouter(owner=self.__class__.__name__)
.add(estimator=self.estimator, method_mapping="one-to-one")
.warn_on(child="estimator", method="fit", params=None)
)
return router


with warnings.catch_warnings(record=True) as record:
WarningMetaRegressor(estimator=LogisticRegression()).fit(
X, y, sample_weight=my_weights
)
for w in record:
print(w.message)

# %%
# Note that in the above implementation, the value passed to ``child`` the same
# as the key passed to the ``add`` method, in this case ``"estimator"``.

# %%
# Third Party Development and scikit-learn Dependency
# ---------------------------------------------------
Expand Down
26 changes: 26 additions & 0 deletions sklearn/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,35 @@
"SkipTestWarning",
"UndefinedMetricWarning",
"PositiveSpectrumWarning",
"UnsetMetadataPassedError",
]


class UnsetMetadataPassedError(ValueError):
"""Exception class to raise if a metadata is passed which is not explicitly \
requested.

.. versionadded:: 1.2

Parameters
----------
message : str
The message

unrequested_params : dict
A dictionary of parameters and their values which are provided but not
requested.

routed_params : dict
A dictionary of routed parameters.
"""

def __init__(self, *, message, unrequested_params, routed_params):
super().__init__(message)
self.unrequested_params = unrequested_params
self.routed_params = routed_params


class NotFittedError(ValueError, AttributeError):
"""Exception class to raise if estimator is used before fitting.

Expand Down
93 changes: 64 additions & 29 deletions sklearn/multioutput.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@
from .model_selection import cross_val_predict
from .utils.metaestimators import available_if
from .utils import check_random_state
from .utils.validation import check_is_fitted, has_fit_parameter, _check_fit_params
from .utils.validation import check_is_fitted
from .utils.multiclass import check_classification_targets
from .utils.fixes import delayed
from .utils.metadata_routing import MetadataRouter, MethodMapping, process_routing

__all__ = [
"MultiOutputRegressor",
Expand All @@ -46,21 +47,16 @@ def _fit_estimator(estimator, X, y, sample_weight=None, **fit_params):


def _partial_fit_estimator(
estimator, X, y, classes=None, sample_weight=None, first_time=True
estimator, X, y, classes=None, partial_fit_params=None, first_time=True
):
partial_fit_params = {} if partial_fit_params is None else partial_fit_params
if first_time:
estimator = clone(estimator)

if sample_weight is not None:
if classes is not None:
estimator.partial_fit(X, y, classes=classes, sample_weight=sample_weight)
else:
estimator.partial_fit(X, y, sample_weight=sample_weight)
if classes is not None:
estimator.partial_fit(X, y, classes=classes, **partial_fit_params)
else:
if classes is not None:
estimator.partial_fit(X, y, classes=classes)
else:
estimator.partial_fit(X, y)
estimator.partial_fit(X, y, **partial_fit_params)
return estimator


Expand All @@ -85,7 +81,7 @@ def __init__(self, estimator, *, n_jobs=None):
self.n_jobs = n_jobs

@_available_if_estimator_has("partial_fit")
def partial_fit(self, X, y, classes=None, sample_weight=None):
def partial_fit(self, X, y, classes=None, sample_weight=None, **partial_fit_params):
"""Incrementally fit a separate model for each class output.

Parameters
Expand All @@ -110,6 +106,11 @@ def partial_fit(self, X, y, classes=None, sample_weight=None):
Only supported if the underlying regressor supports sample
weights.

**partial_fit_params : dict of str -> object
Parameters passed to the ``estimator.partial_fit`` method of each step.

.. versionadded:: 1.2

Returns
-------
self : object
Expand All @@ -124,10 +125,12 @@ def partial_fit(self, X, y, classes=None, sample_weight=None):
"multi-output regression but has only one."
)

if sample_weight is not None and not has_fit_parameter(
self.estimator, "sample_weight"
):
raise ValueError("Underlying estimator does not support sample weights.")
routed_params = process_routing(
obj=self,
method="partial_fit",
other_params=partial_fit_params,
sample_weight=sample_weight,
)

first_time = not hasattr(self, "estimators_")

Expand All @@ -137,8 +140,8 @@ def partial_fit(self, X, y, classes=None, sample_weight=None):
X,
y[:, i],
classes[i] if classes is not None else None,
sample_weight,
first_time,
partial_fit_params=routed_params.estimator.partial_fit,
first_time=first_time,
)
for i in range(y.shape[1])
)
Expand Down Expand Up @@ -192,16 +195,13 @@ def fit(self, X, y, sample_weight=None, **fit_params):
"multi-output regression but has only one."
)

if sample_weight is not None and not has_fit_parameter(
self.estimator, "sample_weight"
):
raise ValueError("Underlying estimator does not support sample weights.")

fit_params_validated = _check_fit_params(X, fit_params)
routed_params = process_routing(
obj=self, method="fit", other_params=fit_params, sample_weight=sample_weight
)

self.estimators_ = Parallel(n_jobs=self.n_jobs)(
delayed(_fit_estimator)(
self.estimator, X, y[:, i], sample_weight, **fit_params_validated
self.estimator, X, y[:, i], **routed_params.estimator.fit
)
for i in range(y.shape[1])
)
Expand Down Expand Up @@ -240,6 +240,36 @@ def predict(self, X):
def _more_tags(self):
return {"multioutput_only": True}

def get_metadata_routing(self):
"""Get metadata routing of this object.

Please check :ref:`User Guide <metadata_routing>` on how the routing
mechanism works.

Returns
-------
routing : MetadataRouter
A :class:`~utils.metadata_routing.MetadataRouter` encapsulating
routing information.
"""
router = (
MetadataRouter(owner=self.__class__.__name__).add(
estimator=self.estimator,
method_mapping=MethodMapping()
.add(callee="partial_fit", caller="partial_fit")
.add(callee="fit", caller="fit"),
)
# the fit method already accepts everything, therefore we don't
# specify parameters. The value passed to ``child`` needs to be the
# same as what's passed to ``add`` above, in this case
# `"estimator"`.
.warn_on(child="estimator", method="fit", params=None)
# the partial_fit method at the time of this change (v1.2) only
# supports sample_weight, therefore we only include this metadata.
.warn_on(child="estimator", method="partial_fit", params=["sample_weight"])
)
return router


class MultiOutputRegressor(RegressorMixin, _MultiOutputEstimator):
"""Multi target regression.
Expand Down Expand Up @@ -311,7 +341,7 @@ def __init__(self, estimator, *, n_jobs=None):
super().__init__(estimator, n_jobs=n_jobs)

@_available_if_estimator_has("partial_fit")
def partial_fit(self, X, y, sample_weight=None):
def partial_fit(self, X, y, sample_weight=None, **partial_fit_params):
"""Incrementally fit the model to data, for each output variable.

Parameters
Expand All @@ -327,12 +357,17 @@ def partial_fit(self, X, y, sample_weight=None):
Only supported if the underlying regressor supports sample
weights.

**partial_fit_params : dict of str -> object
Parameters passed to the ``estimator.partial_fit`` method of each step.

.. versionadded:: 1.2

Returns
-------
self : object
Returns a fitted instance.
"""
super().partial_fit(X, y, sample_weight=sample_weight)
super().partial_fit(X, y, sample_weight=sample_weight, **partial_fit_params)


class MultiOutputClassifier(ClassifierMixin, _MultiOutputEstimator):
Expand Down Expand Up @@ -419,7 +454,7 @@ def fit(self, X, Y, sample_weight=None, **fit_params):

sample_weight : array-like of shape (n_samples,), default=None
Sample weights. If `None`, then samples are equally weighted.
Only supported if the underlying classifier supports sample
Only supported if the underlying regressor supports sample
weights.

**fit_params : dict of string -> object
Expand All @@ -432,7 +467,7 @@ def fit(self, X, Y, sample_weight=None, **fit_params):
self : object
Returns a fitted instance.
"""
super().fit(X, Y, sample_weight, **fit_params)
super().fit(X, Y, sample_weight=sample_weight, **fit_params)
self.classes_ = [estimator.classes_ for estimator in self.estimators_]
return self

Expand Down
Loading