From 5f8d8711819a47f5748a880d4af9634c3cfc4956 Mon Sep 17 00:00:00 2001 From: Stefanie Senger Date: Tue, 26 Mar 2024 15:28:56 +0100 Subject: [PATCH 01/17] metadata routing for stackingclassifier and stackingregressor --- doc/metadata_routing.rst | 6 +- doc/modules/ensemble.rst | 4 +- doc/whats_new/v1.5.rst | 7 +- sklearn/ensemble/_base.py | 2 +- sklearn/ensemble/_stacking.py | 141 ++++++++++++++++++----- sklearn/ensemble/tests/test_stacking.py | 119 +++++++++++++++++++ sklearn/tests/metadata_routing_common.py | 11 +- 7 files changed, 248 insertions(+), 42 deletions(-) diff --git a/doc/metadata_routing.rst b/doc/metadata_routing.rst index aa6580f52982c..e8e51eeadbe43 100644 --- a/doc/metadata_routing.rst +++ b/doc/metadata_routing.rst @@ -274,6 +274,8 @@ Meta-estimators and functions supporting metadata routing: - :class:`sklearn.calibration.CalibratedClassifierCV` - :class:`sklearn.compose.ColumnTransformer` +- :class:`sklearn.ensemble.StackingClassifier` +- :class:`sklearn.ensemble.StackingRegressor` - :class:`sklearn.ensemble.VotingClassifier` - :class:`sklearn.ensemble.VotingRegressor` - :class:`sklearn.ensemble.BaggingClassifier` @@ -314,13 +316,9 @@ Meta-estimators and tools not supporting metadata routing yet: - :class:`sklearn.covariance.GraphicalLassoCV` - :class:`sklearn.ensemble.AdaBoostClassifier` - :class:`sklearn.ensemble.AdaBoostRegressor` -- :class:`sklearn.ensemble.StackingClassifier` -- :class:`sklearn.ensemble.StackingRegressor` - :class:`sklearn.feature_selection.RFE` - :class:`sklearn.feature_selection.RFECV` - :class:`sklearn.feature_selection.SequentialFeatureSelector` -- :class:`sklearn.impute.IterativeImputer` -- :class:`sklearn.linear_model.RANSACRegressor` - :class:`sklearn.model_selection.learning_curve` - :class:`sklearn.model_selection.permutation_test_score` - :class:`sklearn.model_selection.validation_curve` diff --git a/doc/modules/ensemble.rst b/doc/modules/ensemble.rst index 9af9ff39cac06..f697a330442ba 100644 --- a/doc/modules/ensemble.rst +++ b/doc/modules/ensemble.rst @@ -1536,8 +1536,8 @@ availability, tested in the order of preference: `predict_proba`, `decision_function` and `predict`. A :class:`StackingRegressor` and :class:`StackingClassifier` can be used as -any other regressor or classifier, exposing a `predict`, `predict_proba`, and -`decision_function` methods, e.g.:: +any other regressor or classifier, exposing a `predict`, `predict_proba`, or +`decision_function` method, e.g.:: >>> y_pred = reg.predict(X_test) >>> from sklearn.metrics import r2_score diff --git a/doc/whats_new/v1.5.rst b/doc/whats_new/v1.5.rst index bd03cc743f76e..09dc037b29310 100644 --- a/doc/whats_new/v1.5.rst +++ b/doc/whats_new/v1.5.rst @@ -92,6 +92,11 @@ more details. transformers' ``fit`` and ``fit_transform``. :pr:`28205` by :user:`Stefanie Senger `. + - |Feature| :class:`ensemble.StackingClassifier` and + :class:`ensemble.StackingRegressor` now support metadata routing and pass + ``**fit_params`` to the underlying estimators via their `fit` methods. + :pr:`.....` by :user:`Stefanie Senger `. + Changelog --------- @@ -298,7 +303,7 @@ Changelog :func:`preprocessing.quantile_transform` now supports disabling subsampling explicitly. :pr:`27636` by :user:`Ralph Urlus `. - + :mod:`sklearn.tree` ................... diff --git a/sklearn/ensemble/_base.py b/sklearn/ensemble/_base.py index 8410be81c6cbc..117c8a35eeeff 100644 --- a/sklearn/ensemble/_base.py +++ b/sklearn/ensemble/_base.py @@ -20,7 +20,7 @@ def _fit_single_estimator( estimator, X, y, fit_params, message_clsname=None, message=None ): """Private function used to fit an estimator within a job.""" - # TODO(SLEP6): remove if condition for unrouted sample_weight when metadata + # TODO(SLEP6): remove if-condition for unrouted sample_weight when metadata # routing can't be disabled. if not _routing_enabled() and "sample_weight" in fit_params: try: diff --git a/sklearn/ensemble/_stacking.py b/sklearn/ensemble/_stacking.py index 0f093e8a6b51d..74cb536824449 100644 --- a/sklearn/ensemble/_stacking.py +++ b/sklearn/ensemble/_stacking.py @@ -27,8 +27,11 @@ from ..utils._estimator_html_repr import _VisualBlock from ..utils._param_validation import HasMethods, StrOptions from ..utils.metadata_routing import ( - _raise_for_unsupported_routing, - _RoutingNotSupportedMixin, + MetadataRouter, + MethodMapping, + _raise_for_params, + _routing_enabled, + process_routing, ) from ..utils.metaestimators import available_if from ..utils.multiclass import check_classification_targets, type_of_target @@ -36,6 +39,7 @@ from ..utils.validation import ( _check_feature_names_in, _check_response_method, + _deprecate_positional_args, check_is_fitted, column_or_1d, ) @@ -171,7 +175,7 @@ def _method_name(name, estimator, method): # estimators in Stacking*.estimators are not validated yet prefer_skip_nested_validation=False ) - def fit(self, X, y, sample_weight=None): + def fit(self, X, y, **fit_params): """Fit the estimators. Parameters @@ -192,6 +196,17 @@ def fit(self, X, y, sample_weight=None): when not None, `sample_weight` is passed to all underlying estimators + .. deprecated:: 1.4 + `sample_weight` is deprecated in 1.5 and will be removed in 1.7. + + **fit_params : dict + Dict of metadata, potentially containing sample_weight as a + key-value pair. If sample_weight is not existing, then samples are + equally weighted. Note that sample_weight is supported only if all + underlying estimators support sample weights. + + .. versionadded:: 1.5 + Returns ------- self : object @@ -201,16 +216,19 @@ def fit(self, X, y, sample_weight=None): names, all_estimators = self._validate_estimators() self._validate_final_estimator() - # FIXME: when adding support for metadata routing in Stacking*. - # This is a hotfix to make StackingClassifier and StackingRegressor - # pass the tests despite not supporting metadata routing but sharing - # the same base class with VotingClassifier and VotingRegressor. - fit_params = dict() - if sample_weight is not None: - fit_params["sample_weight"] = sample_weight - stack_method = [self.stack_method] * len(all_estimators) + if _routing_enabled(): + routed_params = process_routing(self, "fit", **fit_params) + else: + routed_params = Bunch() + for name in names: + routed_params[name] = Bunch(fit={}) + if "sample_weight" in fit_params: + routed_params[name].fit["sample_weight"] = fit_params[ + "sample_weight" + ] + if self.cv == "prefit": self.estimators_ = [] for estimator in all_estimators: @@ -222,8 +240,10 @@ def fit(self, X, y, sample_weight=None): # base estimators will be used in transform, predict, and # predict_proba. They are exposed publicly. self.estimators_ = Parallel(n_jobs=self.n_jobs)( - delayed(_fit_single_estimator)(clone(est), X, y, fit_params) - for est in all_estimators + delayed(_fit_single_estimator)( + clone(est), X, y, routed_params[name]["fit"] + ) + for name, est in zip(names, all_estimators) if est != "drop" ) @@ -269,10 +289,10 @@ def fit(self, X, y, sample_weight=None): cv=deepcopy(cv), method=meth, n_jobs=self.n_jobs, - params=fit_params, + params=routed_params[name]["fit"], verbose=self.verbose, ) - for est, meth in zip(all_estimators, self.stack_method_) + for name, est, meth in zip(names, all_estimators, self.stack_method_) if est != "drop" ) @@ -370,7 +390,7 @@ def predict(self, X, **predict_params): Parameters to the `predict` called by the `final_estimator`. Note that this may be used to return uncertainties from some estimators with `return_std` or `return_cov`. Be aware that it will only - accounts for uncertainty in the final estimator. + account for uncertainty in the final estimator. Returns ------- @@ -392,8 +412,29 @@ def _sk_visual_block_with_final_estimator(self, final_estimator): ) return _VisualBlock("serial", (parallel, final_block), dash_wrapped=False) + def get_metadata_routing(self): + """Get metadata routing of this object. + Please check :ref:`User Guide ` on how the routing + mechanism works. + .. versionadded:: 1.5 + Returns + ------- + routing : MetadataRouter + A :class:`~sklearn.utils.metadata_routing.MetadataRouter` encapsulating + routing information. + """ + router = MetadataRouter(owner=self.__class__.__name__) + + # `self.estimators` is a list of (name, est) tuples + for name, estimator in self.estimators: + router.add( + **{name: estimator}, + method_mapping=MethodMapping().add(callee="fit", caller="fit"), + ) + return router + -class StackingClassifier(_RoutingNotSupportedMixin, ClassifierMixin, _BaseStacking): +class StackingClassifier(ClassifierMixin, _BaseStacking): """Stack of estimators with a final classifier. Stacked generalization consists in stacking the output of individual @@ -629,7 +670,11 @@ def _validate_estimators(self): return names, estimators - def fit(self, X, y, sample_weight=None): + # TODO(1.7): remove `sample_weight` from the signature after deprecation + # cycle; pop it from `fit_params` before the `_raise_for_params` check and + # reinsert afterwards, for backwards compatibility + @_deprecate_positional_args(version="1.7") + def fit(self, X, y, sample_weight=None, **fit_params): """Fit the estimators. Parameters @@ -649,12 +694,22 @@ def fit(self, X, y, sample_weight=None): Note that this is supported only if all underlying estimators support sample weights. + **fit_params : dict + Parameters to pass to the underlying estimators. + + .. versionadded:: 1.5 + + Only available if `enable_metadata_routing=True`, which can be + set by using ``sklearn.set_config(enable_metadata_routing=True)``. + See :ref:`Metadata Routing User Guide ` for + more details. + Returns ------- self : object Returns a fitted instance of estimator. """ - _raise_for_unsupported_routing(self, "fit", sample_weight=sample_weight) + _raise_for_params(fit_params, self, "fit") check_classification_targets(y) if type_of_target(y) == "multilabel-indicator": self._label_encoder = [LabelEncoder().fit(yk) for yk in y.T] @@ -669,7 +724,10 @@ def fit(self, X, y, sample_weight=None): self._label_encoder = LabelEncoder().fit(y) self.classes_ = self._label_encoder.classes_ y_encoded = self._label_encoder.transform(y) - return super().fit(X, y_encoded, sample_weight) + + if sample_weight is not None: + fit_params["sample_weight"] = sample_weight + return super().fit(X, y_encoded, **fit_params) @available_if(_estimator_has("predict")) def predict(self, X, **predict_params): @@ -685,7 +743,7 @@ def predict(self, X, **predict_params): Parameters to the `predict` called by the `final_estimator`. Note that this may be used to return uncertainties from some estimators with `return_std` or `return_cov`. Be aware that it will only - accounts for uncertainty in the final estimator. + account for uncertainty in the final estimator. Returns ------- @@ -775,7 +833,7 @@ def _sk_visual_block_(self): return super()._sk_visual_block_with_final_estimator(final_estimator) -class StackingRegressor(_RoutingNotSupportedMixin, RegressorMixin, _BaseStacking): +class StackingRegressor(RegressorMixin, _BaseStacking): """Stack of estimators with a final regressor. Stacked generalization consists in stacking the output of individual @@ -944,7 +1002,11 @@ def _validate_final_estimator(self): ) ) - def fit(self, X, y, sample_weight=None): + # TODO(1.7): remove `sample_weight` from the signature after deprecation + # cycle; pop it from `fit_params` before the `_raise_for_params` check and + # reinsert afterwards, for backwards compatibility + @_deprecate_positional_args(version="1.7") + def fit(self, X, y, sample_weight=None, **fit_params): """Fit the estimators. Parameters @@ -961,14 +1023,26 @@ def fit(self, X, y, sample_weight=None): Note that this is supported only if all underlying estimators support sample weights. + **fit_params : dict + Parameters to pass to the underlying estimators. + + .. versionadded:: 1.5 + + Only available if `enable_metadata_routing=True`, which can be + set by using ``sklearn.set_config(enable_metadata_routing=True)``. + See :ref:`Metadata Routing User Guide ` for + more details. + Returns ------- self : object Returns a fitted instance. """ - _raise_for_unsupported_routing(self, "fit", sample_weight=sample_weight) + _raise_for_params(fit_params, self, "fit") y = column_or_1d(y, warn=True) - return super().fit(X, y, sample_weight) + if sample_weight is not None: + fit_params["sample_weight"] = sample_weight + return super().fit(X, y, **fit_params) def transform(self, X): """Return the predictions for X for each estimator. @@ -986,7 +1060,7 @@ def transform(self, X): """ return self._transform(X) - def fit_transform(self, X, y, sample_weight=None): + def fit_transform(self, X, y, sample_weight=None, **fit_params): """Fit the estimators and return the predictions for X for each estimator. Parameters @@ -1003,12 +1077,25 @@ def fit_transform(self, X, y, sample_weight=None): Note that this is supported only if all underlying estimators support sample weights. + **fit_params : dict + Parameters to pass to the underlying estimators. + + .. versionadded:: 1.5 + + Only available if `enable_metadata_routing=True`, which can be + set by using ``sklearn.set_config(enable_metadata_routing=True)``. + See :ref:`Metadata Routing User Guide ` for + more details. + Returns ------- y_preds : ndarray of shape (n_samples, n_estimators) Prediction outputs for each estimator. """ - return super().fit_transform(X, y, sample_weight=sample_weight) + _raise_for_params(fit_params, self, "fit") + if sample_weight is not None: + fit_params["sample_weight"] = sample_weight + return super().fit_transform(X, y, **fit_params) def _sk_visual_block_(self): # If final_estimator's default changes then this should be diff --git a/sklearn/ensemble/tests/test_stacking.py b/sklearn/ensemble/tests/test_stacking.py index 0d1493529e318..bd61746c5a67d 100644 --- a/sklearn/ensemble/tests/test_stacking.py +++ b/sklearn/ensemble/tests/test_stacking.py @@ -3,6 +3,7 @@ # Authors: Guillaume Lemaitre # License: BSD 3 clause +import re from unittest.mock import Mock import numpy as np @@ -38,6 +39,12 @@ from sklearn.neural_network import MLPClassifier from sklearn.preprocessing import scale from sklearn.svm import SVC, LinearSVC, LinearSVR +from sklearn.tests.metadata_routing_common import ( + ConsumingClassifier, + ConsumingRegressor, + _Registry, + check_recorded_metadata, +) from sklearn.utils._mocking import CheckingClassifier from sklearn.utils._testing import ( assert_allclose, @@ -888,3 +895,115 @@ def test_stacking_final_estimator_attribute_error(): clf.fit(X, y).decision_function(X) assert isinstance(exec_info.value.__cause__, AttributeError) assert inner_msg in str(exec_info.value.__cause__) + + +# Metadata Routing Tests +# ====================== + + +@pytest.mark.parametrize( + "Estimator, Child", + [ + (StackingClassifier, ConsumingClassifier), + (StackingRegressor, ConsumingRegressor), + ], +) +def test_routing_passed_metadata_not_supported(Estimator, Child): + """Test that the right error message is raised when metadata is passed while + not supported when `enable_metadata_routing=False`.""" + + with pytest.raises( + ValueError, match="is only supported if enable_metadata_routing=True" + ): + Estimator(["clf", Child()]).fit( + X_iris, y_iris, sample_weight=[1, 1, 1, 1, 1], metadata="a" + ) + + +@pytest.mark.usefixtures("enable_slep006") +@pytest.mark.parametrize( + "Estimator, Child", + [ + (StackingClassifier, ConsumingClassifier), + (StackingRegressor, ConsumingRegressor), + ], +) +def test_get_metadata_routing_without_fit(Estimator, Child): + # Test that metadata_routing() doesn't raise when called before fit. + est = Estimator([("sub_est", Child())]) + est.get_metadata_routing() + + +@pytest.mark.usefixtures("enable_slep006") +@pytest.mark.parametrize( + "Estimator, Child", + [ + (StackingClassifier, ConsumingClassifier), + (StackingRegressor, ConsumingRegressor), + ], +) +@pytest.mark.parametrize("prop", ["sample_weight", "metadata"]) +def test_metadata_routing_for_stacking_estimators(Estimator, Child, prop): + """Test that metadata is routed correctly for Voting*.""" + sample_weight, metadata = np.ones(X_iris.shape[0]), "a" + + est = Estimator( + [ + ( + "sub_est1", + Child(registry=_Registry()).set_fit_request(**{prop: True}), + ), + ( + "sub_est2", + Child(registry=_Registry()).set_fit_request(**{prop: True}), + ), + ], + final_estimator=Child(registry=_Registry()), + ) + + est.fit( + X_iris, y_iris, **{prop: sample_weight if prop == "sample_weight" else metadata} + ) + est.fit_transform( + X_iris, y_iris, **{prop: sample_weight if prop == "sample_weight" else metadata} + ) + + for estimator in est.estimators: + if prop == "sample_weight": + kwargs = {prop: sample_weight} + else: + kwargs = {prop: metadata} + # access sub-estimator in (name, est) with estimator[1]: + registry = estimator[1].registry + assert len(registry) + for sub_est in registry: + check_recorded_metadata( + obj=sub_est, method="fit", split_params=(prop), **kwargs + ) + + +@pytest.mark.usefixtures("enable_slep006") +@pytest.mark.parametrize( + "Estimator, Child", + [ + (StackingClassifier, ConsumingClassifier), + (StackingRegressor, ConsumingRegressor), + ], +) +def test_metadata_routing_error_for_stacking_estimators(Estimator, Child): + """Test that the right error is raised when metadata is not requested.""" + sample_weight, metadata = np.ones(X_iris.shape[0]), "a" + + est = Estimator([("sub_est", Child())]) + + error_message = ( + "[sample_weight, metadata] are passed but are not explicitly set as requested" + f" or not for {Child.__name__}.fit" + ) + + with pytest.raises(ValueError, match=re.escape(error_message)): + est.fit(X_iris, y_iris, sample_weight=sample_weight, metadata=metadata) + + +# End of Metadata Routing Tests +# ============================= diff --git a/sklearn/tests/metadata_routing_common.py b/sklearn/tests/metadata_routing_common.py index 3df47d3f8dd4e..58f080b8f67dd 100644 --- a/sklearn/tests/metadata_routing_common.py +++ b/sklearn/tests/metadata_routing_common.py @@ -259,13 +259,10 @@ def predict(self, X, sample_weight="default", metadata="default"): return np.zeros(shape=(len(X),)) def predict_proba(self, X, sample_weight="default", metadata="default"): - pass # pragma: no cover - - # uncomment when needed - # record_metadata_not_default( - # self, "predict_proba", sample_weight=sample_weight, metadata=metadata - # ) - # return np.asarray([[0.0, 1.0]] * len(X)) + record_metadata_not_default( + self, "predict_proba", sample_weight=sample_weight, metadata=metadata + ) + return np.asarray([[0.0, 1.0]] * len(X)) def predict_log_proba(self, X, sample_weight="default", metadata="default"): pass # pragma: no cover From db318b72e99554990bc9a440c07721398ff16b4f Mon Sep 17 00:00:00 2001 From: Stefanie Senger Date: Tue, 26 Mar 2024 16:07:20 +0100 Subject: [PATCH 02/17] little fixes --- doc/whats_new/v1.5.rst | 3 +-- sklearn/tests/test_metaestimators_metadata_routing.py | 4 ---- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/doc/whats_new/v1.5.rst b/doc/whats_new/v1.5.rst index 09dc037b29310..641a1e97a9f84 100644 --- a/doc/whats_new/v1.5.rst +++ b/doc/whats_new/v1.5.rst @@ -95,8 +95,7 @@ more details. - |Feature| :class:`ensemble.StackingClassifier` and :class:`ensemble.StackingRegressor` now support metadata routing and pass ``**fit_params`` to the underlying estimators via their `fit` methods. - :pr:`.....` by :user:`Stefanie Senger `. - + :pr:`28701` by :user:`Stefanie Senger `. Changelog --------- diff --git a/sklearn/tests/test_metaestimators_metadata_routing.py b/sklearn/tests/test_metaestimators_metadata_routing.py index 46758315d5c2d..a74a031931aa0 100644 --- a/sklearn/tests/test_metaestimators_metadata_routing.py +++ b/sklearn/tests/test_metaestimators_metadata_routing.py @@ -14,8 +14,6 @@ AdaBoostRegressor, BaggingClassifier, BaggingRegressor, - StackingClassifier, - StackingRegressor, ) from sklearn.exceptions import UnsetMetadataPassedError from sklearn.experimental import ( @@ -402,8 +400,6 @@ def enable_slep006(): RFECV(ConsumingClassifier()), SelfTrainingClassifier(ConsumingClassifier()), SequentialFeatureSelector(ConsumingClassifier()), - StackingClassifier(ConsumingClassifier()), - StackingRegressor(ConsumingRegressor()), TransformedTargetRegressor(), ] From 5706e197c2009fc7b8a1616069a7112cc5649eed Mon Sep 17 00:00:00 2001 From: Stefanie Senger Date: Wed, 27 Mar 2024 10:15:26 +0100 Subject: [PATCH 03/17] fix docstring --- doc/whats_new/v1.5.rst | 2 +- sklearn/ensemble/_stacking.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/doc/whats_new/v1.5.rst b/doc/whats_new/v1.5.rst index 641a1e97a9f84..90708ef36414a 100644 --- a/doc/whats_new/v1.5.rst +++ b/doc/whats_new/v1.5.rst @@ -92,7 +92,7 @@ more details. transformers' ``fit`` and ``fit_transform``. :pr:`28205` by :user:`Stefanie Senger `. - - |Feature| :class:`ensemble.StackingClassifier` and +- |Feature| :class:`ensemble.StackingClassifier` and :class:`ensemble.StackingRegressor` now support metadata routing and pass ``**fit_params`` to the underlying estimators via their `fit` methods. :pr:`28701` by :user:`Stefanie Senger `. diff --git a/sklearn/ensemble/_stacking.py b/sklearn/ensemble/_stacking.py index 74cb536824449..25dcd0b4572af 100644 --- a/sklearn/ensemble/_stacking.py +++ b/sklearn/ensemble/_stacking.py @@ -414,9 +414,12 @@ def _sk_visual_block_with_final_estimator(self, final_estimator): def get_metadata_routing(self): """Get metadata routing of this object. + Please check :ref:`User Guide ` on how the routing mechanism works. + .. versionadded:: 1.5 + Returns ------- routing : MetadataRouter From b97f5af0432390521820d680c2d2ed60f0ca194e Mon Sep 17 00:00:00 2001 From: Stefanie Senger Date: Wed, 27 Mar 2024 12:32:43 +0100 Subject: [PATCH 04/17] correct deprecation version --- sklearn/ensemble/_stacking.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/ensemble/_stacking.py b/sklearn/ensemble/_stacking.py index 25dcd0b4572af..97b8429e8d598 100644 --- a/sklearn/ensemble/_stacking.py +++ b/sklearn/ensemble/_stacking.py @@ -196,7 +196,7 @@ def fit(self, X, y, **fit_params): when not None, `sample_weight` is passed to all underlying estimators - .. deprecated:: 1.4 + .. deprecated:: 1.5 `sample_weight` is deprecated in 1.5 and will be removed in 1.7. **fit_params : dict From 109bbd58c9996fd18bb49454b25da858ff1157ff Mon Sep 17 00:00:00 2001 From: Stefanie Senger <91849487+StefanieSenger@users.noreply.github.com> Date: Wed, 10 Apr 2024 14:04:43 +0200 Subject: [PATCH 05/17] Apply suggestions from code review Co-authored-by: Adrin Jalali --- sklearn/ensemble/_stacking.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/ensemble/_stacking.py b/sklearn/ensemble/_stacking.py index 97b8429e8d598..4d7e91ebe7600 100644 --- a/sklearn/ensemble/_stacking.py +++ b/sklearn/ensemble/_stacking.py @@ -677,7 +677,7 @@ def _validate_estimators(self): # cycle; pop it from `fit_params` before the `_raise_for_params` check and # reinsert afterwards, for backwards compatibility @_deprecate_positional_args(version="1.7") - def fit(self, X, y, sample_weight=None, **fit_params): + def fit(self, X, y, *, sample_weight=None, **fit_params): """Fit the estimators. Parameters @@ -1009,7 +1009,7 @@ def _validate_final_estimator(self): # cycle; pop it from `fit_params` before the `_raise_for_params` check and # reinsert afterwards, for backwards compatibility @_deprecate_positional_args(version="1.7") - def fit(self, X, y, sample_weight=None, **fit_params): + def fit(self, X, y, *, sample_weight=None, **fit_params): """Fit the estimators. Parameters From 98bca6ff6dd1a95d19c50f80298e85fa7e9bd6a9 Mon Sep 17 00:00:00 2001 From: Stefanie Senger Date: Thu, 11 Apr 2024 11:07:07 +0200 Subject: [PATCH 06/17] changes after review --- sklearn/ensemble/_stacking.py | 20 ++++++-------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/sklearn/ensemble/_stacking.py b/sklearn/ensemble/_stacking.py index 4d7e91ebe7600..8b56936f7c814 100644 --- a/sklearn/ensemble/_stacking.py +++ b/sklearn/ensemble/_stacking.py @@ -187,18 +187,6 @@ def fit(self, X, y, **fit_params): y : array-like of shape (n_samples,) Target values. - sample_weight : array-like of shape (n_samples,) or default=None - Sample weights. If None, then samples are equally weighted. - Note that this is supported only if all underlying estimators - support sample weights. - - .. versionchanged:: 0.23 - when not None, `sample_weight` is passed to all underlying - estimators - - .. deprecated:: 1.5 - `sample_weight` is deprecated in 1.5 and will be removed in 1.7. - **fit_params : dict Dict of metadata, potentially containing sample_weight as a key-value pair. If sample_weight is not existing, then samples are @@ -572,7 +560,7 @@ class StackingClassifier(ClassifierMixin, _BaseStacking): ----- When `predict_proba` is used by each estimator (i.e. most of the time for `stack_method='auto'` or specifically for `stack_method='predict_proba'`), - The first column predicted by each estimator will be dropped in the case + the first column predicted by each estimator will be dropped in the case of a binary classification problem. Indeed, both feature will be perfectly collinear. @@ -1063,7 +1051,11 @@ def transform(self, X): """ return self._transform(X) - def fit_transform(self, X, y, sample_weight=None, **fit_params): + # TODO(1.7): remove `sample_weight` from the signature after deprecation + # cycle; pop it from `fit_params` before the `_raise_for_params` check and + # reinsert afterwards, for backwards compatibility + @_deprecate_positional_args(version="1.7") + def fit_transform(self, X, y, *, sample_weight=None, **fit_params): """Fit the estimators and return the predictions for X for each estimator. Parameters From aaf8df1d288c5cd194d3f8b66b3fc6630c38f23a Mon Sep 17 00:00:00 2001 From: Stefanie Senger Date: Fri, 12 Apr 2024 13:23:06 +0200 Subject: [PATCH 07/17] fix CI --- sklearn/ensemble/tests/test_stacking.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/ensemble/tests/test_stacking.py b/sklearn/ensemble/tests/test_stacking.py index bd61746c5a67d..b2cef3626b8db 100644 --- a/sklearn/ensemble/tests/test_stacking.py +++ b/sklearn/ensemble/tests/test_stacking.py @@ -998,7 +998,7 @@ def test_metadata_routing_error_for_stacking_estimators(Estimator, Child): error_message = ( "[sample_weight, metadata] are passed but are not explicitly set as requested" - f" or not for {Child.__name__}.fit" + f" or not requested for {Child.__name__}.fit" ) with pytest.raises(ValueError, match=re.escape(error_message)): From 2602fb1e6bcbdbcbaaf561f1647dc5120b8b3dbf Mon Sep 17 00:00:00 2001 From: Stefanie Senger Date: Mon, 15 Apr 2024 11:36:42 +0200 Subject: [PATCH 08/17] buggy routing for predict --- sklearn/ensemble/_stacking.py | 75 +++++++++++++++++++++++- sklearn/ensemble/tests/test_stacking.py | 21 +++++-- sklearn/tests/metadata_routing_common.py | 2 +- 3 files changed, 91 insertions(+), 7 deletions(-) diff --git a/sklearn/ensemble/_stacking.py b/sklearn/ensemble/_stacking.py index 8b56936f7c814..4b9b410211fb2 100644 --- a/sklearn/ensemble/_stacking.py +++ b/sklearn/ensemble/_stacking.py @@ -422,6 +422,12 @@ def get_metadata_routing(self): **{name: estimator}, method_mapping=MethodMapping().add(callee="fit", caller="fit"), ) + + router.add( + final_estimator_=self.final_estimator_, + method_mapping=MethodMapping().add(caller="predict", callee="predict"), + ) + return router @@ -736,12 +742,33 @@ def predict(self, X, **predict_params): with `return_std` or `return_cov`. Be aware that it will only account for uncertainty in the final estimator. + - If `enable_metadata_routing=False` (default): + Parameters directly passed to the `predict` method of the + `final_estimator`- + + - If `enable_metadata_routing=True`: + Parameters safely routed to the `predict` method of the + `final_estimator`. See :ref:`Metadata Routing User Guide + ` for more details. + + .. versionchanged:: 1.5 + `**predict_params` can be routed via metadata routing API. + + Returns ------- y_pred : ndarray of shape (n_samples,) or (n_samples, n_output) Predicted targets. """ - y_pred = super().predict(X, **predict_params) + if _routing_enabled(): + routed_params = process_routing(self, "predict", **predict_params) + else: + # TODO(SLEP6): remove when metadata routing cannot be disabled. + routed_params = Bunch() + routed_params.final_estimator_ = Bunch(predict={}) + routed_params.final_estimator_.predict = predict_params + + y_pred = super().predict(X, **routed_params.final_estimator_["predict"]) if isinstance(self._label_encoder, list): # Handle the multilabel-indicator case y_pred = np.array( @@ -1092,6 +1119,52 @@ def fit_transform(self, X, y, *, sample_weight=None, **fit_params): fit_params["sample_weight"] = sample_weight return super().fit_transform(X, y, **fit_params) + @available_if(_estimator_has("predict")) + def predict(self, X, **predict_params): + """Predict target for X. + + Parameters + ---------- + X : {array-like, sparse matrix} of shape (n_samples, n_features) + Training vectors, where `n_samples` is the number of samples and + `n_features` is the number of features. + + **predict_params : dict of str -> obj + Parameters to the `predict` called by the `final_estimator`. Note + that this may be used to return uncertainties from some estimators + with `return_std` or `return_cov`. Be aware that it will only + account for uncertainty in the final estimator. + + - If `enable_metadata_routing=False` (default): + Parameters directly passed to the `predict` method of the + `final_estimator`- + + - If `enable_metadata_routing=True`: + Parameters safely routed to the `predict` method of the + `final_estimator`. See :ref:`Metadata Routing User Guide + ` for more details. + + .. versionchanged:: 1.5 + `**predict_params` can be routed via metadata routing API. + + + Returns + ------- + y_pred : ndarray of shape (n_samples,) or (n_samples, n_output) + Predicted targets. + """ + if _routing_enabled(): + routed_params = process_routing(self, "predict", **predict_params) + else: + # TODO(SLEP6): remove when metadata routing cannot be disabled. + routed_params = Bunch() + routed_params.final_estimator_ = Bunch(predict={}) + routed_params.final_estimator_.predict = predict_params + + y_pred = super().predict(X, **routed_params.final_estimator_["predict"]) + + return y_pred + def _sk_visual_block_(self): # If final_estimator's default changes then this should be # updated. diff --git a/sklearn/ensemble/tests/test_stacking.py b/sklearn/ensemble/tests/test_stacking.py index b2cef3626b8db..c0c750e3d350f 100644 --- a/sklearn/ensemble/tests/test_stacking.py +++ b/sklearn/ensemble/tests/test_stacking.py @@ -958,7 +958,7 @@ def test_metadata_routing_for_stacking_estimators(Estimator, Child, prop): Child(registry=_Registry()).set_fit_request(**{prop: True}), ), ], - final_estimator=Child(registry=_Registry()), + final_estimator=Child(registry=_Registry()).set_predict_request(**{prop: True}), ) est.fit( @@ -968,11 +968,15 @@ def test_metadata_routing_for_stacking_estimators(Estimator, Child, prop): X_iris, y_iris, **{prop: sample_weight if prop == "sample_weight" else metadata} ) + est.predict( + X_iris, **{prop: sample_weight if prop == "sample_weight" else metadata} + ) + + if prop == "sample_weight": + kwargs = {prop: sample_weight} + else: + kwargs = {prop: metadata} for estimator in est.estimators: - if prop == "sample_weight": - kwargs = {prop: sample_weight} - else: - kwargs = {prop: metadata} # access sub-estimator in (name, est) with estimator[1]: registry = estimator[1].registry assert len(registry) @@ -980,6 +984,13 @@ def test_metadata_routing_for_stacking_estimators(Estimator, Child, prop): check_recorded_metadata( obj=sub_est, method="fit", split_params=(prop), **kwargs ) + # access final_estimator: + registry = est.final_estimator_.registry + assert len(registry) + for sub_est in registry: + check_recorded_metadata( + obj=sub_est, method="predict", split_params=(prop), **kwargs + ) @pytest.mark.usefixtures("enable_slep006") diff --git a/sklearn/tests/metadata_routing_common.py b/sklearn/tests/metadata_routing_common.py index 58f080b8f67dd..c3c1adb9b5aab 100644 --- a/sklearn/tests/metadata_routing_common.py +++ b/sklearn/tests/metadata_routing_common.py @@ -256,7 +256,7 @@ def predict(self, X, sample_weight="default", metadata="default"): record_metadata_not_default( self, "predict", sample_weight=sample_weight, metadata=metadata ) - return np.zeros(shape=(len(X),)) + return np.zeros(shape=(len(X),), dtype="int8") def predict_proba(self, X, sample_weight="default", metadata="default"): record_metadata_not_default( From 733bcacb5dc03dc9b0ba45a603accaafd6615e0a Mon Sep 17 00:00:00 2001 From: Stefanie Senger Date: Mon, 15 Apr 2024 22:30:06 +0200 Subject: [PATCH 09/17] only test last entry in registry for final estimator --- sklearn/ensemble/_stacking.py | 7 ++++++- sklearn/ensemble/tests/test_stacking.py | 7 +++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/sklearn/ensemble/_stacking.py b/sklearn/ensemble/_stacking.py index 4b9b410211fb2..3e3b0f79a9ecd 100644 --- a/sklearn/ensemble/_stacking.py +++ b/sklearn/ensemble/_stacking.py @@ -423,8 +423,13 @@ def get_metadata_routing(self): method_mapping=MethodMapping().add(callee="fit", caller="fit"), ) + try: + final_estimator_ = self.final_estimator_ + except AttributeError: + final_estimator_ = self.final_estimator + router.add( - final_estimator_=self.final_estimator_, + final_estimator_=final_estimator_, method_mapping=MethodMapping().add(caller="predict", callee="predict"), ) diff --git a/sklearn/ensemble/tests/test_stacking.py b/sklearn/ensemble/tests/test_stacking.py index c0c750e3d350f..730277cc638a4 100644 --- a/sklearn/ensemble/tests/test_stacking.py +++ b/sklearn/ensemble/tests/test_stacking.py @@ -987,10 +987,9 @@ def test_metadata_routing_for_stacking_estimators(Estimator, Child, prop): # access final_estimator: registry = est.final_estimator_.registry assert len(registry) - for sub_est in registry: - check_recorded_metadata( - obj=sub_est, method="predict", split_params=(prop), **kwargs - ) + check_recorded_metadata( + obj=registry[-1], method="predict", split_params=(prop), **kwargs + ) @pytest.mark.usefixtures("enable_slep006") From c6762eb65455c2f7573529de07a119c956b3343d Mon Sep 17 00:00:00 2001 From: Stefanie Senger Date: Tue, 16 Apr 2024 09:27:31 +0200 Subject: [PATCH 10/17] fix docstring --- sklearn/ensemble/_stacking.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/sklearn/ensemble/_stacking.py b/sklearn/ensemble/_stacking.py index 3e3b0f79a9ecd..7f4419fe1b493 100644 --- a/sklearn/ensemble/_stacking.py +++ b/sklearn/ensemble/_stacking.py @@ -749,12 +749,11 @@ def predict(self, X, **predict_params): - If `enable_metadata_routing=False` (default): Parameters directly passed to the `predict` method of the - `final_estimator`- + `final_estimator`. - - If `enable_metadata_routing=True`: - Parameters safely routed to the `predict` method of the - `final_estimator`. See :ref:`Metadata Routing User Guide - ` for more details. + - If `enable_metadata_routing=True`: Parameters safely routed to + the `predict` method of the `final_estimator`. See :ref:`Metadata + Routing User Guide ` for more details. .. versionchanged:: 1.5 `**predict_params` can be routed via metadata routing API. @@ -1142,12 +1141,11 @@ def predict(self, X, **predict_params): - If `enable_metadata_routing=False` (default): Parameters directly passed to the `predict` method of the - `final_estimator`- + `final_estimator`. - - If `enable_metadata_routing=True`: - Parameters safely routed to the `predict` method of the - `final_estimator`. See :ref:`Metadata Routing User Guide - ` for more details. + - If `enable_metadata_routing=True`: Parameters safely routed to + the `predict` method of the `final_estimator`. See :ref:`Metadata + Routing User Guide ` for more details. .. versionchanged:: 1.5 `**predict_params` can be routed via metadata routing API. From 4c271d2ecdbcf639e3675e79839ce139531d3453 Mon Sep 17 00:00:00 2001 From: Stefanie Senger Date: Wed, 17 Apr 2024 09:33:54 +0200 Subject: [PATCH 11/17] fix docstring --- sklearn/ensemble/_stacking.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/sklearn/ensemble/_stacking.py b/sklearn/ensemble/_stacking.py index 7f4419fe1b493..99df0698a2a2d 100644 --- a/sklearn/ensemble/_stacking.py +++ b/sklearn/ensemble/_stacking.py @@ -758,7 +758,6 @@ def predict(self, X, **predict_params): .. versionchanged:: 1.5 `**predict_params` can be routed via metadata routing API. - Returns ------- y_pred : ndarray of shape (n_samples,) or (n_samples, n_output) @@ -1150,7 +1149,6 @@ def predict(self, X, **predict_params): .. versionchanged:: 1.5 `**predict_params` can be routed via metadata routing API. - Returns ------- y_pred : ndarray of shape (n_samples,) or (n_samples, n_output) From b8dd378c05f19fe860bdc49e8b5d56eadadf809a Mon Sep 17 00:00:00 2001 From: Stefanie Senger <91849487+StefanieSenger@users.noreply.github.com> Date: Wed, 8 May 2024 10:35:12 +0200 Subject: [PATCH 12/17] Apply suggestions from code review Co-authored-by: Omar Salman --- sklearn/ensemble/_stacking.py | 2 +- sklearn/ensemble/tests/test_stacking.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/ensemble/_stacking.py b/sklearn/ensemble/_stacking.py index dba9a4b6f7843..d91548fb24dbf 100644 --- a/sklearn/ensemble/_stacking.py +++ b/sklearn/ensemble/_stacking.py @@ -189,7 +189,7 @@ def fit(self, X, y, **fit_params): **fit_params : dict Dict of metadata, potentially containing sample_weight as a - key-value pair. If sample_weight is not existing, then samples are + key-value pair. If sample_weight is not present, then samples are equally weighted. Note that sample_weight is supported only if all underlying estimators support sample weights. diff --git a/sklearn/ensemble/tests/test_stacking.py b/sklearn/ensemble/tests/test_stacking.py index dacd3500a7ad1..7e25630c795b6 100644 --- a/sklearn/ensemble/tests/test_stacking.py +++ b/sklearn/ensemble/tests/test_stacking.py @@ -944,7 +944,7 @@ def test_get_metadata_routing_without_fit(Estimator, Child): ) @pytest.mark.parametrize("prop", ["sample_weight", "metadata"]) def test_metadata_routing_for_stacking_estimators(Estimator, Child, prop): - """Test that metadata is routed correctly for Voting*.""" + """Test that metadata is routed correctly for Stacking*.""" sample_weight, metadata = np.ones(X_iris.shape[0]), "a" est = Estimator( From 38a395701f49acf223e79303fc09dca4fb97017b Mon Sep 17 00:00:00 2001 From: Stefanie Senger <91849487+StefanieSenger@users.noreply.github.com> Date: Wed, 8 May 2024 14:43:40 +0200 Subject: [PATCH 13/17] Update sklearn/ensemble/tests/test_stacking.py Co-authored-by: Omar Salman --- sklearn/ensemble/tests/test_stacking.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sklearn/ensemble/tests/test_stacking.py b/sklearn/ensemble/tests/test_stacking.py index 7e25630c795b6..c368aba057694 100644 --- a/sklearn/ensemble/tests/test_stacking.py +++ b/sklearn/ensemble/tests/test_stacking.py @@ -942,10 +942,9 @@ def test_get_metadata_routing_without_fit(Estimator, Child): (StackingRegressor, ConsumingRegressor), ], ) -@pytest.mark.parametrize("prop", ["sample_weight", "metadata"]) -def test_metadata_routing_for_stacking_estimators(Estimator, Child, prop): +@pytest.mark.parametrize("prop, prop_value", [("sample_weight", np.ones(X_iris.shape[0])), ("metadata", "a")]) +def test_metadata_routing_for_stacking_estimators(Estimator, Child, prop, prop_value): """Test that metadata is routed correctly for Stacking*.""" - sample_weight, metadata = np.ones(X_iris.shape[0]), "a" est = Estimator( [ From 4a8ea4918905c4e846fe6fb3f651d97e0d2e0417 Mon Sep 17 00:00:00 2001 From: Stefanie Senger Date: Wed, 8 May 2024 15:29:27 +0200 Subject: [PATCH 14/17] simplify test --- sklearn/ensemble/tests/test_stacking.py | 24 ++++++++---------------- 1 file changed, 8 insertions(+), 16 deletions(-) diff --git a/sklearn/ensemble/tests/test_stacking.py b/sklearn/ensemble/tests/test_stacking.py index c368aba057694..1c038cd469216 100644 --- a/sklearn/ensemble/tests/test_stacking.py +++ b/sklearn/ensemble/tests/test_stacking.py @@ -942,7 +942,9 @@ def test_get_metadata_routing_without_fit(Estimator, Child): (StackingRegressor, ConsumingRegressor), ], ) -@pytest.mark.parametrize("prop, prop_value", [("sample_weight", np.ones(X_iris.shape[0])), ("metadata", "a")]) +@pytest.mark.parametrize( + "prop, prop_value", [("sample_weight", np.ones(X_iris.shape[0])), ("metadata", "a")] +) def test_metadata_routing_for_stacking_estimators(Estimator, Child, prop, prop_value): """Test that metadata is routed correctly for Stacking*.""" @@ -960,34 +962,24 @@ def test_metadata_routing_for_stacking_estimators(Estimator, Child, prop, prop_v final_estimator=Child(registry=_Registry()).set_predict_request(**{prop: True}), ) - est.fit( - X_iris, y_iris, **{prop: sample_weight if prop == "sample_weight" else metadata} - ) - est.fit_transform( - X_iris, y_iris, **{prop: sample_weight if prop == "sample_weight" else metadata} - ) + est.fit(X_iris, y_iris, **{prop: prop_value}) + est.fit_transform(X_iris, y_iris, **{prop: prop_value}) - est.predict( - X_iris, **{prop: sample_weight if prop == "sample_weight" else metadata} - ) + est.predict(X_iris, **{prop: prop_value}) - if prop == "sample_weight": - kwargs = {prop: sample_weight} - else: - kwargs = {prop: metadata} for estimator in est.estimators: # access sub-estimator in (name, est) with estimator[1]: registry = estimator[1].registry assert len(registry) for sub_est in registry: check_recorded_metadata( - obj=sub_est, method="fit", split_params=(prop), **kwargs + obj=sub_est, method="fit", split_params=(prop), **{prop: prop_value} ) # access final_estimator: registry = est.final_estimator_.registry assert len(registry) check_recorded_metadata( - obj=registry[-1], method="predict", split_params=(prop), **kwargs + obj=registry[-1], method="predict", split_params=(prop), **{prop: prop_value} ) From 75c23d30d7cefd7c9906674b27dcef8d28fa523a Mon Sep 17 00:00:00 2001 From: Stefanie Senger <91849487+StefanieSenger@users.noreply.github.com> Date: Mon, 13 May 2024 08:46:40 +0200 Subject: [PATCH 15/17] Apply suggestions from code review Co-authored-by: Omar Salman --- sklearn/ensemble/_stacking.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/sklearn/ensemble/_stacking.py b/sklearn/ensemble/_stacking.py index d91548fb24dbf..9dc93b6c35975 100644 --- a/sklearn/ensemble/_stacking.py +++ b/sklearn/ensemble/_stacking.py @@ -193,7 +193,7 @@ def fit(self, X, y, **fit_params): equally weighted. Note that sample_weight is supported only if all underlying estimators support sample weights. - .. versionadded:: 1.5 + .. versionadded:: 1.6 Returns ------- @@ -406,7 +406,7 @@ def get_metadata_routing(self): Please check :ref:`User Guide ` on how the routing mechanism works. - .. versionadded:: 1.5 + .. versionadded:: 1.6 Returns ------- @@ -699,7 +699,7 @@ def fit(self, X, y, *, sample_weight=None, **fit_params): **fit_params : dict Parameters to pass to the underlying estimators. - .. versionadded:: 1.5 + .. versionadded:: 1.6 Only available if `enable_metadata_routing=True`, which can be set by using ``sklearn.set_config(enable_metadata_routing=True)``. @@ -755,7 +755,7 @@ def predict(self, X, **predict_params): the `predict` method of the `final_estimator`. See :ref:`Metadata Routing User Guide ` for more details. - .. versionchanged:: 1.5 + .. versionchanged:: 1.6 `**predict_params` can be routed via metadata routing API. Returns @@ -1047,7 +1047,7 @@ def fit(self, X, y, *, sample_weight=None, **fit_params): **fit_params : dict Parameters to pass to the underlying estimators. - .. versionadded:: 1.5 + .. versionadded:: 1.6 Only available if `enable_metadata_routing=True`, which can be set by using ``sklearn.set_config(enable_metadata_routing=True)``. @@ -1105,7 +1105,7 @@ def fit_transform(self, X, y, *, sample_weight=None, **fit_params): **fit_params : dict Parameters to pass to the underlying estimators. - .. versionadded:: 1.5 + .. versionadded:: 1.6 Only available if `enable_metadata_routing=True`, which can be set by using ``sklearn.set_config(enable_metadata_routing=True)``. @@ -1146,7 +1146,7 @@ def predict(self, X, **predict_params): the `predict` method of the `final_estimator`. See :ref:`Metadata Routing User Guide ` for more details. - .. versionchanged:: 1.5 + .. versionchanged:: 1.6 `**predict_params` can be routed via metadata routing API. Returns From ba7359e9e7768fa478c77e8e022befdf0ce4d43a Mon Sep 17 00:00:00 2001 From: Stefanie Senger Date: Mon, 13 May 2024 08:49:28 +0200 Subject: [PATCH 16/17] update changelog --- doc/whats_new/v1.5.rst | 5 ----- doc/whats_new/v1.6.rst | 15 ++++++++++++++- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/doc/whats_new/v1.5.rst b/doc/whats_new/v1.5.rst index 2c63881912f14..e50309a330e39 100644 --- a/doc/whats_new/v1.5.rst +++ b/doc/whats_new/v1.5.rst @@ -139,11 +139,6 @@ more details. transformers' ``fit`` and ``fit_transform``. :pr:`28205` by :user:`Stefanie Senger `. -- |Feature| :class:`ensemble.StackingClassifier` and - :class:`ensemble.StackingRegressor` now support metadata routing and pass - ``**fit_params`` to the underlying estimators via their `fit` methods. - :pr:`28701` by :user:`Stefanie Senger `. - - |Fix| Fix an issue when resolving default routing requests set via class attributes. :pr:`28435` by `Adrin Jalali`_. diff --git a/doc/whats_new/v1.6.rst b/doc/whats_new/v1.6.rst index 6eda6717b3d1b..6251e3c949a34 100644 --- a/doc/whats_new/v1.6.rst +++ b/doc/whats_new/v1.6.rst @@ -38,7 +38,20 @@ See :ref:`array_api` for more details. **Classes:** -- +- + + +Metadata Routing +---------------- + +The following models now support metadata routing in one or more of their +methods. Refer to the :ref:`Metadata Routing User Guide ` for +more details. + +- |Feature| :class:`ensemble.StackingClassifier` and + :class:`ensemble.StackingRegressor` now support metadata routing and pass + ``**fit_params`` to the underlying estimators via their `fit` methods. + :pr:`28701` by :user:`Stefanie Senger `. Changelog --------- From 65322ac2015f7c45da602aa53c5e37e1a5773781 Mon Sep 17 00:00:00 2001 From: Stefanie Senger Date: Mon, 13 May 2024 08:50:33 +0200 Subject: [PATCH 17/17] delete line --- doc/whats_new/v1.6.rst | 1 - 1 file changed, 1 deletion(-) diff --git a/doc/whats_new/v1.6.rst b/doc/whats_new/v1.6.rst index 6251e3c949a34..5000866b59c03 100644 --- a/doc/whats_new/v1.6.rst +++ b/doc/whats_new/v1.6.rst @@ -40,7 +40,6 @@ See :ref:`array_api` for more details. - - Metadata Routing ----------------