Skip to content

Commit de387e5

Browse files
committed
a[bool_expr] now uses LArray broadcasting
1 parent 98b0aed commit de387e5

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

larray/core.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -989,6 +989,8 @@ def __setitem__(self, key, value, collapse_slices=True):
989989
data = np.asarray(self)
990990

991991
if isinstance(key, np.ndarray) and np.issubdtype(key.dtype, bool):
992+
if isinstance(key, LArray):
993+
key = key.broadcast_with(self.axes)
992994
data[key] = value
993995
return
994996

@@ -1036,7 +1038,7 @@ def reshape_like(self, target):
10361038

10371039
def broadcast_with(self, target):
10381040
"""
1039-
returns an LArray that is broadcastable with target
1041+
returns an LArray that is (numpy) broadcastable with target
10401042
target can be either an LArray or any collection of Axis
10411043
10421044
* all common axes must be either 1 or the same length

larray/tests/test_la.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -712,13 +712,33 @@ def test_setitem_ndarray(self):
712712
self._assert_equal_raw(la, raw)
713713

714714
def test_setitem_bool_array_key(self):
715+
age, geo, sex, lipro = self.larray.axes
716+
715717
# LArray key
718+
# a) same shape
716719
la = self.larray.copy()
717720
raw = self.array.copy()
718721
la[la < 5] = 0
719722
raw[raw < 5] = 0
720723
self._assert_equal_raw(la, raw)
721724

725+
# b) numpy-broadcastable shape
726+
la = self.larray.copy()
727+
raw = self.array.copy()
728+
key = la[sex['F,']] < 5
729+
self.assertEqual(key.ndim, 4)
730+
la[key] = 0
731+
raw[raw[:, :, [1]] < 5] = 0
732+
self._assert_equal_raw(la, raw)
733+
734+
# c) LArray-broadcastable shape (missing axis)
735+
la = self.larray.copy()
736+
key = la[sex['F']] < 5
737+
self.assertEqual(key.ndim, 3)
738+
la[key] = 0
739+
# same raw as last test
740+
self._assert_equal_raw(la, raw)
741+
722742
# ndarray key
723743
la = self.larray.copy()
724744
raw = self.array.copy()

0 commit comments

Comments
 (0)