diff --git a/array_api_tests/test_indexing_functions.py b/array_api_tests/test_indexing_functions.py index 3ef01cb7..a599d218 100644 --- a/array_api_tests/test_indexing_functions.py +++ b/array_api_tests/test_indexing_functions.py @@ -60,3 +60,63 @@ def test_take(x, data): # sanity check with pytest.raises(StopIteration): next(out_indices) + + + +@pytest.mark.unvectorized +@pytest.mark.min_version("2024.12") +@given( + x=hh.arrays(hh.all_dtypes, hh.shapes(min_dims=1, min_side=1)), + data=st.data(), +) +def test_take_along_axis(x, data): + # TODO + # 2. negative indices + # 3. different dtypes for indices + # 4. "broadcast-compatible" indices + axis = data.draw( + st.integers(-x.ndim, max(x.ndim - 1, 0)) | st.none(), + label="axis" + ) + if axis is None: + axis_kw = {} + n_axis = x.ndim - 1 + else: + axis_kw = {"axis": axis} + n_axis = axis + x.ndim if axis < 0 else axis + + new_len = data.draw(st.integers(0, 2*x.shape[n_axis]), label="new_len") + idx_shape = x.shape[:n_axis] + (new_len,) + x.shape[n_axis+1:] + indices = data.draw( + hh.arrays( + shape=idx_shape, + dtype=dh.default_int, + elements={"min_value": 0, "max_value": x.shape[n_axis]-1} + ), + label="indices" + ) + note(f"{indices=} {idx_shape=}") + + out = xp.take_along_axis(x, indices, **axis_kw) + + ph.assert_dtype("take_along_axis", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape( + "take_along_axis", + out_shape=out.shape, + expected=x.shape[:n_axis] + (new_len,) + x.shape[n_axis+1:], + kw=dict( + x=x, + indices=indices, + axis=axis, + ), + ) + + # value test: notation is from `np.take_along_axis` docstring + Ni, Nk = x.shape[:n_axis], x.shape[n_axis+1:] + for ii in sh.ndindex(Ni): + for kk in sh.ndindex(Nk): + a_1d = x[ii + (slice(None),) + kk] + i_1d = indices[ii + (slice(None),) + kk] + o_1d = out[ii + (slice(None),) + kk] + for j in range(new_len): + assert o_1d[j] == a_1d[i_1d[j]], f'{ii=}, {kk=}, {j=}'