From 90edf4f709ce54809a7626b5bfb9fb74662a4d2f Mon Sep 17 00:00:00 2001 From: Edwin Date: Sun, 8 Jun 2025 11:35:21 -0700 Subject: [PATCH] Added support of AFArray indexing --- .../assignment_and_indexing/_indexing.py | 29 ++++++++----------- 1 file changed, 12 insertions(+), 17 deletions(-) diff --git a/arrayfire_wrapper/lib/create_and_modify_array/assignment_and_indexing/_indexing.py b/arrayfire_wrapper/lib/create_and_modify_array/assignment_and_indexing/_indexing.py index 36b20c3..c5f42c1 100644 --- a/arrayfire_wrapper/lib/create_and_modify_array/assignment_and_indexing/_indexing.py +++ b/arrayfire_wrapper/lib/create_and_modify_array/assignment_and_indexing/_indexing.py @@ -5,8 +5,8 @@ from typing import Any from arrayfire_wrapper.lib._broadcast import bcast_var -from arrayfire_wrapper.lib.create_and_modify_array.manage_array import release_array - +from arrayfire_wrapper.lib.create_and_modify_array.manage_array import release_array, retain_array +from arrayfire_wrapper.defines import AFArray class _IndexSequence(ctypes.Structure): """ @@ -186,7 +186,7 @@ class IndexStructure(ctypes.Structure): ----------- idx: key - - If of type af.Array, self.idx.arr = idx, self.isSeq = False + - If of type AFArray, self.idx.arr = idx, self.isSeq = False - If of type af.ParallelRange, self.idx.seq = idx, self.isBatch = True - Default:, self.idx.seq = af._IndexSequence(idx) @@ -197,26 +197,21 @@ class IndexStructure(ctypes.Structure): """ - def __init__(self, idx: Any) -> None: + def __init__(self, idx: int | slice | AFArray) -> None: self.idx = _IndexUnion() self.isBatch = False self.isSeq = True - # BUG cyclic reimport - # if isinstance(idx, Array): - # if idx.dtype == af_bool: - # self.idx.arr = everything.where(idx.arr) - # else: - # self.idx.arr = everything.retain_array(idx.arr) - - # self.isSeq = False - - if isinstance(idx, ParallelRange): + if isinstance(idx, int) or isinstance(idx, slice): + self.idx.seq = _IndexSequence(idx) + elif isinstance(idx, ParallelRange): self.idx.seq = idx self.isBatch = True - + elif isinstance(idx, AFArray): + self.idx.arr = retain_array(idx) + self.isSeq = False else: - self.idx.seq = _IndexSequence(idx) + raise IndexError("Invalid type while indexing arrayfire.array") def __del__(self) -> None: if not self.isSeq: @@ -247,7 +242,7 @@ def __setitem__(self, idx: int, value: IndexStructure) -> None: self.idxs[idx] = value -def get_indices(key: int | slice | tuple[int | slice, ...]) -> CIndexStructure: # BUG +def get_indices(key: int | slice | tuple[int | slice | AFArray, ...] | AFArray) -> CIndexStructure: # BUG indices = CIndexStructure() if isinstance(key, tuple):