diff --git a/doc/metadata_routing.rst b/doc/metadata_routing.rst index de04ecf022415..56df1e7f158db 100644 --- a/doc/metadata_routing.rst +++ b/doc/metadata_routing.rst @@ -205,7 +205,7 @@ or not:: ... ).fit(X, y, sample_weight=my_weights) ... except ValueError as e: ... print(e) - sample_weight is passed but is not explicitly set as requested or not for + [sample_weight] are passed but are not explicitly set as requested or not for LogisticRegression.score The issue can be fixed by explicitly setting the request value:: diff --git a/examples/plot_metadata_routing.py b/examples/plot_metadata_routing.py index 91c3651699ee4..98895b823003b 100644 --- a/examples/plot_metadata_routing.py +++ b/examples/plot_metadata_routing.py @@ -37,7 +37,7 @@ from sklearn.utils.metadata_routing import MethodMapping from sklearn.utils.metadata_routing import process_routing from sklearn.utils.validation import check_is_fitted -from sklearn.linear_model import LinearRegression +from sklearn.linear_model import LinearRegression, LogisticRegression N, M = 100, 4 X = np.random.rand(N, M) @@ -585,7 +585,7 @@ def get_metadata_routing(self): # %% -# When an estimator suports a metadata which wasn't supported before, the +# When an estimator supports a metadata which wasn't supported before, the # following pattern can be used to warn the users about it. @@ -605,6 +605,53 @@ def predict(self, X): for w in record: print(w.message) +# %% +# Deprecation to Give Users Time to Adapt their Code +# -------------------------------------------------- +# With the introduction of metadata routing, following user code would raise an +# error: + +try: + reg = MetaRegressor(estimator=LinearRegression()) + reg.fit(X, y, sample_weight=my_weights) +except Exception as e: + print(e) + +# %% +# You might want to give your users a period during which they see a +# ``FutureWarning`` instead in order to have time to adapt to the new API. For +# this, the :class:`~sklearn.utils.metadata_routing.MetadataRouter` provides a +# `warn_on` method: + + +class WarningMetaRegressor(MetaEstimatorMixin, RegressorMixin, BaseEstimator): + def __init__(self, estimator): + self.estimator = estimator + + def fit(self, X, y, **fit_params): + params = process_routing(self, "fit", fit_params) + self.estimator_ = clone(self.estimator).fit(X, y, **params.estimator.fit) + + def get_metadata_routing(self): + router = ( + MetadataRouter(owner=self.__class__.__name__) + .add(estimator=self.estimator, method_mapping="one-to-one") + .warn_on(child="estimator", method="fit", params=None) + ) + return router + + +with warnings.catch_warnings(record=True) as record: + WarningMetaRegressor(estimator=LogisticRegression()).fit( + X, y, sample_weight=my_weights + ) +for w in record: + print(w.message) + +# %% +# Note that in the above implementation, the value passed to ``child`` the same +# as the key passed to the ``add`` method, in this case ``"estimator"``. + # %% # Third Party Development and scikit-learn Dependency # --------------------------------------------------- diff --git a/sklearn/exceptions.py b/sklearn/exceptions.py index d84c1f6b40526..23789dd68343e 100644 --- a/sklearn/exceptions.py +++ b/sklearn/exceptions.py @@ -13,9 +13,35 @@ "SkipTestWarning", "UndefinedMetricWarning", "PositiveSpectrumWarning", + "UnsetMetadataPassedError", ] +class UnsetMetadataPassedError(ValueError): + """Exception class to raise if a metadata is passed which is not explicitly \ + requested. + + .. versionadded:: 1.2 + + Parameters + ---------- + message : str + The message + + unrequested_params : dict + A dictionary of parameters and their values which are provided but not + requested. + + routed_params : dict + A dictionary of routed parameters. + """ + + def __init__(self, *, message, unrequested_params, routed_params): + super().__init__(message) + self.unrequested_params = unrequested_params + self.routed_params = routed_params + + class NotFittedError(ValueError, AttributeError): """Exception class to raise if estimator is used before fitting. diff --git a/sklearn/multioutput.py b/sklearn/multioutput.py index 24e4cc8dda7e8..ed33fde4ecb20 100644 --- a/sklearn/multioutput.py +++ b/sklearn/multioutput.py @@ -24,9 +24,10 @@ from .model_selection import cross_val_predict from .utils.metaestimators import available_if from .utils import check_random_state -from .utils.validation import check_is_fitted, has_fit_parameter, _check_fit_params +from .utils.validation import check_is_fitted from .utils.multiclass import check_classification_targets from .utils.fixes import delayed +from .utils.metadata_routing import MetadataRouter, MethodMapping, process_routing __all__ = [ "MultiOutputRegressor", @@ -46,21 +47,16 @@ def _fit_estimator(estimator, X, y, sample_weight=None, **fit_params): def _partial_fit_estimator( - estimator, X, y, classes=None, sample_weight=None, first_time=True + estimator, X, y, classes=None, partial_fit_params=None, first_time=True ): + partial_fit_params = {} if partial_fit_params is None else partial_fit_params if first_time: estimator = clone(estimator) - if sample_weight is not None: - if classes is not None: - estimator.partial_fit(X, y, classes=classes, sample_weight=sample_weight) - else: - estimator.partial_fit(X, y, sample_weight=sample_weight) + if classes is not None: + estimator.partial_fit(X, y, classes=classes, **partial_fit_params) else: - if classes is not None: - estimator.partial_fit(X, y, classes=classes) - else: - estimator.partial_fit(X, y) + estimator.partial_fit(X, y, **partial_fit_params) return estimator @@ -85,7 +81,7 @@ def __init__(self, estimator, *, n_jobs=None): self.n_jobs = n_jobs @_available_if_estimator_has("partial_fit") - def partial_fit(self, X, y, classes=None, sample_weight=None): + def partial_fit(self, X, y, classes=None, sample_weight=None, **partial_fit_params): """Incrementally fit a separate model for each class output. Parameters @@ -110,6 +106,11 @@ def partial_fit(self, X, y, classes=None, sample_weight=None): Only supported if the underlying regressor supports sample weights. + **partial_fit_params : dict of str -> object + Parameters passed to the ``estimator.partial_fit`` method of each step. + + .. versionadded:: 1.2 + Returns ------- self : object @@ -124,10 +125,12 @@ def partial_fit(self, X, y, classes=None, sample_weight=None): "multi-output regression but has only one." ) - if sample_weight is not None and not has_fit_parameter( - self.estimator, "sample_weight" - ): - raise ValueError("Underlying estimator does not support sample weights.") + routed_params = process_routing( + obj=self, + method="partial_fit", + other_params=partial_fit_params, + sample_weight=sample_weight, + ) first_time = not hasattr(self, "estimators_") @@ -137,8 +140,8 @@ def partial_fit(self, X, y, classes=None, sample_weight=None): X, y[:, i], classes[i] if classes is not None else None, - sample_weight, - first_time, + partial_fit_params=routed_params.estimator.partial_fit, + first_time=first_time, ) for i in range(y.shape[1]) ) @@ -192,16 +195,13 @@ def fit(self, X, y, sample_weight=None, **fit_params): "multi-output regression but has only one." ) - if sample_weight is not None and not has_fit_parameter( - self.estimator, "sample_weight" - ): - raise ValueError("Underlying estimator does not support sample weights.") - - fit_params_validated = _check_fit_params(X, fit_params) + routed_params = process_routing( + obj=self, method="fit", other_params=fit_params, sample_weight=sample_weight + ) self.estimators_ = Parallel(n_jobs=self.n_jobs)( delayed(_fit_estimator)( - self.estimator, X, y[:, i], sample_weight, **fit_params_validated + self.estimator, X, y[:, i], **routed_params.estimator.fit ) for i in range(y.shape[1]) ) @@ -240,6 +240,36 @@ def predict(self, X): def _more_tags(self): return {"multioutput_only": True} + def get_metadata_routing(self): + """Get metadata routing of this object. + + Please check :ref:`User Guide ` on how the routing + mechanism works. + + Returns + ------- + routing : MetadataRouter + A :class:`~utils.metadata_routing.MetadataRouter` encapsulating + routing information. + """ + router = ( + MetadataRouter(owner=self.__class__.__name__).add( + estimator=self.estimator, + method_mapping=MethodMapping() + .add(callee="partial_fit", caller="partial_fit") + .add(callee="fit", caller="fit"), + ) + # the fit method already accepts everything, therefore we don't + # specify parameters. The value passed to ``child`` needs to be the + # same as what's passed to ``add`` above, in this case + # `"estimator"`. + .warn_on(child="estimator", method="fit", params=None) + # the partial_fit method at the time of this change (v1.2) only + # supports sample_weight, therefore we only include this metadata. + .warn_on(child="estimator", method="partial_fit", params=["sample_weight"]) + ) + return router + class MultiOutputRegressor(RegressorMixin, _MultiOutputEstimator): """Multi target regression. @@ -311,7 +341,7 @@ def __init__(self, estimator, *, n_jobs=None): super().__init__(estimator, n_jobs=n_jobs) @_available_if_estimator_has("partial_fit") - def partial_fit(self, X, y, sample_weight=None): + def partial_fit(self, X, y, sample_weight=None, **partial_fit_params): """Incrementally fit the model to data, for each output variable. Parameters @@ -327,12 +357,17 @@ def partial_fit(self, X, y, sample_weight=None): Only supported if the underlying regressor supports sample weights. + **partial_fit_params : dict of str -> object + Parameters passed to the ``estimator.partial_fit`` method of each step. + + .. versionadded:: 1.2 + Returns ------- self : object Returns a fitted instance. """ - super().partial_fit(X, y, sample_weight=sample_weight) + super().partial_fit(X, y, sample_weight=sample_weight, **partial_fit_params) class MultiOutputClassifier(ClassifierMixin, _MultiOutputEstimator): @@ -419,7 +454,7 @@ def fit(self, X, Y, sample_weight=None, **fit_params): sample_weight : array-like of shape (n_samples,), default=None Sample weights. If `None`, then samples are equally weighted. - Only supported if the underlying classifier supports sample + Only supported if the underlying regressor supports sample weights. **fit_params : dict of string -> object @@ -432,7 +467,7 @@ def fit(self, X, Y, sample_weight=None, **fit_params): self : object Returns a fitted instance. """ - super().fit(X, Y, sample_weight, **fit_params) + super().fit(X, Y, sample_weight=sample_weight, **fit_params) self.classes_ = [estimator.classes_ for estimator in self.estimators_] return self diff --git a/sklearn/tests/test_metadata_routing.py b/sklearn/tests/test_metadata_routing.py index 6d98d251c67fb..6f0e1bf485a9b 100644 --- a/sklearn/tests/test_metadata_routing.py +++ b/sklearn/tests/test_metadata_routing.py @@ -339,13 +339,11 @@ def test_simple_metadata_routing(): # If the estimator accepts the metadata but doesn't explicitly say it doesn't # need it, there's an error clf = SimpleMetaClassifier(estimator=ClassifierFitMetadata()) - with pytest.raises( - ValueError, - match=( - "sample_weight is passed but is not explicitly set as requested or not for" - " ClassifierFitMetadata.fit" - ), - ): + err_message = ( + "[sample_weight] are passed but are not explicitly set as requested or" + " not for ClassifierFitMetadata.fit" + ) + with pytest.raises(ValueError, match=re.escape(err_message)): clf.fit(X, y, sample_weight=my_weights) # Explicitly saying the estimator doesn't need it, makes the error go away, @@ -587,7 +585,9 @@ def fit(self, X, y, prop=None, **kwargs): def test_method_metadata_request(): - mmr = MethodMetadataRequest(owner="test", method="fit") + mmr = MethodMetadataRequest( + router=MetadataRequest(owner="test"), owner="test", method="fit" + ) with pytest.raises( ValueError, match="alias should be either a valid identifier or" @@ -654,9 +654,9 @@ class RegressorMetadataWarn(RegressorMetadata): "obj, string", [ ( - MethodMetadataRequest(owner="test", method="fit").add_request( - param="foo", alias="bar" - ), + MethodMetadataRequest( + router=MetadataRequest(owner="test"), owner="test", method="fit" + ).add_request(param="foo", alias="bar"), "{'foo': 'bar'}", ), ( @@ -849,6 +849,180 @@ def test_metadata_routing_get_param_names(): ) == router._get_param_names(method="fit", return_alias=False, ignore_self=True) +def test_router_deprecation_warning(): + """This test checks the warning mechanism related to `warn_on`. + + `warn_on` is there to handle backward compatibility in cases where the + meta-estimator is already doing some routing, and SLEP006 would break + existing user code. `warn_on` helps converting some of those errors to + warnings. + + In different scenarios with a meta-estimator and a child estimator we test + if the warning is raised when it should, an error raised when it should, + and the combinations of the above cases. + """ + + class MetaEstimator(BaseEstimator, MetaEstimatorMixin): + def __init__(self, estimator): + self.estimator = estimator + + def fit(self, X, y, **fit_params): + routed_params = process_routing(self, "fit", fit_params) + self.estimator_ = clone(self.estimator).fit( + X, y, **routed_params.estimator.fit + ) + + def predict(self, X, **predict_params): + routed_params = process_routing(self, "predict", predict_params) + return self.estimator_.predict(X, **routed_params.estimator.predict) + + def get_metadata_routing(self): + return ( + MetadataRouter(owner=self.__class__.__name__) + .add(estimator=self.estimator, method_mapping="one-to-one") + .warn_on( + child="estimator", + method="fit", + params=None, + raise_on="1.4", + ) + ) + + class Estimator(BaseEstimator): + def fit(self, X, y, sample_weight=None, groups=None): + return self + + def predict(self, X, sample_weight=None): + return np.ones(shape=len(X)) + + est = MetaEstimator(estimator=Estimator()) + # the meta-estimator has set (using `warn_on`) to have a warning on `fit`. + with pytest.warns( + FutureWarning, match="From version 1.4 this results in the following error" + ): + est.fit(X, y, sample_weight=my_weights) + + err_msg = ( + "{params} are passed but are not explicitly set as requested or not for {owner}" + ) + warn_msg = "From version 1.4 this results in the following error" + # but predict should raise since there is no warn_on set for it. + with pytest.raises( + ValueError, + match=re.escape( + err_msg.format(params="[sample_weight]", owner="Estimator.predict") + ), + ): + est.predict(X, sample_weight=my_weights) + + # In this case both a warning and an error are raised. The warning comes + # from the MetaEstimator, and the error from WeightedMetaRegressor since it + # doesn't have any warn_on set but sample_weight is passed. + est = MetaEstimator(estimator=WeightedMetaRegressor(estimator=RegressorMetadata())) + with pytest.raises( + ValueError, + match=re.escape( + err_msg.format(params="[sample_weight]", owner="RegressorMetadata.fit") + ), + ): + with pytest.warns(FutureWarning, match=warn_msg): + est.fit(X, y, sample_weight=my_weights) + + class WarningWeightedMetaRegressor(WeightedMetaRegressor): + """A WeightedMetaRegressor which warns instead.""" + + def get_metadata_routing(self): + router = ( + MetadataRouter(owner=self.__class__.__name__) + .add_self(self) + .add(estimator=self.estimator, method_mapping="one-to-one") + .warn_on( + child="estimator", + method="fit", + params=["sample_weight"], + raise_on="1.4", + ) + .warn_on( + child="estimator", + method="score", + params=["sample_weight"], + raise_on="1.4", + ) + ) + return router + + # Now there's only a warning since both meta-estimators warn. + est = MetaEstimator( + estimator=WarningWeightedMetaRegressor(estimator=RegressorMetadata()) + ) + with pytest.warns(FutureWarning, match=warn_msg): + est.fit(X, y, sample_weight=my_weights) + + # here we should raise because there is no warn_on for groups + with pytest.raises( + ValueError, + match=re.escape( + err_msg.format(params="[sample_weight, groups]", owner="Estimator.fit") + ), + ): + # the sample_weight should still warn + with pytest.warns(FutureWarning, match=warn_msg): + WarningWeightedMetaRegressor(estimator=Estimator()).fit( + X, y, sample_weight=my_weights, groups=1 + ) + + # but if the inner estimator has a non-default request, we fall back to + # raising an error + est = MetaEstimator( + estimator=WarningWeightedMetaRegressor( + estimator=RegressorMetadata().set_fit_request(sample_weight=True) + ) + ) + with pytest.raises( + ValueError, + match=re.escape( + err_msg.format( + params="[sample_weight]", owner="WarningWeightedMetaRegressor.fit" + ) + ), + ): + est.fit(X, y, sample_weight=my_weights) + + +@pytest.mark.parametrize( + "estimator, is_default_request", + [ + (LinearRegression(), True), + (LinearRegression().set_fit_request(sample_weight=True), False), # type: ignore + (WeightedMetaRegressor(estimator=LinearRegression()), True), + ( + WeightedMetaRegressor( + estimator=LinearRegression().set_fit_request( # type: ignore + sample_weight=True + ) + ), + False, + ), + ( + WeightedMetaRegressor( + estimator=LinearRegression() + ).set_fit_request( # type: ignore + sample_weight=True + ), + False, + ), + ], +) +def test_is_default_request(estimator, is_default_request): + """Test the `_is_default_request` machinery. + + It should be `True` only if the user hasn't changed any default values. + + Applies to both `MetadataRouter` and `MetadataRequest`. + """ + assert estimator.get_metadata_routing()._is_default_request == is_default_request + + def test_method_generation(): # Test if all required request methods are generated. diff --git a/sklearn/tests/test_metaestimators_metadata_routing.py b/sklearn/tests/test_metaestimators_metadata_routing.py new file mode 100644 index 0000000000000..dae9fbf3aa03a --- /dev/null +++ b/sklearn/tests/test_metaestimators_metadata_routing.py @@ -0,0 +1,125 @@ +import numpy as np +import pytest +from sklearn.base import RegressorMixin, ClassifierMixin, BaseEstimator +from sklearn.multioutput import MultiOutputRegressor, MultiOutputClassifier +from sklearn.utils.metadata_routing import MetadataRouter +from sklearn.tests.test_metadata_routing import ( + record_metadata, + check_recorded_metadata, + assert_request_is_empty, +) + +N, M = 100, 4 +X = np.random.rand(N, M) +y = np.random.randint(0, 2, size=N) +y_multi = np.random.randint(0, 2, size=(N, 3)) +metadata = np.random.randint(0, 10, size=N) +sample_weight = np.random.rand(N) + + +class ConsumingRegressor(RegressorMixin, BaseEstimator): + """A regressor consuming metadata.""" + + def partial_fit(self, X, y, sample_weight=None, metadata=None): + record_metadata( + self, "partial_fit", sample_weight=sample_weight, metadata=metadata + ) + return self + + def fit(self, X, y, sample_weight=None, metadata=None): + record_metadata(self, "fit", sample_weight=sample_weight, metadata=metadata) + return self + + def predict(self, X, y, sample_weight=None, metadata=None): + record_metadata(self, "predict", sample_weight=sample_weight, metadata=metadata) + return np.zeros(shape=(len(X))) + + +class ConsumingClassifier(ClassifierMixin, BaseEstimator): + """A classifier consuming metadata.""" + + def partial_fit(self, X, y, sample_weight=None, metadata=None): + record_metadata( + self, "partial_fit", sample_weight=sample_weight, metadata=metadata + ) + self.classes_ = [1] + return self + + def fit(self, X, y, sample_weight=None, metadata=None): + record_metadata(self, "fit", sample_weight=sample_weight, metadata=metadata) + self.classes_ = [1] + return self + + def predict(self, X, y, sample_weight=None, metadata=None): + record_metadata(self, "predict", sample_weight=sample_weight, metadata=metadata) + return np.zeros(shape=(len(X))) + + def predict_proba(self, X, y, sample_weight=None, metadata=None): + record_metadata( + self, "predict_proba", sample_weight=sample_weight, metadata=metadata + ) + return np.zeros(shape=(len(X))) + + def predict_log_proba(self, X, y, sample_weight=None, metadata=None): + record_metadata( + self, "predict_log_proba", sample_weight=sample_weight, metadata=metadata + ) + return np.zeros(shape=(len(X))) + + +def get_empty_metaestimators(): + yield MultiOutputRegressor(estimator=ConsumingRegressor()) + yield MultiOutputClassifier(estimator=ConsumingClassifier()) + + +@pytest.mark.parametrize( + "metaestimator", + get_empty_metaestimators(), + ids=[str(x) for x in get_empty_metaestimators()], +) +def test_default_request(metaestimator): + # Check that by default request is empty and the right type + assert_request_is_empty(metaestimator.get_metadata_routing()) + assert isinstance(metaestimator.get_metadata_routing(), MetadataRouter) + + +@pytest.mark.parametrize( + "MultiOutput, Estimator", + [ + (MultiOutputClassifier, ConsumingClassifier), + (MultiOutputRegressor, ConsumingRegressor), + ], + ids=["Classifier", "Regressor"], +) +def test_multioutput_metadata_routing(MultiOutput, Estimator): + # Check routing of metadata + metaest = MultiOutput(Estimator()) + warn_msg = ( + "You are passing metadata for which the request values are not explicitly" + " set: sample_weight, metadata." + ) + with pytest.warns(FutureWarning, match=(warn_msg)): + metaest.fit(X, y_multi, sample_weight=sample_weight, metadata=metadata) + check_recorded_metadata( + metaest.estimators_[0], + "fit", + sample_weight=sample_weight, + metadata=metadata, + ) + + metaest = MultiOutput( + Estimator() + .set_fit_request(sample_weight=True, metadata="alias") + .set_partial_fit_request(sample_weight=True, metadata=True) + ).fit(X, y_multi, alias=metadata) + # if an estimator requests a metadata but it's not passed, no errors are + # raised. Therefore here we don't pass `sample_weight` to test nothing is + # raised. + check_recorded_metadata( + metaest.estimators_[0], "fit", sample_weight=None, metadata=metadata + ) + + metaest.partial_fit(X, y_multi, metadata=metadata) + check_recorded_metadata( + metaest.estimators_[0], "partial_fit", sample_weight=None, metadata=metadata + ) diff --git a/sklearn/tests/test_multioutput.py b/sklearn/tests/test_multioutput.py index 25d209223acc1..6ae07d63d9f04 100644 --- a/sklearn/tests/test_multioutput.py +++ b/sklearn/tests/test_multioutput.py @@ -1,4 +1,5 @@ import pytest +import re import numpy as np import scipy.sparse as sp from joblib import cpu_count @@ -27,7 +28,6 @@ from sklearn.base import ClassifierMixin from sklearn.utils import shuffle from sklearn.model_selection import GridSearchCV -from sklearn.dummy import DummyRegressor, DummyClassifier from sklearn.pipeline import make_pipeline from sklearn.impute import SimpleImputer from sklearn.ensemble import StackingRegressor @@ -110,12 +110,14 @@ def test_multi_target_sample_weights_api(): w = [0.8, 0.6] rgr = MultiOutputRegressor(OrthogonalMatchingPursuit()) - msg = "does not support sample weights" - with pytest.raises(ValueError, match=msg): + msg = re.escape("fit got unexpected argument(s) {'sample_weight'}") + with pytest.raises(TypeError, match=msg): rgr.fit(X, y, w) # no exception should be raised if the base estimator supports weights - rgr = MultiOutputRegressor(GradientBoostingRegressor(random_state=0)) + rgr = MultiOutputRegressor( + GradientBoostingRegressor(random_state=0).set_fit_request(sample_weight=True) + ) rgr.fit(X, y, w) @@ -124,12 +126,20 @@ def test_multi_target_sample_weight_partial_fit(): X = [[1, 2, 3], [4, 5, 6]] y = [[3.141, 2.718], [2.718, 3.141]] w = [2.0, 1.0] - rgr_w = MultiOutputRegressor(SGDRegressor(random_state=0, max_iter=5)) + rgr_w = MultiOutputRegressor( + SGDRegressor(random_state=0, max_iter=5).set_partial_fit_request( + sample_weight=True + ) + ) rgr_w.partial_fit(X, y, w) # weighted with different weights w = [2.0, 2.0] - rgr = MultiOutputRegressor(SGDRegressor(random_state=0, max_iter=5)) + rgr = MultiOutputRegressor( + SGDRegressor(random_state=0, max_iter=5).set_partial_fit_request( + sample_weight=True + ) + ) rgr.partial_fit(X, y, w) assert rgr.predict(X)[0][0] != rgr_w.predict(X)[0][0] @@ -140,7 +150,9 @@ def test_multi_target_sample_weights(): Xw = [[1, 2, 3], [4, 5, 6]] yw = [[3.141, 2.718], [2.718, 3.141]] w = [2.0, 1.0] - rgr_w = MultiOutputRegressor(GradientBoostingRegressor(random_state=0)) + rgr_w = MultiOutputRegressor( + GradientBoostingRegressor(random_state=0).set_fit_request(sample_weight=True) + ) rgr_w.fit(Xw, yw, w) # unweighted, but with repeated samples @@ -363,7 +375,9 @@ def test_multi_output_classification_sample_weights(): Xw = [[1, 2, 3], [4, 5, 6]] yw = [[3, 2], [2, 3]] w = np.asarray([2.0, 1.0]) - forest = RandomForestClassifier(n_estimators=10, random_state=1) + forest = RandomForestClassifier(n_estimators=10, random_state=1).set_fit_request( + sample_weight=True + ) clf_w = MultiOutputClassifier(forest) clf_w.fit(Xw, yw, w) @@ -383,7 +397,9 @@ def test_multi_output_classification_partial_fit_sample_weights(): Xw = [[1, 2, 3], [4, 5, 6], [1.5, 2.5, 3.5]] yw = [[3, 2], [2, 3], [3, 2]] w = np.asarray([2.0, 1.0, 1.0]) - sgd_linear_clf = SGDClassifier(random_state=1, max_iter=20) + sgd_linear_clf = SGDClassifier(random_state=1, max_iter=20).set_fit_request( + sample_weight=True + ) clf_w = MultiOutputClassifier(sgd_linear_clf) clf_w.fit(Xw, yw, w) @@ -603,40 +619,6 @@ def test_multi_output_classes_(estimator): assert_array_equal(estimator_classes, expected_classes) -class DummyRegressorWithFitParams(DummyRegressor): - def fit(self, X, y, sample_weight=None, **fit_params): - self._fit_params = fit_params - return super().fit(X, y, sample_weight) - - -class DummyClassifierWithFitParams(DummyClassifier): - def fit(self, X, y, sample_weight=None, **fit_params): - self._fit_params = fit_params - return super().fit(X, y, sample_weight) - - -@pytest.mark.filterwarnings("ignore:`n_features_in_` is deprecated") -@pytest.mark.parametrize( - "estimator, dataset", - [ - ( - MultiOutputClassifier(DummyClassifierWithFitParams(strategy="prior")), - datasets.make_multilabel_classification(), - ), - ( - MultiOutputRegressor(DummyRegressorWithFitParams()), - datasets.make_regression(n_targets=3, random_state=0), - ), - ], -) -def test_multioutput_estimator_with_fit_params(estimator, dataset): - X, y = dataset - some_param = np.zeros_like(X) - estimator.fit(X, y, some_param=some_param) - for dummy_estimator in estimator.estimators_: - assert "some_param" in dummy_estimator._fit_params - - def test_regressor_chain_w_fit_params(): # Make sure fit_params are properly propagated to the sub-estimators rng = np.random.RandomState(0) diff --git a/sklearn/utils/_metadata_requests.py b/sklearn/utils/_metadata_requests.py index f6baa09fa22b7..c8be404d7da5a 100644 --- a/sklearn/utils/_metadata_requests.py +++ b/sklearn/utils/_metadata_requests.py @@ -12,6 +12,7 @@ from collections import namedtuple from typing import Union, Optional from ._bunch import Bunch +from ..exceptions import UnsetMetadataPassedError # This namedtuple is used to store a (mapping, routing) pair. Mapping is a # MethodMapping object, and routing is the output of `get_metadata_routing`. @@ -170,8 +171,9 @@ class MethodMetadataRequest: The name of the method to which these requests belong. """ - def __init__(self, owner, method): + def __init__(self, router, owner, method): self._requests = dict() + self.router = router self.owner = owner self.method = method @@ -213,6 +215,9 @@ def add_request( "{None, True, False}, or a RequestType." ) + if alias != self._requests.get(param, None): + self.router._is_default_request = False + if alias == param: alias = RequestType.REQUESTED @@ -290,6 +295,7 @@ def _route_params(self, params=None): corresponding method. """ self._check_warnings(params=params) + unrequested = dict() params = {} if params is None else params args = {arg: value for arg, value in params.items() if value is not None} res = Bunch() @@ -302,12 +308,19 @@ def _route_params(self, params=None): elif alias == RequestType.REQUESTED and prop in args: res[prop] = args[prop] elif alias == RequestType.ERROR_IF_PASSED and prop in args: - raise ValueError( - f"{prop} is passed but is not explicitly set as " - f"requested or not for {self.owner}.{self.method}" - ) + unrequested[prop] = args[prop] elif alias in args: res[prop] = args[alias] + if unrequested: + raise UnsetMetadataPassedError( + message=( + f"[{', '.join([key for key in unrequested])}] are passed but are" + " not explicitly set as requested or not for" + f" {self.owner}.{self.method}" + ), + unrequested_params=unrequested, + routed_params=res, + ) return res def _serialize(self): @@ -353,8 +366,14 @@ class MetadataRequest: _type = "metadata_request" def __init__(self, owner): + # this is used to check if the user has set any request values + self._is_default_request = False for method in METHODS: - setattr(self, method, MethodMetadataRequest(owner=owner, method=method)) + setattr( + self, + method, + MethodMetadataRequest(router=self, owner=owner, method=method), + ) def _get_param_names(self, method, return_alias, ignore_self=None): """Get names of all metadata that can be consumed or routed by specified \ @@ -569,8 +588,23 @@ def __init__(self, owner): # `add_self()`) is treated differently from the other objects which are # stored in _route_mappings. self._self = None + # this attribute is used to decide if there should be an error raised + # or a FutureWarning if a metadata is passed which is not requested. + self._warn_on = dict() self.owner = owner + @property + def _is_default_request(self): + """Return ``True`` only if all sub-components have default values.""" + if self._self and not self._self._is_default_request: + return False + + for router_mapping in self._route_mappings.values(): + if not router_mapping.router._is_default_request: + return False + + return True + def add_self(self, obj): """Add `self` (as a consumer) to the routing. @@ -766,11 +800,80 @@ def route_params(self, *, caller, params): res[name] = Bunch() for _callee, _caller in mapping: if _caller == caller: - res[name][_callee] = router._route_params( - params=params, method=_callee + res[name][_callee] = self._route_warn_or_error( + child=name, router=router, params=params, method=_callee ) return res + def _route_warn_or_error(self, *, child, router, params, method): + """Route parameters while handling error or deprecation warning choice. + + This method warns instead of raising an error if the parent object + has set ``warn_on`` for the child object's method and the user has not + set any metadata request for that child object. This is used during the + deprecation cycle for backward compatibility. + + Parameters + ---------- + child : str + The name of the child object. + + router : MetadataRouter or MetadataRequest + The router for the child object. + + params : dict + The parameters to be routed. + + method : str + The name of the callee method. + + Returns + ------- + dict + The routed parameters. + """ + try: + routed_params = router._route_params(params=params, method=method) + except UnsetMetadataPassedError as e: + warn_on = self._warn_on.get(child, {}) + if method not in warn_on: + # there is no warn_on set for this method of this child object, + # we raise as usual. + raise + if not router._is_default_request: + # the user has set at least one request value for this child + # object, but not for all of them. Therefore we raise as usual. + raise + # now we move everything which has a warn_on flag from + # `unrequested_params` to routed_params, and then raise if anything + # is left. Otherwise we have a perfectly formed `routed_params` and + # we return that. + warn_on_params = warn_on.get(method, {"params": [], "raise_on": "1.4"}) + warn_keys = list(e.unrequested_params.keys()) + routed_params = e.routed_params + # if params is None, we accept and warn on everything. + warn_params = warn_on_params["params"] + if warn_params is None: + warn_params = warn_keys + + for param in warn_params: + if param in e.unrequested_params: + routed_params[param] = e.unrequested_params.pop(param) + + # check if anything is left, and if yes, we raise as usual + if e.unrequested_params: + raise + + # Finally warn before returning the routed parameters. + warn( + "You are passing metadata for which the request values are not" + f" explicitly set: {', '.join(warn_keys)}. From version" + f" {warn_on_params['raise_on']} this results in the following error:" + f" {str(e)}", + FutureWarning, + ) + return routed_params + def validate_metadata(self, *, method, params): """Validate given metadata for a method. @@ -801,6 +904,55 @@ def validate_metadata(self, *, method, params): "not requested metadata in any object." ) + def warn_on(self, *, child, method, params, raise_on="1.4"): + """Set where deprecation warnings on no set requests should occur. + + This method is used in meta-estimators during the transition period for + backward compatibility. Expected behavior for meta-estimators on a code + such as ``RFE(Ridge()).fit(X, y, sample_weight=sample_weight)`` is to + raise a ``ValueError`` complaining about the fact that ``Ridge()`` has + not explicitly set the request values for ``sample_weight``. However, + this breaks backward compatibility for existing meta-estimators. + + Calling this method on a ``MetadataRouter`` object such as + ``warn_on(child='estimator', method='fit', params=['sample_weight'])`` + tells the router to raise a ``FutureWarning`` instead of a + ``ValueError`` if the child object has no set requests for + ``sample_weight`` during ``fit``. + + You can find more information on how to use this method in the + developer guide: + :ref:`sphx_glr_auto_examples_plot_metadata_routing.py`. + + Parameters + ---------- + child : str + The name of the child object. The names come from the keyword + arguments passed to the ``add`` method. + + method : str + The method for which there should be a ``FutureWarning`` + instead of a ``ValueError`` for given params. + + params : list of str + The list of parameters for which there should be a + ``FutureWarning`` instead of a ``ValueError``. If ``None``, the + rule is applied on all parameters. + + raise_on : str, default="1.4" + The version after which there should be an error. Used in the + warning message to inform users. + + Returns + ------- + self : MetadataRouter + Returns `self`. + """ + if child not in self._warn_on: + self._warn_on[child] = dict() + self._warn_on[child][method] = {"params": params, "raise_on": raise_on} + return self + def _serialize(self): """Serialize the object. @@ -932,6 +1084,7 @@ def func(**kw): for prop, alias in kw.items(): if alias is not UNCHANGED: method_metadata_request.add_request(param=prop, alias=alias) + requests._is_default_request = False instance._metadata_request = requests return instance @@ -1014,7 +1167,7 @@ def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) @classmethod - def _build_request_for_signature(cls, method): + def _build_request_for_signature(cls, router, method): """Build the `MethodMetadataRequest` for a method using its signature. This method takes all arguments from the method signature and uses @@ -1023,6 +1176,8 @@ def _build_request_for_signature(cls, method): Parameters ---------- + router : MetadataRequest + The parent object for the created `MethodMetadataRequest`. method : str The name of the method. @@ -1031,7 +1186,7 @@ def _build_request_for_signature(cls, method): method_request : MethodMetadataRequest The prepared request using the method's signature. """ - mmr = MethodMetadataRequest(owner=cls.__name__, method=method) + mmr = MethodMetadataRequest(router=router, owner=cls.__name__, method=method) # Here we use `isfunction` instead of `ismethod` because calling `getattr` # on a class instead of an instance returns an unbound function. if not hasattr(cls, method) or not inspect.isfunction(getattr(cls, method)): @@ -1058,8 +1213,13 @@ class attributes, as well as determining request keys from method signatures. """ requests = MetadataRequest(owner=cls.__name__) + for method in METHODS: - setattr(requests, method, cls._build_request_for_signature(method=method)) + setattr( + requests, + method, + cls._build_request_for_signature(router=requests, method=method), + ) # Then overwrite those defaults with the ones provided in # __metadata_request__* attributes. Defaults set in @@ -1087,6 +1247,11 @@ class attributes, as well as determining request keys from method method = attr[attr.index(substr) + len(substr) :] for prop, alias in value.items(): getattr(requests, method).add_request(param=prop, alias=alias) + + # this indicates that the user has not set any request values for this + # object + requests._is_default_request = True + return requests def _get_metadata_request(self): @@ -1177,6 +1342,9 @@ def process_routing(obj, method, other_params, **kwargs): # fit_params["sample_weight"] = sample_weight all_params = other_params if other_params is not None else dict() all_params.update(kwargs) + all_params = { + param: value for param, value in all_params.items() if value is not None + } request_routing = get_routing_for_object(obj) request_routing.validate_metadata(params=all_params, method=method) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 52ebd12eb51f3..ea6fb73825f11 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -382,6 +382,13 @@ def _get_check_estimator_ids(obj): return re.sub(r"\s", "", str(obj)) +def _weighted(estimator): + """Request sample_weight for fit and score.""" + return estimator.set_fit_request(sample_weight=True).set_score_request( + sample_weight=True + ) + + def _construct_instance(Estimator): """Construct Estimator instance if possible.""" required_parameters = getattr(Estimator, "_required_parameters", []) @@ -392,19 +399,22 @@ def _construct_instance(Estimator): # For common test, we can enforce using `LinearRegression` that # is the default estimator in `RANSACRegressor` instead of `Ridge`. if issubclass(Estimator, RANSACRegressor): - estimator = Estimator(LinearRegression()) + estimator = Estimator(_weighted(LinearRegression())) elif issubclass(Estimator, RegressorMixin): - estimator = Estimator(Ridge()) + estimator = Estimator(_weighted(Ridge())) elif issubclass(Estimator, SelectFromModel): # Increases coverage because SGDRegressor has partial_fit - estimator = Estimator(SGDRegressor(random_state=0)) + estimator = Estimator(_weighted(SGDRegressor(random_state=0))) else: - estimator = Estimator(LogisticRegression(C=1)) + estimator = Estimator(_weighted(LogisticRegression(C=1))) elif required_parameters in (["estimators"],): # Heterogeneous ensemble classes (i.e. stacking, voting) if issubclass(Estimator, RegressorMixin): estimator = Estimator( - estimators=[("est1", Ridge(alpha=0.1)), ("est2", Ridge(alpha=1))] + estimators=[ + ("est1", _weighted(Ridge(alpha=0.1))), + ("est2", _weighted(Ridge(alpha=1))), + ] ) else: estimator = Estimator(