5
5
# Authors: The scikit-learn developers
6
6
# SPDX-License-Identifier: BSD-3-Clause
7
7
8
+ import re
8
9
from itertools import cycle , product
9
10
10
11
import joblib
@@ -696,6 +697,37 @@ def test_warning_bootstrap_sample_weight():
696
697
reg .fit (X , y , sample_weight = sample_weight )
697
698
698
699
700
+ def test_invalid_sample_weight_max_samples_bootstrap_combinations ():
701
+ X , y = iris .data , iris .target
702
+
703
+ # Case 1: small weights and fractional max_samples would lead to sampling
704
+ # less than 1 sample, which is not allowed.
705
+ clf = BaggingClassifier (max_samples = 1.0 )
706
+ sample_weight = np .ones_like (y ) / (2 * len (y ))
707
+ expected_msg = (
708
+ r"The total sum of sample weights is 0.5(\d*), which prevents resampling with "
709
+ r"a fractional value for max_samples=1\.0\. Either pass max_samples as an "
710
+ r"integer or use a larger sample_weight\."
711
+ )
712
+ with pytest .raises (ValueError , match = expected_msg ):
713
+ clf .fit (X , y , sample_weight = sample_weight )
714
+
715
+ # Case 2: large weights and bootstrap=False would lead to sampling without
716
+ # replacement more than the number of samples, which is not allowed.
717
+ clf = BaggingClassifier (bootstrap = False , max_samples = 1.0 )
718
+ sample_weight = np .ones_like (y )
719
+ sample_weight [- 1 ] = 2
720
+ expected_msg = re .escape (
721
+ "max_samples=151 must be <= n_samples=150 to be able to sample without "
722
+ "replacement."
723
+ )
724
+ with pytest .raises (ValueError , match = expected_msg ):
725
+ with pytest .warns (
726
+ UserWarning , match = "When fitting BaggingClassifier with sample_weight"
727
+ ):
728
+ clf .fit (X , y , sample_weight = sample_weight )
729
+
730
+
699
731
class EstimatorAcceptingSampleWeight (BaseEstimator ):
700
732
"""Fake estimator accepting sample_weight"""
701
733
@@ -724,8 +756,9 @@ def predict(self, X):
724
756
@pytest .mark .parametrize ("bagging_class" , [BaggingRegressor , BaggingClassifier ])
725
757
@pytest .mark .parametrize ("accept_sample_weight" , [False , True ])
726
758
@pytest .mark .parametrize ("metadata_routing" , [False , True ])
759
+ @pytest .mark .parametrize ("max_samples" , [10 , 0.8 ])
727
760
def test_draw_indices_using_sample_weight (
728
- bagging_class , accept_sample_weight , metadata_routing
761
+ bagging_class , accept_sample_weight , metadata_routing , max_samples
729
762
):
730
763
X = np .arange (100 ).reshape (- 1 , 1 )
731
764
y = np .repeat ([0 , 1 ], 50 )
@@ -739,7 +772,15 @@ def test_draw_indices_using_sample_weight(
739
772
base_estimator = EstimatorRejectingSampleWeight ()
740
773
741
774
n_samples , n_features = X .shape
742
- max_samples = 10
775
+
776
+ if isinstance (max_samples , float ):
777
+ # max_samples passed as a fraction of the input data. Since
778
+ # sample_weight are provided, the effective number of samples is the
779
+ # sum of the sample weights.
780
+ expected_integer_max_samples = int (max_samples * sample_weight .sum ())
781
+ else :
782
+ expected_integer_max_samples = max_samples
783
+
743
784
with config_context (enable_metadata_routing = metadata_routing ):
744
785
# TODO(slep006): remove block when default routing is implemented
745
786
if metadata_routing and accept_sample_weight :
@@ -748,7 +789,7 @@ def test_draw_indices_using_sample_weight(
748
789
bagging .fit (X , y , sample_weight = sample_weight )
749
790
for estimator , samples in zip (bagging .estimators_ , bagging .estimators_samples_ ):
750
791
counts = np .bincount (samples , minlength = n_samples )
751
- assert sum (counts ) == len (samples ) == max_samples
792
+ assert sum (counts ) == len (samples ) == expected_integer_max_samples
752
793
# only indices 4 and 5 should appear
753
794
assert np .isin (samples , [4 , 5 ]).all ()
754
795
if accept_sample_weight :
@@ -760,8 +801,8 @@ def test_draw_indices_using_sample_weight(
760
801
assert_allclose (estimator .sample_weight_ , counts )
761
802
else :
762
803
# sampled indices represented through indexing
763
- assert estimator .X_ .shape == (max_samples , n_features )
764
- assert estimator .y_ .shape == (max_samples ,)
804
+ assert estimator .X_ .shape == (expected_integer_max_samples , n_features )
805
+ assert estimator .y_ .shape == (expected_integer_max_samples ,)
765
806
assert_allclose (estimator .X_ , X [samples ])
766
807
assert_allclose (estimator .y_ , y [samples ])
767
808
0 commit comments