Skip to content

Commit ffcef9f

Browse files
committed
CLN: small simplifications in _key_to_raw_and_axes
* inline _translated_key (used only once) * use raw_broadcastable instead of make_numpy_broadcastable and using .data
1 parent b14727b commit ffcef9f

File tree

1 file changed

+28
-46
lines changed

1 file changed

+28
-46
lines changed

larray/core/axis.py

Lines changed: 28 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -2828,45 +2828,6 @@ def _key_to_igroups(self, key):
28282828
# translate all keys to IGroup
28292829
return tuple(self._translate_axis_key(axis_key) for axis_key in key)
28302830

2831-
def _translated_key(self, key):
2832-
"""
2833-
Transforms any key (from Array.__get|setitem__) to a complete indices-based key.
2834-
2835-
Parameters
2836-
----------
2837-
key : scalar, list/array of scalars, Group or tuple or dict of them
2838-
any key supported by Array.__get|setitem__
2839-
2840-
Returns
2841-
-------
2842-
tuple
2843-
len(tuple) == self.ndim
2844-
2845-
This key is not yet usable as is in a numpy array as it can still contain Array parts and the advanced key
2846-
parts are not broadcasted together yet.
2847-
"""
2848-
# any key -> (IGroup, IGroup, ...)
2849-
igroup_key = self._key_to_igroups(key)
2850-
2851-
# extract axis from Group keys
2852-
key_items = [(k.axis, k) for k in igroup_key]
2853-
2854-
# even keys given as dict can contain duplicates (if the same axis was
2855-
# given under different forms, e.g. name and AxisReference).
2856-
dupe_axes = list(duplicates(axis for axis, axis_key in key_items))
2857-
if dupe_axes:
2858-
dupe_axes = ', '.join(str(axis) for axis in dupe_axes)
2859-
raise ValueError(f"key has several values for axis: {dupe_axes}\n{key_items}")
2860-
2861-
# IGroup -> raw positional
2862-
dict_key = {axis: axis.index(axis_key) for axis, axis_key in key_items}
2863-
2864-
# dict -> tuple (complete and order key)
2865-
assert all(isinstance(k, Axis) for k in dict_key)
2866-
2867-
return tuple(dict_key[axis] if axis in dict_key else slice(None)
2868-
for axis in self)
2869-
28702831
def _key_to_raw_and_axes(self, key, collapse_slices=False, translate_key=True):
28712832
r"""
28722833
Transforms any key (from Array.__getitem__) to a raw numpy key, the resulting axes, and potentially a tuple
@@ -2883,10 +2844,34 @@ def _key_to_raw_and_axes(self, key, collapse_slices=False, translate_key=True):
28832844
-------
28842845
raw_key, res_axes, transposed_indices
28852846
"""
2886-
from .array import make_numpy_broadcastable, Array, sequence
2847+
from .array import raw_broadcastable, Array, sequence
28872848

28882849
if translate_key:
2889-
key = self._translated_key(key)
2850+
# complete key & translate (those two cannot be dissociated because to complete
2851+
# the key we need to know which axis each key belongs to and to do that, we need to
2852+
# translate the key to indices)
2853+
2854+
# any key -> (IGroup, IGroup, ...)
2855+
igroup_key = self._key_to_igroups(key)
2856+
2857+
# extract axis from Group keys
2858+
key_items = [(k1.axis, k1) for k1 in igroup_key]
2859+
2860+
# even keys given as dict can contain duplicates (if the same axis was
2861+
# given under different forms, e.g. name and AxisReference).
2862+
dupe_axes = list(duplicates(axis1 for axis1, key1 in key_items))
2863+
if dupe_axes:
2864+
dupe_axes = ', '.join(str(axis1) for axis1 in dupe_axes)
2865+
raise ValueError(f"key has several values for axis: {dupe_axes}\n{key_items}")
2866+
2867+
# IGroup -> raw positional
2868+
dict_key = {axis1: axis1.index(key1) for axis1, key1 in key_items}
2869+
2870+
# dict -> tuple (complete and order key)
2871+
assert all(isinstance(k1, Axis) for k1 in dict_key)
2872+
key = tuple(dict_key[axis1] if axis1 in dict_key else slice(None)
2873+
for axis1 in self)
2874+
28902875
assert isinstance(key, tuple) and len(key) == self.ndim
28912876

28922877
# scalar array
@@ -2943,7 +2928,7 @@ def slice_to_sequence(axis, axis_key):
29432928

29442929
# if there are only simple keys, do not bother going via the "advanced indexing" code path
29452930
if all(isinstance(axis_key, (int, np.integer, slice)) for axis_key in key):
2946-
bcasted_adv_keys = key
2931+
raw_broadcasted_key = key
29472932
else:
29482933
# Now that we know advanced indexing comes into play, we need to compute were the subspace created by the
29492934
# advanced indexes will be inserted. Note that there is only ever a SINGLE combined subspace (even if it
@@ -2982,14 +2967,11 @@ def slice_to_sequence(axis, axis_key):
29822967
adv_key_subspace_pos = adv_axes_indices[0]
29832968

29842969
# scalar/slice keys are ignored by make_numpy_broadcastable, which is exactly what we need
2985-
bcasted_adv_keys, adv_key_dest_axes = make_numpy_broadcastable(key)
2970+
raw_broadcasted_key, adv_key_dest_axes = raw_broadcastable(key)
29862971

29872972
# insert advanced indexing subspace
29882973
res_axes[adv_key_subspace_pos:adv_key_subspace_pos] = adv_key_dest_axes
29892974

2990-
# transform to raw numpy arrays
2991-
raw_broadcasted_key = tuple(k.data if isinstance(k, Array) else k
2992-
for k in bcasted_adv_keys)
29932975
return raw_broadcasted_key, res_axes, transpose_indices
29942976

29952977
@property

0 commit comments

Comments
 (0)