Skip to content

Commit c1ca9fb

Browse files
committed
Docstrings for axes helpers
1 parent b970435 commit c1ca9fb

File tree

2 files changed

+11
-9
lines changed

2 files changed

+11
-9
lines changed

array_api_tests/test_manipulation_functions.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from . import hypothesis_helpers as hh
1414
from . import pytest_helpers as ph
1515
from . import xps
16-
from .test_statistical_functions import axes_ndindex, normalise_axis # TODO: Move
16+
from .test_statistical_functions import axes_ndindex, normalise_axis
1717
from .typing import Array, Shape
1818

1919
MAX_SIDE = hh.MAX_ARRAY_SIZE // 64
@@ -32,6 +32,7 @@ def shared_shapes(*args, **kwargs) -> st.SearchStrategy[Shape]:
3232
def axis_ndindex(
3333
shape: Shape, axis: int
3434
) -> Iterator[Tuple[Tuple[Union[int, slice], ...], ...]]:
35+
"""Generate indices that index all elements in dimensions beyond `axis`"""
3536
assert axis >= 0 # sanity check
3637
axis_indices = [range(side) for side in shape[:axis]]
3738
for _ in range(axis, len(shape)):

array_api_tests/test_statistical_functions.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,18 +39,19 @@ def normalise_axis(
3939

4040

4141
def axes_ndindex(shape: Shape, axes: Tuple[int, ...]) -> Iterator[Tuple[Shape, ...]]:
42-
base_iterables = []
43-
axes_iterables = []
42+
"""Generate indices that index all elements except in `axes` dimensions"""
43+
base_indices = []
44+
axes_indices = []
4445
for axis, side in enumerate(shape):
4546
if axis in axes:
46-
base_iterables.append([None])
47-
axes_iterables.append(range(side))
47+
base_indices.append([None])
48+
axes_indices.append(range(side))
4849
else:
49-
base_iterables.append(range(side))
50-
axes_iterables.append([None])
51-
for base_idx in product(*base_iterables):
50+
base_indices.append(range(side))
51+
axes_indices.append([None])
52+
for base_idx in product(*base_indices):
5253
indices = []
53-
for idx in product(*axes_iterables):
54+
for idx in product(*axes_indices):
5455
idx = list(idx)
5556
for axis, side in enumerate(idx):
5657
if axis not in axes:

0 commit comments

Comments
 (0)