Skip to content

Commit 2d229ef

Browse files
committed
BUG: Make np.ma.where delegate to np.where
Fixes #8600 and #8599 Also makes np.ma.masked work with structured dtypes.
1 parent 6a3edf3 commit 2d229ef

File tree

2 files changed

+63
-33
lines changed

2 files changed

+63
-33
lines changed

numpy/ma/core.py

Lines changed: 31 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -6991,44 +6991,42 @@ def where(condition, x=_NoValue, y=_NoValue):
69916991
[6.0 -- 8.0]]
69926992
69936993
"""
6994-
missing = (x is _NoValue, y is _NoValue).count(True)
69956994

6995+
# handle the single-argument case
6996+
missing = (x is _NoValue, y is _NoValue).count(True)
69966997
if missing == 1:
69976998
raise ValueError("Must provide both 'x' and 'y' or neither.")
69986999
if missing == 2:
6999-
return filled(condition, 0).nonzero()
7000-
7001-
# Both x and y are provided
7002-
7003-
# Get the condition
7004-
fc = filled(condition, 0).astype(MaskType)
7005-
notfc = np.logical_not(fc)
7006-
7007-
# Get the data
7008-
xv = getdata(x)
7009-
yv = getdata(y)
7010-
if x is masked:
7011-
ndtype = yv.dtype
7012-
elif y is masked:
7013-
ndtype = xv.dtype
7014-
else:
7015-
ndtype = np.find_common_type([xv.dtype, yv.dtype], [])
7016-
7017-
# Construct an empty array and fill it
7018-
d = np.empty(fc.shape, dtype=ndtype).view(MaskedArray)
7019-
np.copyto(d._data, xv.astype(ndtype), where=fc)
7020-
np.copyto(d._data, yv.astype(ndtype), where=notfc)
7021-
7022-
# Create an empty mask and fill it
7023-
mask = np.zeros(fc.shape, dtype=MaskType)
7024-
np.copyto(mask, getmask(x), where=fc)
7025-
np.copyto(mask, getmask(y), where=notfc)
7026-
mask |= getmaskarray(condition)
7027-
7028-
# Use d._mask instead of d.mask to avoid copies
7029-
d._mask = mask if mask.any() else nomask
7000+
return nonzero(condition)
7001+
7002+
# we only care if the condition is true - false or masked pick y
7003+
cf = filled(condition, False)
7004+
xd = getdata(x)
7005+
yd = getdata(y)
7006+
7007+
# we need the full arrays here for correct final dimensions
7008+
cm = getmaskarray(condition)
7009+
xm = getmaskarray(x)
7010+
ym = getmaskarray(y)
7011+
7012+
# deal with the fact that masked.dtype == float64, but we don't actually
7013+
# want to treat it as that.
7014+
if x is masked and y is not masked:
7015+
xd = np.zeros((), dtype=yd.dtype)
7016+
xm = np.ones((), dtype=ym.dtype)
7017+
elif y is masked and x is not masked:
7018+
yd = np.zeros((), dtype=xd.dtype)
7019+
ym = np.ones((), dtype=xm.dtype)
7020+
7021+
data = np.where(cf, xd, yd)
7022+
mask = np.where(cf, xm, ym)
7023+
mask = np.where(cm, np.ones((), dtype=mask.dtype), mask)
7024+
7025+
# collapse the mask, for backwards compatibility
7026+
if mask.dtype == np.bool_ and not mask.any():
7027+
mask = nomask
70307028

7031-
return d
7029+
return masked_array(data, mask=mask)
70327030

70337031

70347032
def choose(indices, choices, out=None, mode='raise'):

numpy/ma/tests/test_core.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3942,6 +3942,38 @@ def test_where_type(self):
39423942
control = np.find_common_type([np.int32, np.float32], [])
39433943
assert_equal(test, control)
39443944

3945+
def test_where_broadcast(self):
3946+
# Issue 8599
3947+
x = np.arange(9).reshape(3, 3)
3948+
y = np.zeros(3)
3949+
core = np.where([1, 0, 1], x, y)
3950+
ma = where([1, 0, 1], x, y)
3951+
3952+
assert_equal(core, ma)
3953+
assert_equal(core.dtype, ma.dtype)
3954+
3955+
def test_where_structured(self):
3956+
# Issue 8600
3957+
dt = np.dtype([('a', int), ('b', int)])
3958+
x = np.array([(1, 2), (3, 4), (5, 6)], dtype=dt)
3959+
y = np.array((10, 20), dtype=dt)
3960+
core = np.where([0, 1, 1], x, y)
3961+
ma = np.where([0, 1, 1], x, y)
3962+
3963+
assert_equal(core, ma)
3964+
assert_equal(core.dtype, ma.dtype)
3965+
3966+
def test_where_structured_masked(self):
3967+
dt = np.dtype([('a', int), ('b', int)])
3968+
x = np.array([(1, 2), (3, 4), (5, 6)], dtype=dt)
3969+
3970+
ma = where([0, 1, 1], x, masked)
3971+
expected = masked_where([1, 0, 0], x)
3972+
3973+
assert_equal(ma.dtype, expected.dtype)
3974+
assert_equal(ma, expected)
3975+
assert_equal(ma.mask, expected.mask)
3976+
39453977
def test_choose(self):
39463978
# Test choose
39473979
choices = [[0, 1, 2, 3], [10, 11, 12, 13],

0 commit comments

Comments
 (0)