Skip to content

FIX Draw indices using sample_weight in Bagging #31414

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 16 commits into
base: main
Choose a base branch
from

Conversation

antoinebaker
Copy link
Contributor

@antoinebaker antoinebaker commented May 22, 2025

Part of #16298 and alternative to #31165.

What does this implement/fix? Explain your changes.

In Bagging estimators, sample_weight is now used to draw the samples and no longer forwarded to the underlying estimators. Bagging estimators now pass the statistical repeated/weighted equivalence test when bootstrap=True (the default, ie draw with replacement).

Compared to #31165, it better decouples two different usages of sample_weight:

  • sample_weight in bagging_estimator.fit are used as probabilities to draw the indices/rows
  • sample_weight in base_estimator.fit are used to represent the indices (more memory efficient than indexing), this is possible only if base_estimator.fit supports sample_weight (through metadata routing or natively).

#31165 introduced a new sampling_strategy argument to choose indexing/weighting for row sampling, but it would be better to do this in a dedicated follow up PR.

cc @ogrisel @GaetandeCast

Copy link

github-actions bot commented May 22, 2025

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: 0318a51. Link to the linter CI: here

@antoinebaker
Copy link
Contributor Author

antoinebaker commented May 22, 2025

BaggingRegressor(estimator=Ridge(), max_samples=100) now passes the statistical repeated/weighted equivalence test

image

Idem for BaggingClassifier(estimator=LogisticRegression(), max_samples=100)) and varying max_samples.

@antoinebaker
Copy link
Contributor Author

However it fails (as expected) for bootstrap=False (draw without replacement), for example BaggingRegressor(estimator=Ridge(), bootstrap=False, max_samples=10)

image

@ogrisel
Copy link
Member

ogrisel commented May 23, 2025

However it fails (as expected) for bootstrap=False (draw without replacement).

Could you please document this known limitation, both in the docstring of the __init__ method for the bootstrap parameter and in the docstring of the fit method for the sample_weight parameter?

Something like: "Note that the expected frequency semantics for the sample_weight parameter are only fulfilled when sampling with replacement bootstrap=True".

Maybe we should raise a warning when calling BaggingClassifier(bootstrap=False, max_samples=0.5).fit(X, y, sample_weight=sample_weight) with sample_weight is not None. The warning is already implemented and tested: https://github.com/scikit-learn/scikit-learn/pull/31414/files#diff-b7c01e77fe68ded1e41868f4a7e142190f935261624d4abdb299913ef944cbbbR676-R682.

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is a pass of review. Could you please add a non-regression test using a small dataset with specifically engineered weights? For instance, you could have a dataset with 100 datapoints, with 98 data points with a null weight, 1 data point, with a weight of 1 and 1 with a weight of 2:

X = np.arange(100).reshape(-1, 1)
y = (X < 99).astype(np.int32)
sample_weight = np.zeros(shape=X.shape[0])
sample_weight[0] = 1
sample_weight[-1] = 2

Then you could fit a BaggingRegressor and a BaggleClassifier with a fake test estimator that just records the values passed as X, y and sample_weight as fitted attribute to be able to write assertions in the test.

Ideally this test should pass both with metadata routing enabled and disabled.

Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
@ogrisel
Copy link
Member

ogrisel commented May 28, 2025

BTW @antoinebaker once this PR has been finalized with tests, it would be great to open a similar PR for random forests. I suppose their bad handling of sample weights stems from the same root cause and a similar fix should be applicable.

antoinebaker and others added 2 commits June 2, 2025 09:08
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @antoinebaker. Here is another pass of feedback but otherwise LGTM.

@ogrisel
Copy link
Member

ogrisel commented Jun 3, 2025

but otherwise LGTM.

Actually no: I tried to run https://github.com/snath-xoc/sample-weight-audit-nondet/blob/main/reports/sklearn_estimators_sample_weight_audit_report.ipynb against this branch and I still get a p-value lower than 1e-33 for this branch. It's an improvement over the < 1e-54 I measured on main but still, the bug does not seem fixed for classifiers.

I confirm the bug is fixed for the regressor though. So I must be missing something.

@antoinebaker
Copy link
Contributor Author

I confirm the bug is fixed for the regressor though. So I must be missing something.

Did you specify max_samples as an integer eg max_samples=10 ? Otherwise you might get different number of samples in the repeated/weighted datasets #31165 (comment)

Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
@ogrisel
Copy link
Member

ogrisel commented Jun 3, 2025

I forgot about the max_samples thing. Let me try again.

EDIT: I confirm this works as expected.

antoinebaker and others added 2 commits June 3, 2025 17:58
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
@antoinebaker
Copy link
Contributor Author

@ogrisel
Copy link
Member

ogrisel commented Jun 10, 2025

@antoinebaker I pushed 28a2bde to make the sample weight semantics consistent between max_samples passed as absolute or relative values. I re-ran the statistical test, and they now always pass, whatever the value of max_samples.

I had to change the code a bit to raise ValueError with explicit messages for degenerate cases, and updated the tests accordingly. I think I prefer this behavior.

Let's see if the CI is green after this commit and I will do a proper review of the PR.

@ogrisel ogrisel moved this to In Progress in Losses and solvers Jun 10, 2025
Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming CI is green, the diff LGTM besides the following details:

@ogrisel
Copy link
Member

ogrisel commented Jun 10, 2025

Maybe @jeremiedbb and @snath-xoc would like to review this PR.

Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
@antoinebaker
Copy link
Contributor Author

antoinebaker commented Jun 10, 2025

@antoinebaker I pushed 28a2bde to make the sample weight semantics consistent between max_samples passed as absolute or relative values. I re-ran the statistical tests, and they now always pass, whatever the value of max_samples.

