Skip to content

Commit faa1ed8

Browse files
committed
TST: Expand TestArrayEqual.test_masked_scalar
1 parent 32ba514 commit faa1ed8

File tree

1 file changed

+32
-6
lines changed

1 file changed

+32
-6
lines changed

numpy/testing/tests/test_utils.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -197,13 +197,39 @@ def test_masked_nan_inf(self):
197197
self._test_equal(a, b)
198198
self._test_equal(b, a)
199199

200+
# Also provides test cases for gh-11121
200201
def test_masked_scalar(self):
201-
a = np.ma.MaskedArray(3., mask=True)
202-
b = np.array(3.)
203-
self._test_equal(a, b)
204-
# also a test case for gh-11121
205-
b = np.array(np.nan)
206-
self._test_equal(a, b)
202+
# Test masked scalar vs. plain/masked scalar
203+
for a_val, b_val, b_masked in itertools.product(
204+
[3., np.nan, np.inf],
205+
[3., 4., np.nan, np.inf, -np.inf],
206+
[False, True],
207+
):
208+
a = np.ma.MaskedArray(a_val, mask=True)
209+
b = np.ma.MaskedArray(b_val, mask=True) if b_masked else np.array(b_val)
210+
self._test_equal(a, b)
211+
self._test_equal(b, a)
212+
213+
# Test masked scalar vs. plain array
214+
for a_val, b_val in itertools.product(
215+
[3., np.nan, -np.inf],
216+
itertools.product([3., 4., np.nan, np.inf, -np.inf], repeat=2),
217+
):
218+
a = np.ma.MaskedArray(a_val, mask=True)
219+
b = np.array(b_val)
220+
self._test_equal(a, b)
221+
self._test_equal(b, a)
222+
223+
# Test masked scalar vs. masked array
224+
for a_val, b_val, b_mask in itertools.product(
225+
[3., np.nan, np.inf],
226+
itertools.product([3., 4., np.nan, np.inf, -np.inf], repeat=2),
227+
itertools.product([False, True], repeat=2),
228+
):
229+
a = np.ma.MaskedArray(a_val, mask=True)
230+
b = np.ma.MaskedArray(b_val, mask=b_mask)
231+
self._test_equal(a, b)
232+
self._test_equal(b, a)
207233

208234
def test_subclass_that_overrides_eq(self):
209235
# While we cannot guarantee testing functions will always work for

0 commit comments

Comments
 (0)