Skip to content

Commit 64fb62e

Browse files
committed
Cleaning up Index
1 parent 5f8dce0 commit 64fb62e

File tree

2 files changed

+19
-8
lines changed

2 files changed

+19
-8
lines changed

arrayfire/array.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -178,10 +178,7 @@ def _get_info(dims, buf_len):
178178

179179

180180
def _get_indices(key):
181-
182-
S = Index(slice(None))
183-
inds = _Index4(S, S, S, S)
184-
181+
inds = _Index4()
185182
if isinstance(key, tuple):
186183
n_idx = len(key)
187184
for n in range(n_idx):

arrayfire/index.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,20 @@ def __init__ (self, S):
5959
if (S.start is not None):
6060
self.begin = c_double_t(S.start)
6161
if (S.stop is not None):
62-
self.end = c_double_t(S.stop - math.copysign(1, self.step))
62+
self.end = c_double_t(S.stop)
63+
64+
# handle special cases
65+
if self.begin >= 0 and self.end >=0 and self.end <= self.begin and self.step >= 0:
66+
self.begin = 1
67+
self.end = 1
68+
self.step = 1
69+
elif self.begin < 0 and self.end < 0 and self.end >= self.begin and self.step <= 0:
70+
self.begin = -2
71+
self.end = -2
72+
self.step = -1
73+
74+
if (S.stop is not None):
75+
self.end = self.end - math.copysign(1, self.step)
6376
else:
6477
raise IndexError("Invalid type while indexing arrayfire.array")
6578

@@ -217,14 +230,15 @@ def __del__(self):
217230
arr = c_void_ptr_t(self.idx.arr)
218231
backend.get().af_release_array(arr)
219232

233+
_span = Index(slice(None))
220234
class _Index4(object):
221-
def __init__(self, idx0, idx1, idx2, idx3):
235+
def __init__(self):
222236
index_vec = Index * 4
223-
self.array = index_vec(idx0, idx1, idx2, idx3)
237+
self.array = index_vec(_span, _span, _span, _span)
224238
# Do not lose those idx as self.array keeps
225239
# no reference to them. Otherwise the destructor
226240
# is prematurely called
227-
self.idxs = [idx0,idx1,idx2,idx3]
241+
self.idxs = [_span, _span, _span, _span]
228242
@property
229243
def pointer(self):
230244
return c_pointer(self.array)

0 commit comments

Comments
 (0)