Skip to content

Commit 28a2bde

Browse files
committed
Make max_samples a fraction of sample_weight.sum() instead of X.shape[0]
1 parent 0a2b38d commit 28a2bde

File tree

2 files changed

+67
-10
lines changed

2 files changed

+67
-10
lines changed

sklearn/ensemble/_bagging.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -465,11 +465,26 @@ def _fit(
465465
# Validate max_samples
466466
if max_samples is None:
467467
max_samples = self.max_samples
468-
elif not isinstance(max_samples, numbers.Integral):
469-
max_samples = int(max_samples * X.shape[0])
470468

471-
if max_samples > X.shape[0]:
472-
raise ValueError("max_samples must be <= n_samples")
469+
if not isinstance(max_samples, numbers.Integral):
470+
if sample_weight is None:
471+
max_samples = max(int(max_samples * X.shape[0]), 1)
472+
else:
473+
sw_sum = np.sum(sample_weight)
474+
if sw_sum <= 1:
475+
raise ValueError(
476+
f"The total sum of sample weights is {sw_sum}, which prevents "
477+
"resampling with a fractional value for max_samples="
478+
f"{max_samples}. Either pass max_samples as an integer or "
479+
"use a larger sample_weight."
480+
)
481+
max_samples = max(int(max_samples * sw_sum), 1)
482+
483+
if not self.bootstrap and max_samples > X.shape[0]:
484+
raise ValueError(
485+
f"Effective max_samples={max_samples} must be <= n_samples="
486+
f"{X.shape[0]} to be able to sample without replacement."
487+
)
473488

474489
# Store validated integer row sampling value
475490
self._max_samples = max_samples
@@ -722,7 +737,8 @@ class BaggingClassifier(ClassifierMixin, BaseBagging):
722737
replacement by default, see `bootstrap` for more details).
723738
724739
- If int, then draw `max_samples` samples.
725-
- If float, then draw `max_samples * X.shape[0]` samples.
740+
- If float, then draw `max_samples * X.shape[0]` unweighted samples
741+
or `max_samples * sample_weight.sum()` weighted samples.
726742
727743
max_features : int or float, default=1.0
728744
The number of features to draw from X to train each base estimator (

sklearn/ensemble/tests/test_bagging.py

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# Authors: The scikit-learn developers
66
# SPDX-License-Identifier: BSD-3-Clause
77

8+
import re
89
from itertools import cycle, product
910

1011
import joblib
@@ -696,6 +697,37 @@ def test_warning_bootstrap_sample_weight():
696697
reg.fit(X, y, sample_weight=sample_weight)
697698

698699

700+
def test_invalid_sample_weight_max_samples_bootstrap_combinations():
701+
X, y = iris.data, iris.target
702+
703+
# Case 1: small weights and fractional max_samples would lead to sampling
704+
# less than 1 sample, which is not allowed.
705+
clf = BaggingClassifier(max_samples=1.0)
706+
sample_weight = np.ones_like(y) / (2 * len(y))
707+
expected_msg = (
708+
r"The total sum of sample weights is 0.5(\d*), which prevents resampling with "
709+
r"a fractional value for max_samples=1\.0\. Either pass max_samples as an "
710+
r"integer or use a larger sample_weight\."
711+
)
712+
with pytest.raises(ValueError, match=expected_msg):
713+
clf.fit(X, y, sample_weight=sample_weight)
714+
715+
# Case 2: large weights and bootstrap=False would lead to sampling without
716+
# replacement more than the number of samples, which is not allowed.
717+
clf = BaggingClassifier(bootstrap=False, max_samples=1.0)
718+
sample_weight = np.ones_like(y)
719+
sample_weight[-1] = 2
720+
expected_msg = re.escape(
721+
"max_samples=151 must be <= n_samples=150 to be able to sample without "
722+
"replacement."
723+
)
724+
with pytest.raises(ValueError, match=expected_msg):
725+
with pytest.warns(
726+
UserWarning, match="When fitting BaggingClassifier with sample_weight"
727+
):
728+
clf.fit(X, y, sample_weight=sample_weight)
729+
730+
699731
class EstimatorAcceptingSampleWeight(BaseEstimator):
700732
"""Fake estimator accepting sample_weight"""
701733

@@ -724,8 +756,9 @@ def predict(self, X):
724756
@pytest.mark.parametrize("bagging_class", [BaggingRegressor, BaggingClassifier])
725757
@pytest.mark.parametrize("accept_sample_weight", [False, True])
726758
@pytest.mark.parametrize("metadata_routing", [False, True])
759+
@pytest.mark.parametrize("max_samples", [10, 0.8])
727760
def test_draw_indices_using_sample_weight(
728-
bagging_class, accept_sample_weight, metadata_routing
761+
bagging_class, accept_sample_weight, metadata_routing, max_samples
729762
):
730763
X = np.arange(100).reshape(-1, 1)
731764
y = np.repeat([0, 1], 50)
@@ -739,7 +772,15 @@ def test_draw_indices_using_sample_weight(
739772
base_estimator = EstimatorRejectingSampleWeight()
740773

741774
n_samples, n_features = X.shape
742-
max_samples = 10
775+
776+
if isinstance(max_samples, float):
777+
# max_samples passed as a fraction of the input data. Since
778+
# sample_weight are provided, the effective number of samples is the
779+
# sum of the sample weights.
780+
expected_integer_max_samples = int(max_samples * sample_weight.sum())
781+
else:
782+
expected_integer_max_samples = max_samples
783+
743784
with config_context(enable_metadata_routing=metadata_routing):
744785
# TODO(slep006): remove block when default routing is implemented
745786
if metadata_routing and accept_sample_weight:
@@ -748,7 +789,7 @@ def test_draw_indices_using_sample_weight(
748789
bagging.fit(X, y, sample_weight=sample_weight)
749790
for estimator, samples in zip(bagging.estimators_, bagging.estimators_samples_):
750791
counts = np.bincount(samples, minlength=n_samples)
751-
assert sum(counts) == len(samples) == max_samples
792+
assert sum(counts) == len(samples) == expected_integer_max_samples
752793
# only indices 4 and 5 should appear
753794
assert np.isin(samples, [4, 5]).all()
754795
if accept_sample_weight:
@@ -760,8 +801,8 @@ def test_draw_indices_using_sample_weight(
760801
assert_allclose(estimator.sample_weight_, counts)
761802
else:
762803
# sampled indices represented through indexing
763-
assert estimator.X_.shape == (max_samples, n_features)
764-
assert estimator.y_.shape == (max_samples,)
804+
assert estimator.X_.shape == (expected_integer_max_samples, n_features)
805+
assert estimator.y_.shape == (expected_integer_max_samples,)
765806
assert_allclose(estimator.X_, X[samples])
766807
assert_allclose(estimator.y_, y[samples])
767808

0 commit comments

Comments
 (0)