@@ -6991,44 +6991,42 @@ def where(condition, x=_NoValue, y=_NoValue):
6991
6991
[6.0 -- 8.0]]
6992
6992
6993
6993
"""
6994
- missing = (x is _NoValue , y is _NoValue ).count (True )
6995
6994
6995
+ # handle the single-argument case
6996
+ missing = (x is _NoValue , y is _NoValue ).count (True )
6996
6997
if missing == 1 :
6997
6998
raise ValueError ("Must provide both 'x' and 'y' or neither." )
6998
6999
if missing == 2 :
6999
- return filled (condition , 0 ).nonzero ()
7000
-
7001
- # Both x and y are provided
7002
-
7003
- # Get the condition
7004
- fc = filled (condition , 0 ).astype (MaskType )
7005
- notfc = np .logical_not (fc )
7006
-
7007
- # Get the data
7008
- xv = getdata (x )
7009
- yv = getdata (y )
7010
- if x is masked :
7011
- ndtype = yv .dtype
7012
- elif y is masked :
7013
- ndtype = xv .dtype
7014
- else :
7015
- ndtype = np .find_common_type ([xv .dtype , yv .dtype ], [])
7016
-
7017
- # Construct an empty array and fill it
7018
- d = np .empty (fc .shape , dtype = ndtype ).view (MaskedArray )
7019
- np .copyto (d ._data , xv .astype (ndtype ), where = fc )
7020
- np .copyto (d ._data , yv .astype (ndtype ), where = notfc )
7021
-
7022
- # Create an empty mask and fill it
7023
- mask = np .zeros (fc .shape , dtype = MaskType )
7024
- np .copyto (mask , getmask (x ), where = fc )
7025
- np .copyto (mask , getmask (y ), where = notfc )
7026
- mask |= getmaskarray (condition )
7027
-
7028
- # Use d._mask instead of d.mask to avoid copies
7029
- d ._mask = mask if mask .any () else nomask
7000
+ return nonzero (condition )
7001
+
7002
+ # we only care if the condition is true - false or masked pick y
7003
+ cf = filled (condition , False )
7004
+ xd = getdata (x )
7005
+ yd = getdata (y )
7006
+
7007
+ # we need the full arrays here for correct final dimensions
7008
+ cm = getmaskarray (condition )
7009
+ xm = getmaskarray (x )
7010
+ ym = getmaskarray (y )
7011
+
7012
+ # deal with the fact that masked.dtype == float64, but we don't actually
7013
+ # want to treat it as that.
7014
+ if x is masked and y is not masked :
7015
+ xd = np .zeros ((), dtype = yd .dtype )
7016
+ xm = np .ones ((), dtype = ym .dtype )
7017
+ elif y is masked and x is not masked :
7018
+ yd = np .zeros ((), dtype = xd .dtype )
7019
+ ym = np .ones ((), dtype = xm .dtype )
7020
+
7021
+ data = np .where (cf , xd , yd )
7022
+ mask = np .where (cf , xm , ym )
7023
+ mask = np .where (cm , np .ones ((), dtype = mask .dtype ), mask )
7024
+
7025
+ # collapse the mask, for backwards compatibility
7026
+ if mask .dtype == np .bool_ and not mask .any ():
7027
+ mask = nomask
7030
7028
7031
- return d
7029
+ return masked_array ( data , mask = mask )
7032
7030
7033
7031
7034
7032
def choose (indices , choices , out = None , mode = 'raise' ):
0 commit comments