Skip to content

group mean over several axes of array with nans is wrong #1118

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
gdementen opened this issue Oct 16, 2024 · 1 comment
Open

group mean over several axes of array with nans is wrong #1118

gdementen opened this issue Oct 16, 2024 · 1 comment

Comments

@gdementen
Copy link
Contributor

gdementen commented Oct 16, 2024

What happens is the mean is computed on each axis in turn (mean of mean). When no nans are involved, we get theoretically the same result. In practice, we loose some precision but it was deemed acceptable so far. However, when nans are involved, the result is significantly wrong.

>>> arr = Array([[1, 3], [4, nan]], [Axis('a=a0,a1'), Axis('b=b0,b1')])
>>> arr
a\b   b0   b1
 a0  1.0  3.0
 a1  4.0  nan
>>> arr.mean("a0,a1 >> a01", "b0,b1 >> b01")
2.75

While this should be 2.6666... What happens is that it computes:

>>> (((1 + 4) / 2) + ((3 + 0) / 1)) / 2
2.75
>>> 1/4 + 4/4 + 3/2
2.75

Instead of:

>>> (1 + 4 + 3) / 3
2.6666666666666665
>>> 1/3 + 4/3 + 3/3
2.6666666666666665

As a workaround until larray 0.35 is released, I have recommended using:

>>> # TODO: do not use this function anymore when larray 0.35 will be available
... def nd_mean(array, axes_or_groups):
...     """
...     Computes the mean of array over axes_or_groups.
...     
...     This function is temporarily necessary because larray versions up to (and including) 0.34.x
...     behave badly when computing the means on groups over several dimensions 
...     when some values are nans. See https://github.com/larray-project/larray/issues/1118
...     """
...     return array.sum(*axes_or_groups) / (~isnan(array)).sum(*axes_or_groups)
>>> nd_mean(arr, ("a0,a1 >> a01", "b0,b1 >> b01"))
2.6666666666666665
@gdementen
Copy link
Contributor Author

Here is a version which does not output a warning on all-nan slices. I am unsure it is a good idea to do this by default though, so I added a "warn_all_nan_slices" argument, defaulting to True. I fear our users will always want it to be False, but then will miss helpful warnings when there is an actual problem with their data, so I am unsure if having the argument makes sense. What I know is that having the argument default to False is not worth it because nobody would use it, unless they were already been bitten by a bad case of this.

# TODO: do not use this function anymore when larray 0.35 will be available
def nd_mean(array, axes_or_groups, warn_all_nan_slices=True):
    """
    Computes the mean of array over axes_or_groups.
    
    This function is temporarily necessary because larray versions up to (and including) 0.34.x
    behave badly when computing the means on groups over several dimensions 
    when some values are nans. See https://github.com/larray-project/larray/issues/1118
    """
    value_sums = array.sum(*axes_or_groups)
    counts = (~isnan(array)).sum(*axes_or_groups)
    if warn_all_nan_slices:
        return value_sums / counts
    else:
        return where(counts > 0, value_sums.divnot0(counts), nan)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant