Skip to content

MNT Avoid catastrophic cancellation in mean_variance_axis #19766

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

jeremiedbb
Copy link
Member

@jeremiedbb jeremiedbb commented Mar 25, 2021

Fixes #19546

Fixes the unexpected lack of precision of the variance when input is sparse, with weights, and when the variance should actually be 0, as described here #19450 (comment)

With this PR the result is 0 as expected

import numpy as np
from scipy.sparse import csr_matrix
from sklearn.utils.sparsefuncs import mean_variance_axis
n_samples = 100
sw = np.random.rand(n_samples)
X = np.zeros(shape=(n_samples, 2))
X[:, 1] = 1.

mean_variance_axis(csr_matrix(X), axis=0, weights=sw)[1]
# array([0., 0.])

@ogrisel
Copy link
Member

ogrisel commented Mar 26, 2021

Possibly related to #19546?

@jeremiedbb jeremiedbb changed the title MNT Improve precision of variance on sparse input with weights [WIP] MNT Improve precision of variance on sparse input with weights Mar 26, 2021
@jeremiedbb
Copy link
Member Author

Possibly related to #19546?

Absolutely. It aims to solve the same issue. However, it turns out that it's still not as precise as the dense case. I'm digging further :)

@jeremiedbb jeremiedbb changed the title [WIP] MNT Improve precision of variance on sparse input with weights MNT Avoid catastrophic cancellation in mean_variance_axis Mar 29, 2021
@jeremiedbb
Copy link
Member Author

I reworked the PR to focus on the catastrophic cancellation described in #19546, caused by variances[i] += (sum_weights[i] - sum_weights_nz[i]) * means[i]**2. The precision is now comparable the dense case one. Here are the results of the gist linked in #19546 with various rng.

## dtype=float64
_incremental_mean_and_var [100.] [0.]
csr_mean_variance_axis0 [100.] [7.2701421e-27]
incr_mean_variance_axis0 csr [100.] [7.2701421e-27]
csc_mean_variance_axis0 [100.] [7.2701421e-27]
incr_mean_variance_axis0 csc [100.] [7.2701421e-27]
## dtype=float32
_incremental_mean_and_var [100.00000577] [3.32692735e-11]
csr_mean_variance_axis0 [99.99997] [9.3132246e-10]
incr_mean_variance_axis0 csr [99.99997] [9.3132246e-10]
csc_mean_variance_axis0 [99.99997] [9.3132246e-10]
incr_mean_variance_axis0 csc [99.99997] [9.3132246e-10]
## dtype=float64
_incremental_mean_and_var [100.] [0.]
csr_mean_variance_axis0 [100.] [0.]
incr_mean_variance_axis0 csr [100.] [0.]
csc_mean_variance_axis0 [100.] [0.]
incr_mean_variance_axis0 csc [100.] [0.]
## dtype=float32
_incremental_mean_and_var [99.99999932] [4.66211111e-13]
csr_mean_variance_axis0 [99.99993] [4.7148196e-09]
incr_mean_variance_axis0 csr [99.99993] [4.7148196e-09]
csc_mean_variance_axis0 [99.99993] [4.7148196e-09]
incr_mean_variance_axis0 csc [99.99993] [4.7148196e-09]
## dtype=float64
_incremental_mean_and_var [100.] [2.01948392e-28]
csr_mean_variance_axis0 [100.] [1.81753553e-27]
incr_mean_variance_axis0 csr [100.] [1.81753553e-27]
csc_mean_variance_axis0 [100.] [1.81753553e-27]
incr_mean_variance_axis0 csc [100.] [1.81753553e-27]
## dtype=float32
_incremental_mean_and_var [99.99999692] [9.51546741e-12]
csr_mean_variance_axis0 [100.] [0.]
incr_mean_variance_axis0 csr [100.] [0.]
csc_mean_variance_axis0 [100.] [0.]
incr_mean_variance_axis0 csc [100.] [0.]

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice fix @jeremiedbb! I assume this code has changed too much in main to consider a backport for 0.24.2, but this is fine with me.

@@ -131,13 +131,19 @@ def _csr_mean_variance_axis0(np.ndarray[floating, ndim=1, mode="c"] X_data,
np.ndarray[floating, ndim=1] sum_weights_nz = \
np.zeros(shape=n_features, dtype=dtype)

np.ndarray[np.uint64_t, ndim=1] counts = np.full(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A more explicit name:

Suggested change
np.ndarray[np.uint64_t, ndim=1] counts = np.full(
np.ndarray[np.uint64_t, ndim=1] counts_nan = np.full(

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sum_weights is the sum of all weights where X is not nan
sum_weights_nan is the sum of weights where X is nan
sum_weights_nz is the sum of weights where X is non zero

Following the same scheme:
counts is the number of elements which are not nan
counts_nz is the number of elements which are non zero

I'd rather keep that. Maybe you missed that the increment is negative (counts[col_ind] -= 1).
Let me try to reorder the code such that the match is clearer (and remove sum_weights_nan, we actually need only 2 out the 3 arrays).

Copy link
Member

@thomasjpfan thomasjpfan Mar 30, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I was mistaken with the suggestion. I was thinking about the negation of it, so it the suggestion should have been counts_non_nan.

If we do not change the name, I think we should at least put a comment above counts to say that it is the number of elements which are not nan.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added comments to describe the different arrays

@ogrisel
Copy link
Member

ogrisel commented Mar 30, 2021

@thomasjpfan merge?

Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Maybe @ogrisel should look this over one more time because removing sum_weights_nan is a semi-significant change from his last approval.

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I gave it another look and LGTM. Let's merge!

@ogrisel ogrisel merged commit 57d3668 into scikit-learn:main Mar 31, 2021
@glemaitre glemaitre mentioned this pull request Apr 22, 2021
12 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Weighted variance computation for sparse data is not numerically stable
3 participants