Skip to content

Commit 22ac89a

Browse files
committed
FEAT: evaluate X.axis_name expressions in make_numpy_broadcastable
(used in all wrapped numpy functions, most notably where())
1 parent 1dfa7a2 commit 22ac89a

File tree

3 files changed

+54
-60
lines changed

3 files changed

+54
-60
lines changed

doc/source/changes/version_0_34_2.rst.inc

Lines changed: 11 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,55 +1,21 @@
11
.. py:currentmodule:: larray
22

33

4-
Syntax changes
5-
^^^^^^^^^^^^^^
6-
7-
* renamed ``Array.old_method_name()`` to :py:obj:`Array.new_method_name()` (closes :issue:`1`).
8-
9-
* renamed ``old_argument_name`` argument of :py:obj:`Array.method_name()` to ``new_argument_name``.
10-
11-
12-
Backward incompatible changes
13-
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
14-
15-
* other backward incompatible changes
16-
17-
184
New features
195
^^^^^^^^^^^^
206

21-
* added a feature (see the :ref:`miscellaneous section <misc>` for details). It works on :ref:`api-axis` and
22-
:ref:`api-group` objects.
23-
24-
Here is an example of the new feature:
25-
26-
>>> arr = ndtest((2, 3))
27-
>>> arr
28-
a\b b0 b1 b2
29-
a0 0 1 2
30-
a1 3 4 5
31-
32-
And it can also be used like this:
7+
* added support for evaluating expressions using X.axis_name when calling some
8+
built-in functions, most notably `where()`. For example, the following code
9+
now works (previously it seemed to work but produced the wrong result -- see the
10+
fixes section below): ::
3311

34-
>>> arr = ndtest("a=a0..a2")
12+
>>> arr = ndtest("age=0..3")
3513
>>> arr
36-
a a0 a1 a2
37-
0 1 2
38-
39-
* added another feature in the editor (closes :editor_issue:`1`).
40-
41-
.. note::
42-
43-
- It works for foo bar !
44-
- It does not work for foo baz !
45-
46-
47-
.. _misc:
48-
49-
Miscellaneous improvements
50-
^^^^^^^^^^^^^^^^^^^^^^^^^^
51-
52-
* improved something.
14+
age 0 1 2 3
15+
0 1 2 3
16+
>>> where(X.age == 2, 42, arr)
17+
age 0 1 2 3
18+
0 1 42 3
5319

5420

5521
Fixes
@@ -68,7 +34,7 @@ Fixes
6834
there is no array to extract the axis labels from) instead of always
6935
evaluating to True. This was especially dangerous in the context of a
7036
where() function, which always evaluated to its left side
71-
(e.g. where(X.age > 0, arr, 0) evaluated to 'arr' for all ages).
37+
(e.g. `where(X.age > 0, arr, 0)` evaluated to `arr` for all ages).
7238
Closes :issue:`1083`.
7339

7440
* expressions using `X.axis_name` and an Array now evaluate correctly when

larray/core/array.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9949,13 +9949,19 @@ def make_numpy_broadcastable(values, min_axes=None) -> Tuple[List[Array], AxisCo
99499949
--------
99509950
Axis.iscompatible : tests if axes are compatible between them.
99519951
"""
9952-
all_axes = AxisCollection.union(*[get_axes(v) for v in values])
9952+
axes_union = AxisCollection.union(*[get_axes(v) for v in values])
99539953
if min_axes is not None:
99549954
if not isinstance(min_axes, AxisCollection):
99559955
min_axes = AxisCollection(min_axes)
9956-
all_axes = min_axes | all_axes
9957-
return [v.broadcast_with(all_axes) if isinstance(v, Array) else v
9958-
for v in values], all_axes
9956+
axes_union = min_axes | axes_union
9957+
def broadcasted_value(value):
9958+
if isinstance(value, Array):
9959+
return value.broadcast_with(axes_union)
9960+
elif isinstance(value, ExprNode):
9961+
return value.evaluate(axes_union)
9962+
else:
9963+
return value
9964+
return [broadcasted_value(value) for value in values], axes_union
99599965

99609966

99619967
def raw_broadcastable(values, min_axes=None) -> Tuple[Tuple[Any, ...], AxisCollection]:

larray/tests/test_array.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5011,23 +5011,45 @@ def test_ufuncs(small_array):
50115011
assert_nparray_equal(clip(small_array, low, high).data,
50125012
np.clip(raw, raw_low, raw_high))
50135013

5014+
# round
5015+
small_float = small_array + 0.6
5016+
rounded = round(small_float)
5017+
assert_nparray_equal(rounded.data, np.round(raw + 0.6))
5018+
5019+
5020+
def test_where():
5021+
arr = ndtest((2, 3))
5022+
# a\b b0 b1 b2
5023+
# a0 0 1 2
5024+
# a1 3 4 5
5025+
5026+
expected = from_string(r"""a\b b0 b1 b2
5027+
a0 -1 -1 -1
5028+
a1 -1 4 5""")
5029+
50145030
# where (no broadcasting)
5015-
assert_nparray_equal(where(small_array < 5, -5, small_array).data,
5016-
np.where(raw < 5, -5, raw))
5031+
res = where(arr < 4, -1, arr)
5032+
assert_larray_equal(res, expected)
50175033

50185034
# where (transposed no broadcasting)
5019-
assert_nparray_equal(where(small_array < 5, -5, small_array.T).data,
5020-
np.where(raw < 5, -5, raw))
5035+
res = where(arr < 4, -1, arr.T)
5036+
assert_larray_equal(res, expected)
50215037

50225038
# where (with broadcasting)
5023-
result = where(small_array['d1'] < 5, -5, small_array)
5024-
assert result.axes.names == ['c', 'd']
5025-
assert_nparray_equal(result.data, np.where(raw[:, [0]] < 5, -5, raw))
5039+
res = where(arr['b1'] < 4, -1, arr)
5040+
assert_larray_equal(res, from_string(r"""a\b b0 b1 b2
5041+
a0 -1 -1 -1
5042+
a1 3 4 5"""))
50265043

5027-
# round
5028-
small_float = small_array + 0.6
5029-
rounded = round(small_float)
5030-
assert_nparray_equal(rounded.data, np.round(raw + 0.6))
5044+
# with expressions (issue #1083)
5045+
arr = ndtest("age=0..5")
5046+
res = where(X.age == 3, 42, arr)
5047+
assert_larray_equal(res, from_string("""age 0 1 2 3 4 5
5048+
\t 0 1 2 42 4 5"""))
5049+
5050+
res = where(X.age == 3, arr, 42)
5051+
assert_larray_equal(res, from_string("""age 0 1 2 3 4 5
5052+
\t 42 42 42 3 42 42"""))
50315053

50325054

50335055
def test_eye():

0 commit comments

Comments
 (0)