Skip to content

Commit e8bec94

Browse files
committed
FIX: reindex using a list of labels in a dict (closes #1068)
1 parent 53977b7 commit e8bec94

File tree

3 files changed

+93
-42
lines changed

3 files changed

+93
-42
lines changed

doc/source/changes/version_0_34_2.rst.inc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,4 +55,6 @@ Miscellaneous improvements
5555
Fixes
5656
^^^^^
5757

58-
* fixed reindex when using an axis object from the array as `axes_to_reindex` (closes :issue:`1088`).
58+
* fixed Array.reindex when using an axis object from the array as `axes_to_reindex` (closes :issue:`1088`).
59+
60+
* fixed Array.reindex({axis: list_of_labels}) (closes :issue:`1068`).

larray/core/array.py

Lines changed: 63 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1985,52 +1985,79 @@ def reindex(self, axes_to_reindex=None, new_axis=None, fill_value=nan, inplace=F
19851985
a0 c0 -1 -1
19861986
a1 c0 2 1
19871987
"""
1988-
if isinstance(axes_to_reindex, str) and '=' in axes_to_reindex:
1989-
axes_to_reindex = Axis(axes_to_reindex)
1990-
elif isinstance(axes_to_reindex, Group):
1991-
axes_to_reindex = Axis(axes_to_reindex)
1988+
def labels_def_and_name_to_axis(labels_def, axis_name=None):
1989+
# TODO: the rename functionality seems weird to me.
1990+
# I think we should either raise an error if the axis name
1991+
# is different (force using new_axis=other_axis.labels instead
1992+
# of new_axis=other_axis) OR do not do use the old name
1993+
# (and make sure this effectively does a rename).
1994+
# it might have been the unintended consequence of supporting a
1995+
# list of labels as new_axis
1996+
axis = labels_def if isinstance(labels_def, Axis) else Axis(labels_def)
1997+
return axis.rename(axis_name) if axis_name is not None else axis
1998+
1999+
def axis_ref_to_axis(axes, axis_ref):
2000+
if isinstance(axis_ref, Axis) or is_axis_ref(axis_ref):
2001+
return axes[axis_ref]
2002+
else:
2003+
raise TypeError(
2004+
"In Array.reindex, source axes must be Axis objects or axis references ('axis name', "
2005+
"X.axis_name or axis_integer_position) but got object of "
2006+
f"type {type(axis_ref).__name__} instead."
2007+
)
2008+
2009+
def is_axis_ref(axis_ref):
2010+
return isinstance(axis_ref, (int, str, AxisReference))
2011+
2012+
def is_axis_def(axis_def):
2013+
return ((isinstance(axis_def, str) and '=' in axis_def)
2014+
or isinstance(axis_def, Group))
19922015

1993-
# XXX: can't we move this to AxisCollection.replace?
19942016
if new_axis is None:
1995-
if isinstance(axes_to_reindex, (int, str, AxisReference)):
2017+
if isinstance(axes_to_reindex, Axis) and not isinstance(axes_to_reindex, AxisReference):
2018+
axes_to_reindex = {axes_to_reindex: axes_to_reindex}
2019+
elif is_axis_def(axes_to_reindex):
2020+
axis = Axis(axes_to_reindex)
2021+
axes_to_reindex = {axis: axis}
2022+
elif is_axis_ref(axes_to_reindex):
19962023
raise TypeError("In Array.reindex, when using an axis reference ('axis name', X.axis_name or "
19972024
"axis_integer_position) as axes_to_reindex, you must provide a value for `new_axis`.")
1998-
elif isinstance(axes_to_reindex, Axis):
1999-
new_axis = axes_to_reindex
2000-
axes_to_reindex = self.axes[axes_to_reindex]
2001-
else:
2002-
assert isinstance(axes_to_reindex, (tuple, list, dict, AxisCollection))
2025+
# otherwise axes_to_reindex should be None (when kwargs are used),
2026+
# a dict or a sequence of axes
2027+
# axes_to_reindex can be None when kwargs are used
2028+
assert (axes_to_reindex is None or
2029+
isinstance(axes_to_reindex, (tuple, list, dict, AxisCollection)))
20032030
else:
2004-
if isinstance(axes_to_reindex, (int, str, Axis)):
2005-
axes_to_reindex = self.axes[axes_to_reindex]
2006-
else:
2031+
if not (isinstance(axes_to_reindex, Axis) or is_axis_ref(axes_to_reindex)):
20072032
raise TypeError(
2008-
"In Array.reindex, when `new_axis` is used, `axes_to_reindex`"
2009-
" must be an Axis object or an axis reference ('axis name', "
2010-
"X.axis_name or axis_integer_position) but got object of "
2011-
f"type {type(axes_to_reindex).__name__} instead."
2033+
"In Array.reindex, when `new_axis` is used, `axes_to_reindex` "
2034+
"must be an Axis object or an axis reference ('axis name', "
2035+
f"X.axis_name or axis_integer_position) but got {axes_to_reindex} "
2036+
f"(which is of type {type(axes_to_reindex).__name__}) instead."
20122037
)
2013-
assert isinstance(axes_to_reindex, Axis)
2014-
old_axis_name = axes_to_reindex.name
2015-
2016-
if not isinstance(new_axis, Axis):
2017-
new_axis = Axis(new_axis)
2018-
# TODO: this functionality seems weird to me.
2019-
# I think we should either raise an error if the axis name
2020-
# is different (force using new_axis=other_axis.labels instead
2021-
# of new_axis=other_axis) OR do not do use the old name
2022-
# (and make sure this effectively does a rename)
2023-
new_axis = new_axis.rename(old_axis_name)
2038+
axes_to_reindex = {axes_to_reindex: new_axis}
2039+
new_axis = None
20242040

2025-
if isinstance(axes_to_reindex, (list, tuple)) and all([isinstance(axis, Axis) for axis in axes_to_reindex]):
2041+
if isinstance(axes_to_reindex, (list, tuple)):
20262042
axes_to_reindex = AxisCollection(axes_to_reindex)
20272043

2044+
assert new_axis is None
2045+
assert axes_to_reindex is None or isinstance(axes_to_reindex, (dict, AxisCollection))
2046+
20282047
if isinstance(axes_to_reindex, AxisCollection):
2029-
assert new_axis is None
2030-
# add extra axes if needed
2048+
# | axes_to_reindex is needed because axes_to_reindex can contain more axes than self.axes
20312049
res_axes = AxisCollection([axes_to_reindex.get(axis, axis) for axis in self.axes]) | axes_to_reindex
20322050
else:
2033-
res_axes = self.axes.replace(axes_to_reindex, new_axis, **kwargs)
2051+
# TODO: move this to AxisCollection.replace
2052+
if isinstance(axes_to_reindex, dict):
2053+
new_axes_to_reindex = {}
2054+
for k, v in axes_to_reindex.items():
2055+
src_axis = axis_ref_to_axis(self.axes, k)
2056+
dst_axis = labels_def_and_name_to_axis(v, src_axis.name)
2057+
new_axes_to_reindex[src_axis] = dst_axis
2058+
axes_to_reindex = new_axes_to_reindex
2059+
2060+
res_axes = self.axes.replace(axes_to_reindex, **kwargs)
20342061
res = full(res_axes, fill_value, dtype=common_dtype((self.data, fill_value)))
20352062

20362063
def get_group(res_axes, self_axis):
@@ -2039,9 +2066,9 @@ def get_group(res_axes, self_axis):
20392066
return self_axis[:]
20402067
else:
20412068
return self_axis[self_axis.intersection(res_axis).labels]
2042-
self_labels = tuple(get_group(res_axes, axis) for axis in self.axes)
2043-
res_labels = tuple(res_axes[group.axis][group] for group in self_labels)
2044-
res[res_labels] = self[self_labels]
2069+
self_groups = tuple(get_group(res_axes, axis) for axis in self.axes)
2070+
res_groups = tuple(res_axes[group.axis][group] for group in self_groups)
2071+
res[res_groups] = self[self_groups]
20452072
if inplace:
20462073
self.axes = res.axes
20472074
self.data = res.data

larray/tests/test_array.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3127,6 +3127,22 @@ def test_reindex():
31273127
res = arr.reindex('a', 'b=a0,a1,a2')
31283128
assert_larray_nan_equal(res, expected)
31293129

3130+
# using a list as the new labels
3131+
res = arr.reindex('a', ['a0', 'a1', 'a2'])
3132+
assert_larray_nan_equal(res, expected)
3133+
3134+
# using the dict syntax
3135+
res = arr.reindex({'a': new_a})
3136+
assert_larray_nan_equal(res, expected)
3137+
3138+
# using the dict syntax with a list of labels (issue #1068)
3139+
res = arr.reindex({'a': ['a0', 'a1', 'a2']})
3140+
assert_larray_nan_equal(res, expected)
3141+
3142+
# using the dict syntax with a labels def string
3143+
res = arr.reindex({'a': 'a0,a1,a2'})
3144+
assert_larray_nan_equal(res, expected)
3145+
31303146
# test error conditions
31313147
msg = ("In Array.reindex, when using an axis reference ('axis name', X.axis_name or "
31323148
"axis_integer_position) as axes_to_reindex, you must provide a value for `new_axis`.")
@@ -3141,16 +3157,22 @@ def test_reindex():
31413157

31423158
msg_tmpl = ("In Array.reindex, when `new_axis` is used, `axes_to_reindex`"
31433159
" must be an Axis object or an axis reference ('axis name', "
3144-
"X.axis_name or axis_integer_position) but got object of "
3145-
"type {objtype} instead.")
3160+
"X.axis_name or axis_integer_position) but got {obj_str} "
3161+
"(which is of type {obj_type}) instead.")
31463162

3147-
with must_raise(TypeError, msg_tmpl.format(objtype='list')):
3163+
msg = msg_tmpl.format(obj_str="[Axis(['a0', 'a1'], 'a')]", obj_type='list')
3164+
with must_raise(TypeError, msg):
31483165
res = arr.reindex([a], new_a)
31493166

3150-
with must_raise(TypeError, msg_tmpl.format(objtype='AxisCollection')):
3167+
msg = msg_tmpl.format(obj_str='{a}', obj_type='AxisCollection')
3168+
with must_raise(TypeError, msg):
31513169
res = arr.reindex(AxisCollection([a]), new_a)
31523170

3153-
with must_raise(TypeError, msg_tmpl.format(objtype='dict')):
3171+
msg = msg_tmpl.format(
3172+
obj_str="{Axis(['a0', 'a1'], 'a'): Axis(['a0', 'a1', 'a2'], 'a')}",
3173+
obj_type='dict'
3174+
)
3175+
with must_raise(TypeError, msg):
31543176
res = arr.reindex({a: new_a}, new_a)
31553177

31563178
# 2d array, one axis reindexed

0 commit comments

Comments
 (0)