Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
ab14402
routing for VotingClassifier
StefanieSenger Oct 14, 2023
c44b362
routing for all the classifiers VotingClassifier uses
StefanieSenger Oct 19, 2023
7a3e0fa
routing done in parents fit method
StefanieSenger Oct 23, 2023
b06403f
routing for VotingRegressor
StefanieSenger Oct 23, 2023
ecea0ad
immunity for Stacking* and changes after review
StefanieSenger Oct 24, 2023
ebcd1be
Update sklearn/tests/test_metaestimators_metadata_routing.py
StefanieSenger Oct 24, 2023
16c52cd
changes after review
StefanieSenger Oct 24, 2023
0b63677
Merge branch 'main' into routing_VotingClassifier
StefanieSenger Oct 24, 2023
3b4c6a3
added custom test for Voting*
StefanieSenger Oct 27, 2023
a6691df
revert list-support for tests
StefanieSenger Oct 27, 2023
7fe675a
Update sklearn/ensemble/_voting.py
StefanieSenger Jan 5, 2024
5ada54a
Update sklearn/ensemble/_voting.py
StefanieSenger Jan 5, 2024
c85858c
Merge branch 'main' into routing_VotingClassifier
StefanieSenger Jan 5, 2024
82bea9a
Update sklearn/ensemble/_voting.py
StefanieSenger Jan 5, 2024
fd91818
Update sklearn/ensemble/tests/test_voting.py
StefanieSenger Jan 5, 2024
5489d31
changes after review
StefanieSenger Jan 5, 2024
5784c56
ignore FutureWarning
StefanieSenger Jan 5, 2024
277793f
improvements according to review
StefanieSenger Feb 1, 2024
1fbe65f
Merge branch 'main' into routing_VotingClassifier
StefanieSenger Feb 1, 2024
c375792
Merge remote-tracking branch 'origin/main' into pr/StefanieSenger/27584
glemaitre Feb 22, 2024
d7d14c7
nitpicks
glemaitre Feb 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions doc/metadata_routing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,8 @@ Meta-estimators and functions supporting metadata routing:

- :class:`sklearn.calibration.CalibratedClassifierCV`
- :class:`sklearn.compose.ColumnTransformer`
- :class:`sklearn.ensemble.VotingClassifier`
- :class:`sklearn.ensemble.VotingRegressor`
- :class:`sklearn.ensemble.BaggingClassifier`
- :class:`sklearn.ensemble.BaggingRegressor`
- :class:`sklearn.feature_selection.SelectFromModel`
Expand Down Expand Up @@ -310,8 +312,6 @@ Meta-estimators and tools not supporting metadata routing yet:
- :class:`sklearn.ensemble.AdaBoostRegressor`
- :class:`sklearn.ensemble.StackingClassifier`
- :class:`sklearn.ensemble.StackingRegressor`
- :class:`sklearn.ensemble.VotingClassifier`
- :class:`sklearn.ensemble.VotingRegressor`
- :class:`sklearn.feature_selection.RFE`
- :class:`sklearn.feature_selection.RFECV`
- :class:`sklearn.feature_selection.SequentialFeatureSelector`
Expand Down
5 changes: 5 additions & 0 deletions doc/whats_new/v1.5.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ more details.
via their `fit` methods.
:pr:`28432` by :user:`Adam Li <adam2392>` and :user:`Benjamin Bossan <BenjaminBossan>`.

- |Feature| :class:`ensemble.VotingClassifier` and
:class:`ensemble.VotingRegressor` now support metadata routing and pass
``**fit_params`` to the underlying estimators via their `fit` methods.
:pr:`27584` by :user:`Stefanie Senger <StefanieSenger>`.

Changelog
---------

Expand Down
1 change: 0 additions & 1 deletion sklearn/compose/tests/test_column_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2517,7 +2517,6 @@ def transform(self, X, sample_weight=None, metadata=None):

