From b775376ed55de118b66d71eb535c6ce8ac570d50 Mon Sep 17 00:00:00 2001 From: Allan Haldane Date: Tue, 5 Apr 2016 14:42:36 -0400 Subject: [PATCH] BUG: MaskedArray.count treats negative axes incorrectly Follow up to #5706. Fixes #7509 --- numpy/ma/core.py | 5 +++++ numpy/ma/tests/test_core.py | 8 +++++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/numpy/ma/core.py b/numpy/ma/core.py index 97dcae1cf045..6306557e67ef 100644 --- a/numpy/ma/core.py +++ b/numpy/ma/core.py @@ -4305,6 +4305,11 @@ def count(self, axis=None, keepdims=np._NoValue): return self.size axes = axis if isinstance(axis, tuple) else (axis,) + axes = tuple(a if a >= 0 else self.ndim + a for a in axes) + if len(axes) != len(set(axes)): + raise ValueError("duplicate value in 'axis'") + if np.any([a < 0 or a >= self.ndim for a in axes]): + raise ValueError("'axis' entry is out of bounds") items = 1 for ax in axes: items *= self.shape[ax] diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py index 723257d70953..e377ff0afb17 100644 --- a/numpy/ma/tests/test_core.py +++ b/numpy/ma/tests/test_core.py @@ -1004,7 +1004,7 @@ def test_count_func(self): res = count(ott, 0) assert_(isinstance(res, ndarray)) assert_(res.dtype.type is np.intp) - assert_raises(IndexError, ott.count, axis=1) + assert_raises(ValueError, ott.count, axis=1) def test_minmax_func(self): # Tests minimum and maximum. @@ -4295,6 +4295,9 @@ def test_count(self): assert_equal(count(a, keepdims=True), 16*ones((1,1,1))) assert_equal(count(a, axis=1, keepdims=True), 2*ones((2,1,4))) assert_equal(count(a, axis=(0,1), keepdims=True), 4*ones((1,1,4))) + assert_equal(count(a, axis=-2), 2*ones((2,4))) + assert_raises(ValueError, count, a, axis=(1,1)) + assert_raises(ValueError, count, a, axis=3) # check the 'nomask' path a = np.ma.array(d, mask=nomask) @@ -4305,6 +4308,9 @@ def test_count(self): assert_equal(count(a, keepdims=True), 24*ones((1,1,1))) assert_equal(count(a, axis=1, keepdims=True), 3*ones((2,1,4))) assert_equal(count(a, axis=(0,1), keepdims=True), 6*ones((1,1,4))) + assert_equal(count(a, axis=-2), 3*ones((2,4))) + assert_raises(ValueError, count, a, axis=(1,1)) + assert_raises(ValueError, count, a, axis=3) # check the 'masked' singleton assert_equal(count(np.ma.masked), 0)