@@ -59,7 +59,20 @@ def __init__ (self, S):
59
59
if (S .start is not None ):
60
60
self .begin = c_double_t (S .start )
61
61
if (S .stop is not None ):
62
- self .end = c_double_t (S .stop - math .copysign (1 , self .step ))
62
+ self .end = c_double_t (S .stop )
63
+
64
+ # handle special cases
65
+ if self .begin >= 0 and self .end >= 0 and self .end <= self .begin and self .step >= 0 :
66
+ self .begin = 1
67
+ self .end = 1
68
+ self .step = 1
69
+ elif self .begin < 0 and self .end < 0 and self .end >= self .begin and self .step <= 0 :
70
+ self .begin = - 2
71
+ self .end = - 2
72
+ self .step = - 1
73
+
74
+ if (S .stop is not None ):
75
+ self .end = self .end - math .copysign (1 , self .step )
63
76
else :
64
77
raise IndexError ("Invalid type while indexing arrayfire.array" )
65
78
@@ -217,14 +230,15 @@ def __del__(self):
217
230
arr = c_void_ptr_t (self .idx .arr )
218
231
backend .get ().af_release_array (arr )
219
232
233
+ _span = Index (slice (None ))
220
234
class _Index4 (object ):
221
- def __init__ (self , idx0 , idx1 , idx2 , idx3 ):
235
+ def __init__ (self ):
222
236
index_vec = Index * 4
223
- self .array = index_vec (idx0 , idx1 , idx2 , idx3 )
237
+ self .array = index_vec (_span , _span , _span , _span )
224
238
# Do not lose those idx as self.array keeps
225
239
# no reference to them. Otherwise the destructor
226
240
# is prematurely called
227
- self .idxs = [idx0 , idx1 , idx2 , idx3 ]
241
+ self .idxs = [_span , _span , _span , _span ]
228
242
@property
229
243
def pointer (self ):
230
244
return c_pointer (self .array )
0 commit comments