Skip to content

Commit ad36032

Browse files
authored
Merge pull request #25370 from asmeurer/array_api-portability
ENH: Make `numpy.array_api` more portable
2 parents 4b56203 + 7f354e5 commit ad36032

16 files changed

+262
-127
lines changed
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
Make ``numpy.array_api`` more portable
2+
--------------------------------------
3+
4+
``numpy.array_api`` no longer uses ``"cpu"`` as its "device", but rather a
5+
separate ``CPU_DEVICE`` object (which is not accessible in the namespace).
6+
This is because "cpu" is not part of the array API standard.
7+
8+
``numpy.array_api`` now uses separate wrapped objects for dtypes. Previously
9+
it reused the ``numpy`` dtype objects. This makes it clear which behaviors on
10+
dtypes are part of the array API standard (effectively, the standard only
11+
requires ``==`` on dtype objects).
12+
13+
``numpy.array_api.nonzero`` now errors on zero-dimensional arrays, as required
14+
by the array API standard.

doc/source/reference/array_api.rst

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -622,6 +622,20 @@ Manipulation functions differences
622622
- **Compatible**
623623
- See https://github.com/numpy/numpy/issues/9818.
624624

625+
Searching functions differences
626+
-------------------------------
627+
628+
.. list-table::
629+
:header-rows: 1
630+
631+
* - Feature
632+
- Type
633+
- Notes
634+
* - ``nonzero`` disallows 0-dimensional inputs
635+
- **Breaking**
636+
- This behavior is already deprecated for ``np.nonzero``. See
637+
https://github.com/numpy/numpy/pull/13708.
638+
625639
Set functions differences
626640
-------------------------
627641

@@ -645,8 +659,8 @@ Set functions differences
645659

646660
.. _array_api-set-functions-differences:
647661

648-
Set functions differences
649-
-------------------------
662+
Sorting functions differences
663+
-----------------------------
650664

651665
.. list-table::
652666
:header-rows: 1
@@ -698,6 +712,16 @@ Other differences
698712
- **Strictness**
699713
- For example, ``numpy.array_api.asarray([0], dtype='int32')`` is not
700714
allowed.
715+
* - Dtype objects are wrapped so that they only implement the required
716+
``__eq__`` method, which only compares against dtype objects.
717+
- **Strictness**
718+
- For example, ``float32 == 'float32'`` is not allowed.
719+
* - ``arr.device`` always returns a ``CPU_DEVICE`` object (which is not
720+
part of the namespace). This is the only valid non-default value for
721+
``device`` keyword arguments to creation functions like ``asarray()``.
722+
- **Compatible**
723+
- CPU is the only device supported by NumPy. The standard does not
724+
require device objects to be accessible other than via ``arr.device``.
701725
* - ``asarray`` is not implicitly called in any function.
702726
- **Strictness**
703727
- The exception is Python operators, which accept Python scalars in

numpy/array_api/_array_object.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from enum import IntEnum
2020
from ._creation_functions import asarray
2121
from ._dtypes import (
22+
_DType,
2223
_all_dtypes,
2324
_boolean_dtypes,
2425
_integer_dtypes,
@@ -39,6 +40,13 @@
3940

4041
import numpy as np
4142

43+
# Placeholder object to represent the "cpu" device (the only device NumPy
44+
# supports).
45+
class _cpu_device:
46+
def __repr__(self):
47+
return "CPU_DEVICE"
48+
49+
CPU_DEVICE = _cpu_device()
4250

4351
class Array:
4452
"""
@@ -75,11 +83,13 @@ def _new(cls, x, /):
7583
if isinstance(x, np.generic):
7684
# Convert the array scalar to a 0-D array
7785
x = np.asarray(x)
78-
if x.dtype not in _all_dtypes:
86+
_dtype = _DType(x.dtype)
87+
if _dtype not in _all_dtypes:
7988
raise TypeError(
8089
f"The array_api namespace does not support the dtype '{x.dtype}'"
8190
)
8291
obj._array = x
92+
obj._dtype = _dtype
8393
return obj
8494

8595
# Prevent Array() from working
@@ -101,7 +111,7 @@ def __repr__(self: Array, /) -> str:
101111
"""
102112
Performs the operation __repr__.
103113
"""
104-
suffix = f", dtype={self.dtype.name})"
114+
suffix = f", dtype={self.dtype})"
105115
if 0 in self.shape:
106116
prefix = "empty("
107117
mid = str(self.shape)
@@ -176,6 +186,8 @@ def _promote_scalar(self, scalar):
176186
integer that is too large to fit in a NumPy integer dtype, or
177187
TypeError when the scalar type is incompatible with the dtype of self.
178188
"""
189+
from ._data_type_functions import iinfo
190+
179191
# Note: Only Python scalar types that match the array dtype are
180192
# allowed.
181193
if isinstance(scalar, bool):
@@ -189,7 +201,7 @@ def _promote_scalar(self, scalar):
189201
"Python int scalars cannot be promoted with bool arrays"
190202
)
191203
if self.dtype in _integer_dtypes:
192-
info = np.iinfo(self.dtype)
204+
info = iinfo(self.dtype)
193205
if not (info.min <= scalar <= info.max):
194206
raise OverflowError(
195207
"Python int scalars must be within the bounds of the dtype for integer arrays"
@@ -215,7 +227,7 @@ def _promote_scalar(self, scalar):
215227
# behavior for integers within the bounds of the integer dtype.
216228
# Outside of those bounds we use the default NumPy behavior (either
217229
# cast or raise OverflowError).
218-
return Array._new(np.array(scalar, self.dtype))
230+
return Array._new(np.array(scalar, dtype=self.dtype._np_dtype))
219231

220232
@staticmethod
221233
def _normalize_two_args(x1, x2) -> Tuple[Array, Array]:
@@ -325,7 +337,9 @@ def _validate_index(self, key):
325337
for i in _key:
326338
if i is not None:
327339
nonexpanding_key.append(i)
328-
if isinstance(i, Array) or isinstance(i, np.ndarray):
340+
if isinstance(i, np.ndarray):
341+
raise IndexError("Index arrays for np.array_api must be np.array_api arrays")
342+
if isinstance(i, Array):
329343
if i.dtype in _boolean_dtypes:
330344
key_has_mask = True
331345
single_axes.append(i)
@@ -1067,7 +1081,7 @@ def __rxor__(self: Array, other: Union[int, bool, Array], /) -> Array:
10671081
def to_device(self: Array, device: Device, /, stream: None = None) -> Array:
10681082
if stream is not None:
10691083
raise ValueError("The stream argument to to_device() is not supported")
1070-
if device == 'cpu':
1084+
if device == CPU_DEVICE:
10711085
return self
10721086
raise ValueError(f"Unsupported device {device!r}")
10731087

@@ -1078,11 +1092,11 @@ def dtype(self) -> Dtype:
10781092
10791093
See its docstring for more information.
10801094
"""
1081-
return self._array.dtype
1095+
return self._dtype
10821096

10831097
@property
10841098
def device(self) -> Device:
1085-
return "cpu"
1099+
return CPU_DEVICE
10861100

10871101
# Note: mT is new in array API spec (see matrix_transpose)
10881102
@property

0 commit comments

Comments
 (0)