Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions doc/whats_new/upcoming_changes/sklearn.ensemble/31414.fix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
- :class:`ensemble.BaggingClassfier`, :class:`ensemble.BaggingRegressor`
and :class:`ensemble.IsolationForest` now use `sample_weight` to draw
the samples instead of forwarding them multiplied by a uniformly sampled
mask to the underlying estimators. Furthermore, `max_samples` is now
interpreted as a fraction of `sample_weight.sum()` instead of `X.shape[0]`
when passed as a float.
By :user:`Antoine Baker <antoinebaker>`.
147 changes: 82 additions & 65 deletions sklearn/ensemble/_bagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def _generate_bagging_indices(
n_samples,
max_features,
max_samples,
sample_weight,
):
"""Randomly draw feature and sample indices."""
# Get valid random state
Expand All @@ -81,18 +82,37 @@ def _generate_bagging_indices(
feature_indices = _generate_indices(
random_state, bootstrap_features, n_features, max_features
)
sample_indices = _generate_indices(
random_state, bootstrap_samples, n_samples, max_samples
)
if sample_weight is None:
sample_indices = _generate_indices(
random_state, bootstrap_samples, n_samples, max_samples
)
else:
normalized_sample_weight = sample_weight / np.sum(sample_weight)
sample_indices = random_state.choice(
n_samples,
max_samples,
replace=bootstrap_samples,
p=normalized_sample_weight,
)

return feature_indices, sample_indices


def _consumes_sample_weight(estimator):
if _routing_enabled():
request_or_router = get_routing_for_object(estimator)
consumes_sample_weight = request_or_router.consumes("fit", ("sample_weight",))
else:
consumes_sample_weight = has_fit_parameter(estimator, "sample_weight")
return consumes_sample_weight


def _parallel_build_estimators(
n_estimators,
ensemble,
X,
y,
sample_weight,
seeds,
total_n_estimators,
verbose,
Expand All @@ -108,22 +128,12 @@ def _parallel_build_estimators(
bootstrap_features = ensemble.bootstrap_features
has_check_input = has_fit_parameter(ensemble.estimator_, "check_input")
requires_feature_indexing = bootstrap_features or max_features != n_features
consumes_sample_weight = _consumes_sample_weight(ensemble.estimator_)

# Build estimators
estimators = []
estimators_features = []

# TODO: (slep6) remove if condition for unrouted sample_weight when metadata
# routing can't be disabled.
support_sample_weight = has_fit_parameter(ensemble.estimator_, "sample_weight")
if not _routing_enabled() and (
not support_sample_weight and fit_params.get("sample_weight") is not None
):
raise ValueError(
"The base estimator doesn't support sample weight, but sample_weight is "
"passed to the fit method."
)

for i in range(n_estimators):
if verbose > 1:
print(
Expand All @@ -139,7 +149,8 @@ def _parallel_build_estimators(
else:
estimator_fit = estimator.fit

# Draw random feature, sample indices
# Draw random feature, sample indices (using normalized sample_weight
# as probabilites if provided).
features, indices = _generate_bagging_indices(
random_state,
bootstrap_features,
Expand All @@ -148,45 +159,22 @@ def _parallel_build_estimators(
n_samples,
max_features,
max_samples,
sample_weight,
)

fit_params_ = fit_params.copy()

# TODO(SLEP6): remove if condition for unrouted sample_weight when metadata
# routing can't be disabled.
# 1. If routing is enabled, we will check if the routing supports sample
# weight and use it if it does.
# 2. If routing is not enabled, we will check if the base
# estimator supports sample_weight and use it if it does.

# Note: Row sampling can be achieved either through setting sample_weight or
# by indexing. The former is more efficient. Therefore, use this method
# by indexing. The former is more memory efficient. Therefore, use this method
# if possible, otherwise use indexing.
if _routing_enabled():
request_or_router = get_routing_for_object(ensemble.estimator_)
consumes_sample_weight = request_or_router.consumes(
"fit", ("sample_weight",)
)
else:
consumes_sample_weight = support_sample_weight
if consumes_sample_weight:
# Draw sub samples, using sample weights, and then fit
curr_sample_weight = _check_sample_weight(
fit_params_.pop("sample_weight", None), X
).copy()

if bootstrap:
sample_counts = np.bincount(indices, minlength=n_samples)
curr_sample_weight *= sample_counts
else:
not_indices_mask = ~indices_to_mask(indices, n_samples)
curr_sample_weight[not_indices_mask] = 0

fit_params_["sample_weight"] = curr_sample_weight
# Row sampling by setting sample_weight
indices_as_sample_weight = np.bincount(indices, minlength=n_samples)
fit_params_["sample_weight"] = indices_as_sample_weight
X_ = X[:, features] if requires_feature_indexing else X
estimator_fit(X_, y, **fit_params_)
else:
# cannot use sample_weight, so use indexing
# Row sampling by indexing
y_ = _safe_indexing(y, indices)
X_ = _safe_indexing(X, indices)
fit_params_ = _check_method_params(X, params=fit_params_, indices=indices)
Expand Down Expand Up @@ -354,9 +342,11 @@ def fit(self, X, y, sample_weight=None, **fit_params):
regression).

sample_weight : array-like of shape (n_samples,), default=None
Sample weights. If None, then samples are equally weighted.
Note that this is supported only if the base estimator supports
sample weighting.
Sample weights. If None, then samples are equally weighted. Used as
probabilities to sample the training set. Note that the expected
frequency semantics for the `sample_weight` parameter are only
fulfilled when sampling with replacement `bootstrap=True`.

**fit_params : dict
Parameters to pass to the underlying estimators.

Expand Down Expand Up @@ -386,6 +376,15 @@ def fit(self, X, y, sample_weight=None, **fit_params):
multi_output=True,
)

if sample_weight is not None:
sample_weight = _check_sample_weight(sample_weight, X, dtype=None)

if not self.bootstrap:
warn(
f"When fitting {self.__class__.__name__} with sample_weight "
f"it is recommended to use bootstrap=True, got {self.bootstrap}."
)

return self._fit(
X,
y,
Expand Down Expand Up @@ -435,8 +434,6 @@ def _fit(

sample_weight : array-like of shape (n_samples,), default=None
Sample weights. If None, then samples are equally weighted.
Note that this is supported only if the base estimator supports
sample weighting.

**fit_params : dict, default=None
Parameters to pass to the :term:`fit` method of the underlying
Expand All @@ -457,30 +454,38 @@ def _fit(
# Check parameters
self._validate_estimator(self._get_estimator())

if sample_weight is not None:
fit_params["sample_weight"] = sample_weight

if _routing_enabled():
routed_params = process_routing(self, "fit", **fit_params)
else:
routed_params = Bunch()
routed_params.estimator = Bunch(fit=fit_params)
if "sample_weight" in fit_params:
routed_params.estimator.fit["sample_weight"] = fit_params[
"sample_weight"
]

if max_depth is not None:
self.estimator_.max_depth = max_depth

# Validate max_samples
if max_samples is None:
max_samples = self.max_samples
elif not isinstance(max_samples, numbers.Integral):
max_samples = int(max_samples * X.shape[0])

if max_samples > X.shape[0]:
raise ValueError("max_samples must be <= n_samples")
if not isinstance(max_samples, numbers.Integral):
if sample_weight is None:
max_samples = max(int(max_samples * X.shape[0]), 1)
else:
sw_sum = np.sum(sample_weight)
if sw_sum <= 1:
raise ValueError(
f"The total sum of sample weights is {sw_sum}, which prevents "
"resampling with a fractional value for max_samples="
f"{max_samples}. Either pass max_samples as an integer or "
"use a larger sample_weight."
)
max_samples = max(int(max_samples * sw_sum), 1)
Comment on lines +474 to +482
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something doesn't feel right about this approach.
If max_samples is a float in [0,1], I'd interpret it as a fraction of the sum of the weights and so to draw a number of samples that sums on average to max_samples * sw_sum. For instance if max_samples=0.5, I'd expect to draw samples such that the sum of their weight is on average half the total sum of the weights.
This is not the case here since we're turning it into an int being the number of samples to draw. That's why there's this issue with small weights in particular.

That being said I don't have any alternative to propose. At least the docstring is clear about how max_samples is related to the actual number samples drawn. So I guess this is good enough for us.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not the case here since we're turning it into an int being the number of samples to draw. That's why there's this issue with small weights in particular.

I am not sure, I follow. Does your comment specifically refer to the edge case where sw_sum >= 1 but int(max_samples * sw_sum) == 0 in which case the max operator uses 1 instead? I think this is really an edge case and we can. We could raise a warning, but the user wouldn't be able to do anything about it. Furthermore, I expect this case to be very rare in practice.

Besides this extreme edge case, I think your expectation that we "draw samples such that the sum of their weight is on average half the total sum of the weights." should be met, no?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No it's not about this edge case.

Take for instance sw = np.array([1.2, 3.4, 4.7, 5.6, 2.2, 2.9]). We have sw.sum()=20. If I set max_samples=0.5, intuitively I'd expect to draw samples such that the sum of their weight is close to 10 on average. But here max_samples * sw_sum = 10 so we'll sample 10 points and on average the sum of their weights is 10 * sw.mean() = 33.33 so more than 3 times my expectation.

On the opposite, if the samples weights sum to a value less than n_samples, we'll draw points such that the sum of their weight is less than the expected. Actually I think the expected sum of weights is int(max_samples * sw_sum) * sw_mean, so only equals to int(max_samples * sw_sum) if sw_mean=1. To get the expected sum of weights we should then draw int(max_samples * n_samples) points, which leads to an average sum of weights of max_samples * sw_sum.

But this was the previous implementation and used to break the equivalence between weighted and repeated.

Copy link
Member

@ogrisel ogrisel Jun 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I set max_samples=0.5, intuitively I'd expect to draw samples such that the sum of their weight is close to 10 on average. But here max_samples * sw_sum = 10 so we'll sample 10 points and on average the sum of their weights is 10 * sw.mean() = 33.33 so more than 3 times my expectation.

I don't think that's what this PR does. What we do is:

  • generate indices (with replacement) with max_samples * sw_sum ~= 10 elements with replacement (sklearn/ensemble/_bagging.py:90);
  • then pass indices_as_sample_weight = np.bincount(indices) (sklearn/ensemble/_bagging.py:172) as the sample_weight param of the base estimator to simulate fitting on this resampling using sample weights. Note that we do not reuse the sample_weight values passed by the user a second time for this step. This avoids double accounting.

Personally, I don't think there is a problem in the current state of the PR and the statistical tests seem to confirm this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I said I don't have a better alternative to offer so I'm okay with this.

The issue for me comes from the fact that a parameter is tied to n_samples and not the weight sum. That's why we're able to have the equivalence between weighted and repeated but a lot harder with a rescaling of the weights. Here you'd get an error if you normalize your weights in advance (then sw_sum = 1), which feels like a non-optimal behavior to me.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here you'd get an error if you normalize your weights in advance (then sw_sum = 1), which feels like a non-optimal behavior to me.

Lack of invariance to weight re-scaling (multiplication by a positive constant) is probably one of the properties that distinguishes the "frequency" weight semantics from other kinds of weight semantics. I am personally still not clear if this is a bug or a feature. One way to decide would be to review the known downstream uses of scikit-learn sample_weight (for instance: #30564 (comment)) to see if any of them would break because of lack of invariance to weight rescaling.


if not self.bootstrap and max_samples > X.shape[0]:
raise ValueError(
f"Effective max_samples={max_samples} must be <= n_samples="
f"{X.shape[0]} to be able to sample without replacement."
)

# Store validated integer row sampling value
self._max_samples = max_samples
Expand All @@ -499,6 +504,11 @@ def _fit(
# Store validated integer feature sampling value
self._max_features = max_features

# Store sample_weight (needed in _get_estimators_indices). Note that
# we intentionally do not materialize `sample_weight=None` as an array
# of ones to avoid unnecessarily cluttering trained estimator pickles.
self._sample_weight = sample_weight

# Other checks
if not self.bootstrap and self.oob_score:
raise ValueError("Out of bag estimation only available if bootstrap=True")
Expand Down Expand Up @@ -552,6 +562,7 @@ def _fit(
self,
X,
y,
sample_weight,
seeds[starts[i] : starts[i + 1]],
total_n_estimators,
verbose=self.verbose,
Expand Down Expand Up @@ -596,6 +607,7 @@ def _get_estimators_indices(self):
self._n_samples,
self._max_features,
self._max_samples,
self._sample_weight,
)

yield feature_indices, sample_indices
Expand Down Expand Up @@ -726,7 +738,8 @@ class BaggingClassifier(ClassifierMixin, BaseBagging):
replacement by default, see `bootstrap` for more details).

- If int, then draw `max_samples` samples.
- If float, then draw `max_samples * X.shape[0]` samples.
- If float, then draw `max_samples * X.shape[0]` unweighted samples
or `max_samples * sample_weight.sum()` weighted samples.

max_features : int or float, default=1.0
The number of features to draw from X to train each base estimator (
Expand All @@ -737,8 +750,10 @@ class BaggingClassifier(ClassifierMixin, BaseBagging):
- If float, then draw `max(1, int(max_features * n_features_in_))` features.

bootstrap : bool, default=True
Whether samples are drawn with replacement. If False, sampling
without replacement is performed.
Whether samples are drawn with replacement. If False, sampling without
replacement is performed. If fitting with `sample_weight`, it is
strongly recommended to choose True, as only drawing with replacement
will ensure the expected frequency semantics of `sample_weight`.

bootstrap_features : bool, default=False
Whether features are drawn with replacement.
Expand Down Expand Up @@ -1245,8 +1260,10 @@ class BaggingRegressor(RegressorMixin, BaseBagging):
- If float, then draw `max(1, int(max_features * n_features_in_))` features.

bootstrap : bool, default=True
Whether samples are drawn with replacement. If False, sampling
without replacement is performed.
Whether samples are drawn with replacement. If False, sampling without
replacement is performed. If fitting with `sample_weight`, it is
strongly recommended to choose True, as only drawing with replacement
will ensure the expected frequency semantics of `sample_weight`.

bootstrap_features : bool, default=False
Whether features are drawn with replacement.
Expand Down
13 changes: 11 additions & 2 deletions sklearn/ensemble/_iforest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@
from ..utils._chunking import get_chunk_n_rows
from ..utils._param_validation import Interval, RealNotInt, StrOptions
from ..utils.parallel import Parallel, delayed
from ..utils.validation import _num_samples, check_is_fitted, validate_data
from ..utils.validation import (
_check_sample_weight,
_num_samples,
check_is_fitted,
validate_data,
)
from ._bagging import BaseBagging

__all__ = ["IsolationForest"]
Expand Down Expand Up @@ -317,6 +322,10 @@ def fit(self, X, y=None, sample_weight=None):
X = validate_data(
self, X, accept_sparse=["csc"], dtype=tree_dtype, ensure_all_finite=False
)

if sample_weight is not None:
sample_weight = _check_sample_weight(sample_weight, X, dtype=None)

if issparse(X):
# Pre-sort indices to avoid that each individual tree of the
# ensemble sorts the indices.
Expand Down Expand Up @@ -350,7 +359,7 @@ def fit(self, X, y=None, sample_weight=None):
super()._fit(
X,
y,
max_samples,
max_samples=max_samples,
max_depth=max_depth,
sample_weight=sample_weight,
check_input=False,
Expand Down
Loading