Skip to content

FIX Draw indices using sample_weight in Bagging #31414

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: main
Choose a base branch
from

Conversation

antoinebaker
Copy link
Contributor

@antoinebaker antoinebaker commented May 22, 2025

Part of #16298 and alternative to #31165.

What does this implement/fix? Explain your changes.

In Bagging estimators, sample_weight is now used to draw the samples and no longer forwarded to the underlying estimators. Bagging estimators now pass the statistical repeated/weighted equivalence test when bootstrap=True (the default, ie draw with replacement).

Compared to #31165, it better decouples two different usages of sample_weight:

  • sample_weight in bagging_estimator.fit are used as probabilities to draw the indices/rows
  • sample_weight in base_estimator.fit are used to represent the indices (more memory efficient than indexing), this is possible only if base_estimator.fit supports sample_weight (through metadata routing or natively).

#31165 introduced a new sampling_strategy argument to choose indexing/weighting for row sampling, but it would be better to do this in a dedicated follow up PR.

cc @ogrisel @GaetandeCast

Copy link

github-actions bot commented May 22, 2025

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: f449770. Link to the linter CI: here

@antoinebaker
Copy link
Contributor Author

antoinebaker commented May 22, 2025

BaggingRegressor(estimator=Ridge(), max_samples=100) now passes the statistical repeated/weighted equivalence test

image

Idem for BaggingClassifier(estimator=LogisticRegression(), max_samples=100)) and varying max_samples.

@antoinebaker
Copy link
Contributor Author

However it fails (as expected) for bootstrap=False (draw without replacement), for example BaggingRegressor(estimator=Ridge(), bootstrap=False, max_samples=10)

image

@ogrisel
Copy link
Member

ogrisel commented May 23, 2025

However it fails (as expected) for bootstrap=False (draw without replacement).

Could you please document this known limitation, both in the docstring of the __init__ method for the bootstrap parameter and in the docstring of the fit method for the sample_weight parameter?

Something like: "Note that the expected frequency semantics for the sample_weight parameter are only fulfilled when sampling with replacement bootstrap=True".

Maybe we should raise a warning when calling BaggingClassifier(bootstrap=False, max_samples=0.5).fit(X, y, sample_weight=sample_weight) with sample_weight is not None. The warning is already implemented and tested: https://github.com/scikit-learn/scikit-learn/pull/31414/files#diff-b7c01e77fe68ded1e41868f4a7e142190f935261624d4abdb299913ef944cbbbR676-R682.

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is a pass of review. Could you please add a non-regression test using a small dataset with specifically engineered weights? For instance, you could have a dataset with 100 datapoints, with 98 data points with a null weight, 1 data point, with a weight of 1 and 1 with a weight of 2:

X = np.arange(100).reshape(-1, 1)
y = (X < 99).astype(np.int32)
sample_weight = np.zeros(shape=X.shape[0])
sample_weight[0] = 1
sample_weight[-1] = 2

Then you could fit a BaggingRegressor and a BaggleClassifier with a fake test estimator that just records the values passed as X, y and sample_weight as fitted attribute to be able to write assertions in the test.

Ideally this test should pass both with metadata routing enabled and disabled.

Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
@ogrisel
Copy link
Member

ogrisel commented May 28, 2025

BTW @antoinebaker once this PR has been finalized with tests, it would be great to open a similar PR for random forests. I suppose their bad handling of sample weights stems from the same root cause and a similar fix should be applicable.

antoinebaker and others added 2 commits June 2, 2025 09:08
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @antoinebaker. Here is another pass of feedback but otherwise LGTM.

@ogrisel
Copy link
Member

ogrisel commented Jun 3, 2025

but otherwise LGTM.

Actually no: I tried to run https://github.com/snath-xoc/sample-weight-audit-nondet/blob/main/reports/sklearn_estimators_sample_weight_audit_report.ipynb against this branch and I still get a p-value lower than 1e-33 for this branch. It's an improvement over the < 1e-54 I measured on main but still, the bug does not seem fixed for classifiers.

I confirm the bug is fixed for the regressor though. So I must be missing something.

@antoinebaker
Copy link
Contributor Author

I confirm the bug is fixed for the regressor though. So I must be missing something.

Did you specify max_samples as an integer eg max_samples=10 ? Otherwise you might get different number of samples in the repeated/weighted datasets #31165 (comment)

Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
@ogrisel
Copy link
Member

ogrisel commented Jun 3, 2025

I forgot about the max_samples thing. Let me try again.

EDIT: I confirm this works as expected.

antoinebaker and others added 2 commits June 3, 2025 17:58
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
@antoinebaker
Copy link
Contributor Author

@ogrisel
Copy link
Member

ogrisel commented Jun 10, 2025

@antoinebaker I pushed 28a2bde to make the sample weight semantics consistent between max_samples passed as absolute or relative values. I re-ran the statistical test, and they now always pass, whatever the value of max_samples.

I had to change the code a bit to raise ValueError with explicit messages for degenerate cases, and updated the tests accordingly. I think I prefer this behavior.

Let's see if the CI is green after this commit and I will do a proper review of the PR.

@ogrisel ogrisel moved this to In Progress in Losses and solvers Jun 10, 2025
Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming CI is green, the diff LGTM besides the following details:

@ogrisel
Copy link
Member

ogrisel commented Jun 10, 2025

Maybe @jeremiedbb and @snath-xoc would like to review this PR.

Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
@antoinebaker
Copy link
Contributor Author

@antoinebaker I pushed 28a2bde to make the sample weight semantics consistent between max_samples passed as absolute or relative values. I re-ran the statistical test, and they now always pass, whatever the value of max_samples.

I like the new semantic and raising an error if sw_sum < 1.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: In Progress
Development

Successfully merging this pull request may close these issues.

2 participants