diff --git a/doc/metadata_routing.rst b/doc/metadata_routing.rst index 7f2719606a4dd..4174f95e65ba0 100644 --- a/doc/metadata_routing.rst +++ b/doc/metadata_routing.rst @@ -274,6 +274,8 @@ Meta-estimators and functions supporting metadata routing: - :class:`sklearn.calibration.CalibratedClassifierCV` - :class:`sklearn.compose.ColumnTransformer` +- :class:`sklearn.ensemble.VotingClassifier` +- :class:`sklearn.ensemble.VotingRegressor` - :class:`sklearn.ensemble.BaggingClassifier` - :class:`sklearn.ensemble.BaggingRegressor` - :class:`sklearn.feature_selection.SelectFromModel` @@ -310,8 +312,6 @@ Meta-estimators and tools not supporting metadata routing yet: - :class:`sklearn.ensemble.AdaBoostRegressor` - :class:`sklearn.ensemble.StackingClassifier` - :class:`sklearn.ensemble.StackingRegressor` -- :class:`sklearn.ensemble.VotingClassifier` -- :class:`sklearn.ensemble.VotingRegressor` - :class:`sklearn.feature_selection.RFE` - :class:`sklearn.feature_selection.RFECV` - :class:`sklearn.feature_selection.SequentialFeatureSelector` diff --git a/doc/whats_new/v1.5.rst b/doc/whats_new/v1.5.rst index ba50134f744bb..53f0fbd8a74e8 100644 --- a/doc/whats_new/v1.5.rst +++ b/doc/whats_new/v1.5.rst @@ -48,6 +48,11 @@ more details. via their `fit` methods. :pr:`28432` by :user:`Adam Li ` and :user:`Benjamin Bossan `. +- |Feature| :class:`ensemble.VotingClassifier` and + :class:`ensemble.VotingRegressor` now support metadata routing and pass + ``**fit_params`` to the underlying estimators via their `fit` methods. + :pr:`27584` by :user:`Stefanie Senger `. + Changelog --------- diff --git a/sklearn/compose/tests/test_column_transformer.py b/sklearn/compose/tests/test_column_transformer.py index 5dd84d3c057a2..ec22ddf2f3ae0 100644 --- a/sklearn/compose/tests/test_column_transformer.py +++ b/sklearn/compose/tests/test_column_transformer.py @@ -2517,7 +2517,6 @@ def transform(self, X, sample_weight=None, metadata=None): X = np.array([[0, 1, 2], [2, 4, 6]]).T y = [1, 2, 3] - _Registry() sample_weight, metadata = [1], "a" trs = ColumnTransformer( [ diff --git a/sklearn/ensemble/_base.py b/sklearn/ensemble/_base.py index 1fa05d90975cd..8410be81c6cbc 100644 --- a/sklearn/ensemble/_base.py +++ b/sklearn/ensemble/_base.py @@ -12,17 +12,20 @@ from ..base import BaseEstimator, MetaEstimatorMixin, clone, is_classifier, is_regressor from ..utils import Bunch, _print_elapsed_time, check_random_state from ..utils._tags import _safe_tags +from ..utils.metadata_routing import _routing_enabled from ..utils.metaestimators import _BaseComposition def _fit_single_estimator( - estimator, X, y, sample_weight=None, message_clsname=None, message=None + estimator, X, y, fit_params, message_clsname=None, message=None ): """Private function used to fit an estimator within a job.""" - if sample_weight is not None: + # TODO(SLEP6): remove if condition for unrouted sample_weight when metadata + # routing can't be disabled. + if not _routing_enabled() and "sample_weight" in fit_params: try: with _print_elapsed_time(message_clsname, message): - estimator.fit(X, y, sample_weight=sample_weight) + estimator.fit(X, y, sample_weight=fit_params["sample_weight"]) except TypeError as exc: if "unexpected keyword argument 'sample_weight'" in str(exc): raise TypeError( @@ -33,7 +36,7 @@ def _fit_single_estimator( raise else: with _print_elapsed_time(message_clsname, message): - estimator.fit(X, y) + estimator.fit(X, y, **fit_params) return estimator diff --git a/sklearn/ensemble/_stacking.py b/sklearn/ensemble/_stacking.py index c028e85895b14..0f093e8a6b51d 100644 --- a/sklearn/ensemble/_stacking.py +++ b/sklearn/ensemble/_stacking.py @@ -201,6 +201,14 @@ def fit(self, X, y, sample_weight=None): names, all_estimators = self._validate_estimators() self._validate_final_estimator() + # FIXME: when adding support for metadata routing in Stacking*. + # This is a hotfix to make StackingClassifier and StackingRegressor + # pass the tests despite not supporting metadata routing but sharing + # the same base class with VotingClassifier and VotingRegressor. + fit_params = dict() + if sample_weight is not None: + fit_params["sample_weight"] = sample_weight + stack_method = [self.stack_method] * len(all_estimators) if self.cv == "prefit": @@ -214,7 +222,7 @@ def fit(self, X, y, sample_weight=None): # base estimators will be used in transform, predict, and # predict_proba. They are exposed publicly. self.estimators_ = Parallel(n_jobs=self.n_jobs)( - delayed(_fit_single_estimator)(clone(est), X, y, sample_weight) + delayed(_fit_single_estimator)(clone(est), X, y, fit_params) for est in all_estimators if est != "drop" ) @@ -253,9 +261,6 @@ def fit(self, X, y, sample_weight=None): if hasattr(cv, "random_state") and cv.random_state is None: cv.random_state = np.random.RandomState() - fit_params = ( - {"sample_weight": sample_weight} if sample_weight is not None else None - ) predictions = Parallel(n_jobs=self.n_jobs)( delayed(cross_val_predict)( clone(est), @@ -280,9 +285,7 @@ def fit(self, X, y, sample_weight=None): ] X_meta = self._concatenate_predictions(X, predictions) - _fit_single_estimator( - self.final_estimator_, X_meta, y, sample_weight=sample_weight - ) + _fit_single_estimator(self.final_estimator_, X_meta, y, fit_params=fit_params) return self diff --git a/sklearn/ensemble/_voting.py b/sklearn/ensemble/_voting.py index 48cb104019e85..4e7c7af369ab0 100644 --- a/sklearn/ensemble/_voting.py +++ b/sklearn/ensemble/_voting.py @@ -31,14 +31,18 @@ from ..utils._estimator_html_repr import _VisualBlock from ..utils._param_validation import StrOptions from ..utils.metadata_routing import ( - _raise_for_unsupported_routing, - _RoutingNotSupportedMixin, + MetadataRouter, + MethodMapping, + _raise_for_params, + _routing_enabled, + process_routing, ) from ..utils.metaestimators import available_if from ..utils.multiclass import type_of_target from ..utils.parallel import Parallel, delayed from ..utils.validation import ( _check_feature_names_in, + _deprecate_positional_args, check_is_fitted, column_or_1d, ) @@ -76,7 +80,7 @@ def _predict(self, X): return np.asarray([est.predict(X) for est in self.estimators_]).T @abstractmethod - def fit(self, X, y, sample_weight=None): + def fit(self, X, y, **fit_params): """Get common fit operations.""" names, clfs = self._validate_estimators() @@ -86,16 +90,27 @@ def fit(self, X, y, sample_weight=None): f" {len(self.weights)} weights, {len(self.estimators)} estimators" ) + if _routing_enabled(): + routed_params = process_routing(self, "fit", **fit_params) + else: + routed_params = Bunch() + for name in names: + routed_params[name] = Bunch(fit={}) + if "sample_weight" in fit_params: + routed_params[name].fit["sample_weight"] = fit_params[ + "sample_weight" + ] + self.estimators_ = Parallel(n_jobs=self.n_jobs)( delayed(_fit_single_estimator)( clone(clf), X, y, - sample_weight=sample_weight, + fit_params=routed_params[name]["fit"], message_clsname="Voting", - message=self._log_message(names[idx], idx + 1, len(clfs)), + message=self._log_message(name, idx + 1, len(clfs)), ) - for idx, clf in enumerate(clfs) + for idx, (name, clf) in enumerate(zip(names, clfs)) if clf != "drop" ) @@ -156,8 +171,32 @@ def _sk_visual_block_(self): names, estimators = zip(*self.estimators) return _VisualBlock("parallel", estimators, names=names) + def get_metadata_routing(self): + """Get metadata routing of this object. + + Please check :ref:`User Guide ` on how the routing + mechanism works. + + .. versionadded:: 1.5 + + Returns + ------- + routing : MetadataRouter + A :class:`~sklearn.utils.metadata_routing.MetadataRouter` encapsulating + routing information. + """ + router = MetadataRouter(owner=self.__class__.__name__) + + # `self.estimators` is a list of (name, est) tuples + for name, estimator in self.estimators: + router.add( + **{name: estimator}, + method_mapping=MethodMapping().add(callee="fit", caller="fit"), + ) + return router + -class VotingClassifier(_RoutingNotSupportedMixin, ClassifierMixin, _BaseVoting): +class VotingClassifier(ClassifierMixin, _BaseVoting): """Soft Voting/Majority Rule classifier for unfitted estimators. Read more in the :ref:`User Guide `. @@ -317,7 +356,11 @@ def __init__( # estimators in VotingClassifier.estimators are not validated yet prefer_skip_nested_validation=False ) - def fit(self, X, y, sample_weight=None): + # TODO(1.7): remove `sample_weight` from the signature after deprecation + # cycle; pop it from `fit_params` before the `_raise_for_params` check and + # reinsert later, for backwards compatibility + @_deprecate_positional_args(version="1.7") + def fit(self, X, y, *, sample_weight=None, **fit_params): """Fit the estimators. Parameters @@ -336,12 +379,23 @@ def fit(self, X, y, sample_weight=None): .. versionadded:: 0.18 + **fit_params : dict + Parameters to pass to the underlying estimators. + + .. versionadded:: 1.5 + + Only available if `enable_metadata_routing=True`, + which can be set by using + ``sklearn.set_config(enable_metadata_routing=True)``. + See :ref:`Metadata Routing User Guide ` for + more details. + Returns ------- self : object Returns the instance itself. """ - _raise_for_unsupported_routing(self, "fit", sample_weight=sample_weight) + _raise_for_params(fit_params, self, "fit") y_type = type_of_target(y, input_name="y") if y_type in ("unknown", "continuous"): # raise a specific ValueError for non-classification tasks @@ -363,7 +417,10 @@ def fit(self, X, y, sample_weight=None): self.classes_ = self.le_.classes_ transformed_y = self.le_.transform(y) - return super().fit(X, transformed_y, sample_weight) + if sample_weight is not None: + fit_params["sample_weight"] = sample_weight + + return super().fit(X, transformed_y, **fit_params) def predict(self, X): """Predict class labels for X. @@ -495,7 +552,7 @@ def get_feature_names_out(self, input_features=None): return np.asarray(names_out, dtype=object) -class VotingRegressor(_RoutingNotSupportedMixin, RegressorMixin, _BaseVoting): +class VotingRegressor(RegressorMixin, _BaseVoting): """Prediction voting regressor for unfitted estimators. A voting regressor is an ensemble meta-estimator that fits several base @@ -596,7 +653,11 @@ def __init__(self, estimators, *, weights=None, n_jobs=None, verbose=False): # estimators in VotingRegressor.estimators are not validated yet prefer_skip_nested_validation=False ) - def fit(self, X, y, sample_weight=None): + # TODO(1.7): remove `sample_weight` from the signature after deprecation cycle; + # pop it from `fit_params` before the `_raise_for_params` check and reinsert later, + # for backwards compatibility + @_deprecate_positional_args(version="1.7") + def fit(self, X, y, *, sample_weight=None, **fit_params): """Fit the estimators. Parameters @@ -613,14 +674,27 @@ def fit(self, X, y, sample_weight=None): Note that this is supported only if all underlying estimators support sample weights. + **fit_params : dict + Parameters to pass to the underlying estimators. + + .. versionadded:: 1.5 + + Only available if `enable_metadata_routing=True`, + which can be set by using + ``sklearn.set_config(enable_metadata_routing=True)``. + See :ref:`Metadata Routing User Guide ` for + more details. + Returns ------- self : object Fitted estimator. """ - _raise_for_unsupported_routing(self, "fit", sample_weight=sample_weight) + _raise_for_params(fit_params, self, "fit") y = column_or_1d(y, warn=True) - return super().fit(X, y, sample_weight) + if sample_weight is not None: + fit_params["sample_weight"] = sample_weight + return super().fit(X, y, **fit_params) def predict(self, X): """Predict regression target for X. diff --git a/sklearn/ensemble/tests/test_voting.py b/sklearn/ensemble/tests/test_voting.py index 011d9b40077e1..2f4c412bd6466 100644 --- a/sklearn/ensemble/tests/test_voting.py +++ b/sklearn/ensemble/tests/test_voting.py @@ -23,12 +23,18 @@ from sklearn.neighbors import KNeighborsClassifier from sklearn.preprocessing import StandardScaler from sklearn.svm import SVC +from sklearn.tests.metadata_routing_common import ( + ConsumingClassifier, + ConsumingRegressor, + _Registry, + check_recorded_metadata, +) from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor from sklearn.utils._testing import ( - _convert_container, assert_almost_equal, assert_array_almost_equal, assert_array_equal, + ignore_warnings, ) # Load datasets @@ -255,19 +261,19 @@ def test_predict_proba_on_toy_problem(): assert inner_msg in str(exec_info.value.__cause__) -@pytest.mark.parametrize("container_type", ["list", "array", "dataframe"]) -def test_multilabel(container_type): +def test_multilabel(): """Check if error is raised for multilabel classification.""" X, y = make_multilabel_classification( n_classes=2, n_labels=1, allow_unlabeled=False, random_state=123 ) - y = _convert_container(y, container_type) clf = OneVsRestClassifier(SVC(kernel="linear")) eclf = VotingClassifier(estimators=[("ovr", clf)], voting="hard") - err_msg = "only supports binary or multiclass classification" - with pytest.raises(NotImplementedError, match=err_msg): + + try: eclf.fit(X, y) + except NotImplementedError: + return def test_gridsearch(): @@ -308,6 +314,7 @@ def test_parallel_fit(global_random_seed): assert_array_almost_equal(eclf1.predict_proba(X), eclf2.predict_proba(X)) +@ignore_warnings(category=FutureWarning) def test_sample_weight(global_random_seed): """Tests sample_weight parameter of VotingClassifier""" clf1 = LogisticRegression(random_state=global_random_seed) @@ -682,3 +689,100 @@ def test_get_features_names_out_classifier_error(): ) with pytest.raises(ValueError, match=msg): voting.get_feature_names_out() + + +# Metadata Routing Tests +# ====================== + + +@pytest.mark.parametrize( + "Estimator, Child", + [(VotingClassifier, ConsumingClassifier), (VotingRegressor, ConsumingRegressor)], +) +def test_routing_passed_metadata_not_supported(Estimator, Child): + """Test that the right error message is raised when metadata is passed while + not supported when `enable_metadata_routing=False`.""" + + X = np.array([[0, 1], [2, 2], [4, 6]]) + y = [1, 2, 3] + + with pytest.raises( + ValueError, match="is only supported if enable_metadata_routing=True" + ): + Estimator(["clf", Child()]).fit(X, y, sample_weight=[1, 1, 1], metadata="a") + + +@pytest.mark.usefixtures("enable_slep006") +@pytest.mark.parametrize( + "Estimator, Child", + [(VotingClassifier, ConsumingClassifier), (VotingRegressor, ConsumingRegressor)], +) +def test_get_metadata_routing_without_fit(Estimator, Child): + # Test that metadata_routing() doesn't raise when called before fit. + est = Estimator([("sub_est", Child())]) + est.get_metadata_routing() + + +@pytest.mark.usefixtures("enable_slep006") +@pytest.mark.parametrize( + "Estimator, Child", + [(VotingClassifier, ConsumingClassifier), (VotingRegressor, ConsumingRegressor)], +) +@pytest.mark.parametrize("prop", ["sample_weight", "metadata"]) +def test_metadata_routing_for_voting_estimators(Estimator, Child, prop): + """Test that metadata is routed correctly for Voting*.""" + X = np.array([[0, 1], [2, 2], [4, 6]]) + y = [1, 2, 3] + sample_weight, metadata = [1, 1, 1], "a" + + est = Estimator( + [ + ( + "sub_est1", + Child(registry=_Registry()).set_fit_request(**{prop: True}), + ), + ( + "sub_est2", + Child(registry=_Registry()).set_fit_request(**{prop: True}), + ), + ] + ) + + est.fit(X, y, **{prop: sample_weight if prop == "sample_weight" else metadata}) + + for estimator in est.estimators: + if prop == "sample_weight": + kwargs = {prop: sample_weight} + else: + kwargs = {prop: metadata} + # access sub-estimator in (name, est) with estimator[1] + registry = estimator[1].registry + assert len(registry) + for sub_est in registry: + check_recorded_metadata(obj=sub_est, method="fit", **kwargs) + + +@pytest.mark.usefixtures("enable_slep006") +@pytest.mark.parametrize( + "Estimator, Child", + [(VotingClassifier, ConsumingClassifier), (VotingRegressor, ConsumingRegressor)], +) +def test_metadata_routing_error_for_voting_estimators(Estimator, Child): + """Test that the right error is raised when metadata is not requested.""" + X = np.array([[0, 1], [2, 2], [4, 6]]) + y = [1, 2, 3] + sample_weight, metadata = [1, 1, 1], "a" + + est = Estimator([("sub_est", Child())]) + + error_message = ( + "[sample_weight, metadata] are passed but are not explicitly set as requested" + f" or not for {Child.__name__}.fit" + ) + + with pytest.raises(ValueError, match=re.escape(error_message)): + est.fit(X, y, sample_weight=sample_weight, metadata=metadata) + + +# End of Metadata Routing Tests +# ============================= diff --git a/sklearn/tests/metadata_routing_common.py b/sklearn/tests/metadata_routing_common.py index 94473566ad0c7..dc0387eb38f93 100644 --- a/sklearn/tests/metadata_routing_common.py +++ b/sklearn/tests/metadata_routing_common.py @@ -55,6 +55,7 @@ def check_recorded_metadata(obj, method, split_params=tuple(), **kwargs): split_params : tuple, default=empty specifies any parameters which are to be checked as being a subset of the original values. + **kwargs : metadata to check """ records = getattr(obj, "_records", dict()).get(method, dict()) assert set(kwargs.keys()) == set( @@ -243,6 +244,7 @@ def fit(self, X, y, sample_weight="default", metadata="default"): record_metadata_not_default( self, "fit", sample_weight=sample_weight, metadata=metadata ) + self.classes_ = np.unique(y) return self diff --git a/sklearn/tests/test_metaestimators_metadata_routing.py b/sklearn/tests/test_metaestimators_metadata_routing.py index 2cffa5125e3c2..08a7e0ef9952a 100644 --- a/sklearn/tests/test_metaestimators_metadata_routing.py +++ b/sklearn/tests/test_metaestimators_metadata_routing.py @@ -16,8 +16,6 @@ BaggingRegressor, StackingClassifier, StackingRegressor, - VotingClassifier, - VotingRegressor, ) from sklearn.exceptions import UnsetMetadataPassedError from sklearn.experimental import ( @@ -366,8 +364,6 @@ def enable_slep006(): StackingClassifier(ConsumingClassifier()), StackingRegressor(ConsumingRegressor()), TransformedTargetRegressor(), - VotingClassifier(ConsumingClassifier()), - VotingRegressor(ConsumingRegressor()), ] @@ -508,12 +504,11 @@ def test_error_on_missing_requests_for_sub_estimator(metaestimator): scorer.set_score_request(**{key: True}) val = {"sample_weight": sample_weight, "metadata": metadata}[key] method_kwargs = {key: val} + instance = cls(**kwargs) msg = ( f"[{key}] are passed but are not explicitly set as requested or not" f" for {estimator.__class__.__name__}.{method_name}" ) - - instance = cls(**kwargs) with pytest.raises(UnsetMetadataPassedError, match=re.escape(msg)): method = getattr(instance, method_name) method(X, y, **method_kwargs)