I like the new semantic and raising an error if sw_sum < 1.

@snath-xoc
Copy link
Contributor

From an initial pass LGTM, all tests pass for me as well. Will wait for @jeremiedbb to review but otherwise looks good to go!

Copy link
Member

@jeremiedbb jeremiedbb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Just a nitpick and a remark for which we can't really do anything about :)

Comment on lines +473 to +481
sw_sum = np.sum(sample_weight)
if sw_sum <= 1:
raise ValueError(
f"The total sum of sample weights is {sw_sum}, which prevents "
"resampling with a fractional value for max_samples="
f"{max_samples}. Either pass max_samples as an integer or "
"use a larger sample_weight."
)
max_samples = max(int(max_samples * sw_sum), 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something doesn't feel right about this approach.
If max_samples is a float in [0,1], I'd interpret it as a fraction of the sum of the weights and so to draw a number of samples that sums on average to max_samples * sw_sum. For instance if max_samples=0.5, I'd expect to draw samples such that the sum of their weight is on average half the total sum of the weights.
This is not the case here since we're turning it into an int being the number of samples to draw. That's why there's this issue with small weights in particular.

That being said I don't have any alternative to propose. At least the docstring is clear about how max_samples is related to the actual number samples drawn. So I guess this is good enough for us.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not the case here since we're turning it into an int being the number of samples to draw. That's why there's this issue with small weights in particular.

I am not sure, I follow. Does your comment specifically refer to the edge case where sw_sum >= 1 but int(max_samples * sw_sum) == 0 in which case the max operator uses 1 instead? I think this is really an edge case and we can. We could raise a warning, but the user wouldn't be able to do anything about it. Furthermore, I expect this case to be very rare in practice.

Besides this extreme edge case, I think your expectation that we "draw samples such that the sum of their weight is on average half the total sum of the weights." should be met, no?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No it's not about this edge case.

Take for instance sw = np.array([1.2, 3.4, 4.7, 5.6, 2.2, 2.9]). We have sw.sum()=20. If I set max_samples=0.5, intuitively I'd expect to draw samples such that the sum of their weight is close to 10 on average. But here max_samples * sw_sum = 10 so we'll sample 10 points and on average the sum of their weights is 10 * sw.mean() = 33.33 so more than 3 times my expectation.

On the opposite, if the samples weights sum to a value less than n_samples, we'll draw points such that the sum of their weight is less than the expected. Actually I think the expected sum of weights is int(max_samples * sw_sum) * sw_mean, so only equals to int(max_samples * sw_sum) if sw_mean=1. To get the expected sum of weights we should then draw int(max_samples * n_samples) points, which leads to an average sum of weights of max_samples * sw_sum.

But this was the previous implementation and used to break the equivalence between weighted and repeated.

Copy link
Member

@ogrisel ogrisel Jun 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I set max_samples=0.5, intuitively I'd expect to draw samples such that the sum of their weight is close to 10 on average. But here max_samples * sw_sum = 10 so we'll sample 10 points and on average the sum of their weights is 10 * sw.mean() = 33.33 so more than 3 times my expectation.

I don't think that's what this PR does. What we do is:

  • generate indices (with replacement) with max_samples * sw_sum ~= 10 elements with replacement (sklearn/ensemble/_bagging.py:90);
  • then pass indices_as_sample_weight = np.bincount(indices) (sklearn/ensemble/_bagging.py:172) as the sample_weight param of the base estimator to simulate fitting on this resampling using sample weights. Note that we do not reuse the sample_weight values passed by the user a second time for this step. This avoids double accounting.

Personally, I don't think there is a problem in the current state of the PR and the statistical tests seem to confirm this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I said I don't have a better alternative to offer so I'm okay with this.

The issue for me comes from the fact that a parameter is tied to n_samples and not the weight sum. That's why we're able to have the equivalence between weighted and repeated but a lot harder with a rescaling of the weights. Here you'd get an error if you normalize your weights in advance (then sw_sum = 1), which feels like a non-optimal behavior to me.

Co-authored-by: Jérémie du Boisberranger <jeremie@probabl.ai>
@snath-xoc
Copy link
Contributor

snath-xoc commented Jun 13, 2025

I adapted the stochastic sample weighting test to check the performance when sample_weight.sum()<1 and max_sample=1 (see here), with the following results:

image

I agree this is probably an edge case which should raise a warning (can't think of an obvious solution to it.

NOTE: I had to rescale non-integer weights when constructing the repeated dataset and scale them back again... hopefully this doesn't conflate the two things (it may though)

@antoinebaker
Copy link
Contributor Author

antoinebaker commented Jun 13, 2025

I adapted the stochastic sample weighting test to check the performance when sample_weight.sum()<1 and max_sample=1

Is it absolute max_samples=1 (int), which would mean fit each estimator with only one sample, or relative max_samples=1.0 (float), which after 28a2bde should raise an error ?

What I find nice with the redefinition of max_samples in #31414 (comment) is that Bagging estimators should now passed the repeated/weighted equivalence with integer weights, both for max_samples relative (float) or absolute (integer). So it would be nice to add the two cases eg max_samples=20 and max_samples=0.5 in the test suite.

They should also pass the global rescaling of the weights. However we should be cautious with the rescaling ofmax_samples: no rescaling when absolute (int), rescaled when relative (float).

@snath-xoc
Copy link
Contributor

I adapted the stochastic sample weighting test to check the performance when sample_weight.sum()<1 and max_sample=1

Is it absolute max_samples=1 (int), which would mean fit each estimator with only one sample, or relative max_samples=1.0 (float), which after 28a2bde should raise an error ?

yes it is passed as an (int), otherwise the error is raised indeed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: In Progress
Development

Successfully merging this pull request may close these issues.

4 participants