X = np.array([[0, 1, 2], [2, 4, 6]]).T
y = [1, 2, 3]
_Registry()
sample_weight, metadata = [1], "a"
trs = ColumnTransformer(
[
Expand Down
11 changes: 7 additions & 4 deletions sklearn/ensemble/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,20 @@
from ..base import BaseEstimator, MetaEstimatorMixin, clone, is_classifier, is_regressor
from ..utils import Bunch, _print_elapsed_time, check_random_state
from ..utils._tags import _safe_tags
from ..utils.metadata_routing import _routing_enabled
from ..utils.metaestimators import _BaseComposition


def _fit_single_estimator(
estimator, X, y, sample_weight=None, message_clsname=None, message=None
estimator, X, y, fit_params, message_clsname=None, message=None
):
"""Private function used to fit an estimator within a job."""
if sample_weight is not None:
# TODO(SLEP6): remove if condition for unrouted sample_weight when metadata
# routing can't be disabled.
if not _routing_enabled() and "sample_weight" in fit_params:
try:
with _print_elapsed_time(message_clsname, message):
estimator.fit(X, y, sample_weight=sample_weight)
estimator.fit(X, y, sample_weight=fit_params["sample_weight"])
except TypeError as exc:
if "unexpected keyword argument 'sample_weight'" in str(exc):
raise TypeError(
Expand All @@ -33,7 +36,7 @@ def _fit_single_estimator(
raise
else:
with _print_elapsed_time(message_clsname, message):
estimator.fit(X, y)
estimator.fit(X, y, **fit_params)
return estimator


Expand Down
17 changes: 10 additions & 7 deletions sklearn/ensemble/_stacking.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,14 @@ def fit(self, X, y, sample_weight=None):
names, all_estimators = self._validate_estimators()
self._validate_final_estimator()

# FIXME: when adding support for metadata routing in Stacking*.
# This is a hotfix to make StackingClassifier and StackingRegressor
# pass the tests despite not supporting metadata routing but sharing
# the same base class with VotingClassifier and VotingRegressor.
fit_params = dict()
if sample_weight is not None:
fit_params["sample_weight"] = sample_weight

stack_method = [self.stack_method] * len(all_estimators)

if self.cv == "prefit":
Expand All @@ -214,7 +222,7 @@ def fit(self, X, y, sample_weight=None):
# base estimators will be used in transform, predict, and
# predict_proba. They are exposed publicly.
self.estimators_ = Parallel(n_jobs=self.n_jobs)(
delayed(_fit_single_estimator)(clone(est), X, y, sample_weight)
delayed(_fit_single_estimator)(clone(est), X, y, fit_params)
for est in all_estimators
if est != "drop"
)
Expand Down Expand Up @@ -253,9 +261,6 @@ def fit(self, X, y, sample_weight=None):
if hasattr(cv, "random_state") and cv.random_state is None:
cv.random_state = np.random.RandomState()

fit_params = (
{"sample_weight": sample_weight} if sample_weight is not None else None
)
predictions = Parallel(n_jobs=self.n_jobs)(
delayed(cross_val_predict)(
clone(est),
Expand All @@ -280,9 +285,7 @@ def fit(self, X, y, sample_weight=None):
]

X_meta = self._concatenate_predictions(X, predictions)
_fit_single_estimator(
self.final_estimator_, X_meta, y, sample_weight=sample_weight
)
_fit_single_estimator(self.final_estimator_, X_meta, y, fit_params=fit_params)

return self

Expand Down
102 changes: 88 additions & 14 deletions sklearn/ensemble/_voting.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,18 @@
from ..utils._estimator_html_repr import _VisualBlock
from ..utils._param_validation import StrOptions
from ..utils.metadata_routing import (
_raise_for_unsupported_routing,
_RoutingNotSupportedMixin,
MetadataRouter,
MethodMapping,
_raise_for_params,
_routing_enabled,
process_routing,
)
from ..utils.metaestimators import available_if
from ..utils.multiclass import type_of_target
from ..utils.parallel import Parallel, delayed
from ..utils.validation import (
_check_feature_names_in,
_deprecate_positional_args,
check_is_fitted,
column_or_1d,
)
Expand Down Expand Up @@ -76,7 +80,7 @@ def _predict(self, X):
return np.asarray([est.predict(X) for est in self.estimators_]).T

@abstractmethod
def fit(self, X, y, sample_weight=None):
def fit(self, X, y, **fit_params):
"""Get common fit operations."""
names, clfs = self._validate_estimators()

Expand All @@ -86,16 +90,27 @@ def fit(self, X, y, sample_weight=None):
f" {len(self.weights)} weights, {len(self.estimators)} estimators"
)

if _routing_enabled():
routed_params = process_routing(self, "fit", **fit_params)
else:
routed_params = Bunch()
for name in names:
routed_params[name] = Bunch(fit={})
if "sample_weight" in fit_params:
routed_params[name].fit["sample_weight"] = fit_params[
"sample_weight"
]

self.estimators_ = Parallel(n_jobs=self.n_jobs)(
delayed(_fit_single_estimator)(
clone(clf),
X,
y,
sample_weight=sample_weight,
fit_params=routed_params[name]["fit"],
message_clsname="Voting",
message=self._log_message(names[idx], idx + 1, len(clfs)),
message=self._log_message(name, idx + 1, len(clfs)),
)
for idx, clf in enumerate(clfs)
for idx, (name, clf) in enumerate(zip(names, clfs))
if clf != "drop"
)

Expand Down Expand Up @@ -156,8 +171,32 @@ def _sk_visual_block_(self):
names, estimators = zip(*self.estimators)
return _VisualBlock("parallel", estimators, names=names)

def get_metadata_routing(self):
"""Get metadata routing of this object.

Please check :ref:`User Guide <metadata_routing>` on how the routing
mechanism works.

.. versionadded:: 1.5

Returns
-------
routing : MetadataRouter
A :class:`~sklearn.utils.metadata_routing.MetadataRouter` encapsulating
routing information.
"""
router = MetadataRouter(owner=self.__class__.__name__)

# `self.estimators` is a list of (name, est) tuples
for name, estimator in self.estimators:
router.add(
**{name: estimator},
method_mapping=MethodMapping().add(callee="fit", caller="fit"),
)
return router


class VotingClassifier(_RoutingNotSupportedMixin, ClassifierMixin, _BaseVoting):
class VotingClassifier(ClassifierMixin, _BaseVoting):
"""Soft Voting/Majority Rule classifier for unfitted estimators.

Read more in the :ref:`User Guide <voting_classifier>`.
Expand Down Expand Up @@ -317,7 +356,11 @@ def __init__(
# estimators in VotingClassifier.estimators are not validated yet
prefer_skip_nested_validation=False
)
def fit(self, X, y, sample_weight=None):
# TODO(1.7): remove `sample_weight` from the signature after deprecation
# cycle; pop it from `fit_params` before the `_raise_for_params` check and
# reinsert later, for backwards compatibility
@_deprecate_positional_args(version="1.7")
def fit(self, X, y, *, sample_weight=None, **fit_params):
"""Fit the estimators.

Parameters
Expand All @@ -336,12 +379,23 @@ def fit(self, X, y, sample_weight=None):

.. versionadded:: 0.18

**fit_params : dict
Parameters to pass to the underlying estimators.

.. versionadded:: 1.5

Only available if `enable_metadata_routing=True`,
which can be set by using
``sklearn.set_config(enable_metadata_routing=True)``.
See :ref:`Metadata Routing User Guide <metadata_routing>` for
more details.

Returns
-------
self : object
Returns the instance itself.
"""
_raise_for_unsupported_routing(self, "fit", sample_weight=sample_weight)
_raise_for_params(fit_params, self, "fit")
y_type = type_of_target(y, input_name="y")
if y_type in ("unknown", "continuous"):
# raise a specific ValueError for non-classification tasks
Expand All @@ -363,7 +417,10 @@ def fit(self, X, y, sample_weight=None):
self.classes_ = self.le_.classes_
transformed_y = self.le_.transform(y)

return super().fit(X, transformed_y, sample_weight)
if sample_weight is not None:
fit_params["sample_weight"] = sample_weight

return super().fit(X, transformed_y, **fit_params)

def predict(self, X):
"""Predict class labels for X.
Expand Down Expand Up @@ -495,7 +552,7 @@ def get_feature_names_out(self, input_features=None):
return np.asarray(names_out, dtype=object)


class VotingRegressor(_RoutingNotSupportedMixin, RegressorMixin, _BaseVoting):
class VotingRegressor(RegressorMixin, _BaseVoting):
"""Prediction voting regressor for unfitted estimators.

A voting regressor is an ensemble meta-estimator that fits several base
Expand Down Expand Up @@ -596,7 +653,11 @@ def __init__(self, estimators, *, weights=None, n_jobs=None, verbose=False):
# estimators in VotingRegressor.estimators are not validated yet
prefer_skip_nested_validation=False
)
def fit(self, X, y, sample_weight=None):
# TODO(1.7): remove `sample_weight` from the signature after deprecation cycle;
# pop it from `fit_params` before the `_raise_for_params` check and reinsert later,
# for backwards compatibility
@_deprecate_positional_args(version="1.7")
def fit(self, X, y, *, sample_weight=None, **fit_params):
"""Fit the estimators.

Parameters
Expand All @@ -613,14 +674,27 @@ def fit(self, X, y, sample_weight=None):
Note that this is supported only if all underlying estimators
support sample weights.

**fit_params : dict
Parameters to pass to the underlying estimators.

.. versionadded:: 1.5

Only available if `enable_metadata_routing=True`,
which can be set by using
``sklearn.set_config(enable_metadata_routing=True)``.
See :ref:`Metadata Routing User Guide <metadata_routing>` for
more details.

Returns
-------
self : object
Fitted estimator.
"""
_raise_for_unsupported_routing(self, "fit", sample_weight=sample_weight)
_raise_for_params(fit_params, self, "fit")
y = column_or_1d(y, warn=True)
return super().fit(X, y, sample_weight)
if sample_weight is not None:
fit_params["sample_weight"] = sample_weight
return super().fit(X, y, **fit_params)

def predict(self, X):
"""Predict regression target for X.
Expand Down
Loading