-
-
Notifications
You must be signed in to change notification settings - Fork 25.9k
FIX Draw indices using sample_weight in Bagging #31414
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
FIX Draw indices using sample_weight in Bagging #31414
Conversation
Could you please document this known limitation, both in the docstring of the Something like: "Note that the expected frequency semantics for the
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is a pass of review. Could you please add a non-regression test using a small dataset with specifically engineered weights? For instance, you could have a dataset with 100 datapoints, with 98 data points with a null weight, 1 data point, with a weight of 1 and 1 with a weight of 2:
X = np.arange(100).reshape(-1, 1)
y = (X < 99).astype(np.int32)
sample_weight = np.zeros(shape=X.shape[0])
sample_weight[0] = 1
sample_weight[-1] = 2
Then you could fit a BaggingRegressor
and a BaggleClassifier
with a fake test estimator that just records the values passed as X
, y
and sample_weight
as fitted attribute to be able to write assertions in the test.
Ideally this test should pass both with metadata routing enabled and disabled.
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
BTW @antoinebaker once this PR has been finalized with tests, it would be great to open a similar PR for random forests. I suppose their bad handling of sample weights stems from the same root cause and a similar fix should be applicable. |
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, @antoinebaker. Here is another pass of feedback but otherwise LGTM.
Actually no: I tried to run https://github.com/snath-xoc/sample-weight-audit-nondet/blob/main/reports/sklearn_estimators_sample_weight_audit_report.ipynb against this branch and I still get a p-value lower than 1e-33 for this branch. It's an improvement over the < 1e-54 I measured on I confirm the bug is fixed for the regressor though. So I must be missing something. |
Did you specify |
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
I forgot about the EDIT: I confirm this works as expected. |
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
The CI test failure seems unrelated to this PR (forbidden request in |
@antoinebaker I pushed 28a2bde to make the sample weight semantics consistent between I had to change the code a bit to raise Let's see if the CI is green after this commit and I will do a proper review of the PR. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Assuming CI is green, the diff LGTM besides the following details:
Maybe @jeremiedbb and @snath-xoc would like to review this PR. |
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
I like the new semantic and raising an error if sw_sum < 1. |
From an initial pass LGTM, all tests pass for me as well. Will wait for @jeremiedbb to review but otherwise looks good to go! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Just a nitpick and a remark for which we can't really do anything about :)
sw_sum = np.sum(sample_weight) | ||
if sw_sum <= 1: | ||
raise ValueError( | ||
f"The total sum of sample weights is {sw_sum}, which prevents " | ||
"resampling with a fractional value for max_samples=" | ||
f"{max_samples}. Either pass max_samples as an integer or " | ||
"use a larger sample_weight." | ||
) | ||
max_samples = max(int(max_samples * sw_sum), 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Something doesn't feel right about this approach.
If max_samples
is a float in [0,1], I'd interpret it as a fraction of the sum of the weights and so to draw a number of samples that sums on average to max_samples * sw_sum
. For instance if max_samples=0.5, I'd expect to draw samples such that the sum of their weight is on average half the total sum of the weights.
This is not the case here since we're turning it into an int being the number of samples to draw. That's why there's this issue with small weights in particular.
That being said I don't have any alternative to propose. At least the docstring is clear about how max_samples
is related to the actual number samples drawn. So I guess this is good enough for us.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not the case here since we're turning it into an int being the number of samples to draw. That's why there's this issue with small weights in particular.
I am not sure, I follow. Does your comment specifically refer to the edge case where sw_sum >= 1
but int(max_samples * sw_sum) == 0
in which case the max
operator uses 1 instead? I think this is really an edge case and we can. We could raise a warning, but the user wouldn't be able to do anything about it. Furthermore, I expect this case to be very rare in practice.
Besides this extreme edge case, I think your expectation that we "draw samples such that the sum of their weight is on average half the total sum of the weights." should be met, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No it's not about this edge case.
Take for instance sw = np.array([1.2, 3.4, 4.7, 5.6, 2.2, 2.9])
. We have sw.sum()=20
. If I set max_samples=0.5
, intuitively I'd expect to draw samples such that the sum of their weight is close to 10 on average. But here max_samples * sw_sum = 10
so we'll sample 10 points and on average the sum of their weights is 10 * sw.mean() = 33.33
so more than 3 times my expectation.
On the opposite, if the samples weights sum to a value less than n_samples, we'll draw points such that the sum of their weight is less than the expected. Actually I think the expected sum of weights is int(max_samples * sw_sum) * sw_mean
, so only equals to int(max_samples * sw_sum)
if sw_mean=1
. To get the expected sum of weights we should then draw int(max_samples * n_samples)
points, which leads to an average sum of weights of max_samples * sw_sum
.
But this was the previous implementation and used to break the equivalence between weighted and repeated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I set max_samples=0.5, intuitively I'd expect to draw samples such that the sum of their weight is close to 10 on average. But here max_samples * sw_sum = 10 so we'll sample 10 points and on average the sum of their weights is 10 * sw.mean() = 33.33 so more than 3 times my expectation.
I don't think that's what this PR does. What we do is:
- generate
indices
(with replacement) withmax_samples * sw_sum ~= 10
elements with replacement (sklearn/ensemble/_bagging.py:90
); - then pass
indices_as_sample_weight = np.bincount(indices)
(sklearn/ensemble/_bagging.py:172
) as thesample_weight
param of the base estimator to simulate fitting on this resampling using sample weights. Note that we do not reuse thesample_weight
values passed by the user a second time for this step. This avoids double accounting.
Personally, I don't think there is a problem in the current state of the PR and the statistical tests seem to confirm this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As I said I don't have a better alternative to offer so I'm okay with this.
The issue for me comes from the fact that a parameter is tied to n_samples and not the weight sum. That's why we're able to have the equivalence between weighted and repeated but a lot harder with a rescaling of the weights. Here you'd get an error if you normalize your weights in advance (then sw_sum = 1), which feels like a non-optimal behavior to me.
Co-authored-by: Jérémie du Boisberranger <jeremie@probabl.ai>
I adapted the stochastic sample weighting test to check the performance when sample_weight.sum()<1 and max_sample=1 (see here), with the following results: I agree this is probably an edge case which should raise a warning (can't think of an obvious solution to it. NOTE: I had to rescale non-integer weights when constructing the repeated dataset and scale them back again... hopefully this doesn't conflate the two things (it may though) |
Is it absolute What I find nice with the redefinition of They should also pass the global rescaling of the weights. However we should be cautious with the rescaling of |
yes it is passed as an (int), otherwise the error is raised indeed |
Part of #16298 and alternative to #31165.
What does this implement/fix? Explain your changes.
In Bagging estimators,
sample_weight
is now used to draw the samples and no longer forwarded to the underlying estimators. Bagging estimators now pass the statistical repeated/weighted equivalence test whenbootstrap=True
(the default, ie draw with replacement).Compared to #31165, it better decouples two different usages of
sample_weight
:sample_weight
inbagging_estimator.fit
are used as probabilities to draw the indices/rowssample_weight
inbase_estimator.fit
are used to represent the indices (more memory efficient than indexing), this is possible only ifbase_estimator.fit
supportssample_weight
(through metadata routing or natively).#31165 introduced a new
sampling_strategy
argument to choose indexing/weighting for row sampling, but it would be better to do this in a dedicated follow up PR.cc @ogrisel @GaetandeCast