19
19
from enum import IntEnum
20
20
from ._creation_functions import asarray
21
21
from ._dtypes import (
22
+ _DType ,
22
23
_all_dtypes ,
23
24
_boolean_dtypes ,
24
25
_integer_dtypes ,
39
40
40
41
import numpy as np
41
42
43
+ # Placeholder object to represent the "cpu" device (the only device NumPy
44
+ # supports).
45
+ class _cpu_device :
46
+ def __repr__ (self ):
47
+ return "CPU_DEVICE"
48
+
49
+ CPU_DEVICE = _cpu_device ()
42
50
43
51
class Array :
44
52
"""
@@ -75,11 +83,13 @@ def _new(cls, x, /):
75
83
if isinstance (x , np .generic ):
76
84
# Convert the array scalar to a 0-D array
77
85
x = np .asarray (x )
78
- if x .dtype not in _all_dtypes :
86
+ _dtype = _DType (x .dtype )
87
+ if _dtype not in _all_dtypes :
79
88
raise TypeError (
80
89
f"The array_api namespace does not support the dtype '{ x .dtype } '"
81
90
)
82
91
obj ._array = x
92
+ obj ._dtype = _dtype
83
93
return obj
84
94
85
95
# Prevent Array() from working
@@ -101,7 +111,7 @@ def __repr__(self: Array, /) -> str:
101
111
"""
102
112
Performs the operation __repr__.
103
113
"""
104
- suffix = f", dtype={ self .dtype . name } )"
114
+ suffix = f", dtype={ self .dtype } )"
105
115
if 0 in self .shape :
106
116
prefix = "empty("
107
117
mid = str (self .shape )
@@ -176,6 +186,8 @@ def _promote_scalar(self, scalar):
176
186
integer that is too large to fit in a NumPy integer dtype, or
177
187
TypeError when the scalar type is incompatible with the dtype of self.
178
188
"""
189
+ from ._data_type_functions import iinfo
190
+
179
191
# Note: Only Python scalar types that match the array dtype are
180
192
# allowed.
181
193
if isinstance (scalar , bool ):
@@ -189,7 +201,7 @@ def _promote_scalar(self, scalar):
189
201
"Python int scalars cannot be promoted with bool arrays"
190
202
)
191
203
if self .dtype in _integer_dtypes :
192
- info = np . iinfo (self .dtype )
204
+ info = iinfo (self .dtype )
193
205
if not (info .min <= scalar <= info .max ):
194
206
raise OverflowError (
195
207
"Python int scalars must be within the bounds of the dtype for integer arrays"
@@ -215,7 +227,7 @@ def _promote_scalar(self, scalar):
215
227
# behavior for integers within the bounds of the integer dtype.
216
228
# Outside of those bounds we use the default NumPy behavior (either
217
229
# cast or raise OverflowError).
218
- return Array ._new (np .array (scalar , self .dtype ))
230
+ return Array ._new (np .array (scalar , dtype = self .dtype . _np_dtype ))
219
231
220
232
@staticmethod
221
233
def _normalize_two_args (x1 , x2 ) -> Tuple [Array , Array ]:
@@ -325,7 +337,9 @@ def _validate_index(self, key):
325
337
for i in _key :
326
338
if i is not None :
327
339
nonexpanding_key .append (i )
328
- if isinstance (i , Array ) or isinstance (i , np .ndarray ):
340
+ if isinstance (i , np .ndarray ):
341
+ raise IndexError ("Index arrays for np.array_api must be np.array_api arrays" )
342
+ if isinstance (i , Array ):
329
343
if i .dtype in _boolean_dtypes :
330
344
key_has_mask = True
331
345
single_axes .append (i )
@@ -1067,7 +1081,7 @@ def __rxor__(self: Array, other: Union[int, bool, Array], /) -> Array:
1067
1081
def to_device (self : Array , device : Device , / , stream : None = None ) -> Array :
1068
1082
if stream is not None :
1069
1083
raise ValueError ("The stream argument to to_device() is not supported" )
1070
- if device == 'cpu' :
1084
+ if device == CPU_DEVICE :
1071
1085
return self
1072
1086
raise ValueError (f"Unsupported device { device !r} " )
1073
1087
@@ -1078,11 +1092,11 @@ def dtype(self) -> Dtype:
1078
1092
1079
1093
See its docstring for more information.
1080
1094
"""
1081
- return self ._array . dtype
1095
+ return self ._dtype
1082
1096
1083
1097
@property
1084
1098
def device (self ) -> Device :
1085
- return "cpu"
1099
+ return CPU_DEVICE
1086
1100
1087
1101
# Note: mT is new in array API spec (see matrix_transpose)
1088
1102
@property
0 commit comments