-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
MNT Avoid catastrophic cancellation in mean_variance_axis #19766
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MNT Avoid catastrophic cancellation in mean_variance_axis #19766
Conversation
Possibly related to #19546? |
Absolutely. It aims to solve the same issue. However, it turns out that it's still not as precise as the dense case. I'm digging further :) |
I reworked the PR to focus on the catastrophic cancellation described in #19546, caused by
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice fix @jeremiedbb! I assume this code has changed too much in main
to consider a backport for 0.24.2, but this is fine with me.
@@ -131,13 +131,19 @@ def _csr_mean_variance_axis0(np.ndarray[floating, ndim=1, mode="c"] X_data, | |||
np.ndarray[floating, ndim=1] sum_weights_nz = \ | |||
np.zeros(shape=n_features, dtype=dtype) | |||
|
|||
np.ndarray[np.uint64_t, ndim=1] counts = np.full( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A more explicit name:
np.ndarray[np.uint64_t, ndim=1] counts = np.full( | |
np.ndarray[np.uint64_t, ndim=1] counts_nan = np.full( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sum_weights
is the sum of all weights where X is not nan
sum_weights_nan
is the sum of weights where X is nan
sum_weights_nz
is the sum of weights where X is non zero
Following the same scheme:
counts
is the number of elements which are not nan
counts_nz
is the number of elements which are non zero
I'd rather keep that. Maybe you missed that the increment is negative (counts[col_ind] -= 1
).
Let me try to reorder the code such that the match is clearer (and remove sum_weights_nan
, we actually need only 2 out the 3 arrays).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I was mistaken with the suggestion. I was thinking about the negation of it, so it the suggestion should have been counts_non_nan
.
If we do not change the name, I think we should at least put a comment above counts
to say that it is the number of elements which are not nan.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added comments to describe the different arrays
@thomasjpfan merge? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Maybe @ogrisel should look this over one more time because removing sum_weights_nan
is a semi-significant change from his last approval.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I gave it another look and LGTM. Let's merge!
Fixes #19546
Fixes the unexpected lack of precision of the variance when input is sparse, with weights, and when the variance should actually be 0, as described here #19450 (comment)
With this PR the result is 0 as expected