From b1316cff516d147519a9c30f0e8327e5895598f4 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Thu, 20 Mar 2025 16:47:38 +0100 Subject: [PATCH 01/28] TST: skip tests of binary funcs w/scalar on older numpies NumPy < 2 fails to promote an empty f32 array with a scalar, returns an empty f64 array --- numpy-1-21-xfails.txt | 3 +++ numpy-1-26-xfails.txt | 3 +++ 2 files changed, 6 insertions(+) diff --git a/numpy-1-21-xfails.txt b/numpy-1-21-xfails.txt index 28c0e13a..7c7a0757 100644 --- a/numpy-1-21-xfails.txt +++ b/numpy-1-21-xfails.txt @@ -212,3 +212,6 @@ array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] + +# numpy < 2 bug: type promotion of asarray([], 'float32') and (np.finfo(float32).max + 1) -> float64 +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real diff --git a/numpy-1-26-xfails.txt b/numpy-1-26-xfails.txt index 80790534..57259b6f 100644 --- a/numpy-1-26-xfails.txt +++ b/numpy-1-26-xfails.txt @@ -66,3 +66,6 @@ array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] + +# numpy < 2 bug: type promotion of asarray([], 'float32') and (finfo(float32).max + 1) gives float64 not float32 +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real From 64ab7e26b86d0cd2d4cb544fdd39699a887823e8 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Fri, 21 Mar 2025 01:27:29 +0100 Subject: [PATCH 02/28] MAINT: update the version for 1.12.dev0 development --- array_api_compat/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_compat/__init__.py b/array_api_compat/__init__.py index 96b061e7..60b37e97 100644 --- a/array_api_compat/__init__.py +++ b/array_api_compat/__init__.py @@ -17,6 +17,6 @@ this implementation for the default when working with NumPy arrays. """ -__version__ = '1.11.2' +__version__ = '1.12.dev0' from .common import * # noqa: F401, F403 From 0080afed5b110c311cb88314d0370a2a3fcbefef Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Fri, 21 Mar 2025 12:02:54 +0100 Subject: [PATCH 03/28] Add a CuPy xfail CuPy 13.x follows NumPy 1.x without "weak scalars". In NumPy `result_type(int32, uint8, 1) != result_type(int32, uint8)` has been fixed in 2.x (or 1.x with set_promotion_state("weak"), so hopefully CuPy 14.x follows the suite, when released. Until then, just xfail the test. --- cupy-xfails.txt | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cupy-xfails.txt b/cupy-xfails.txt index 63e844cd..3d20d745 100644 --- a/cupy-xfails.txt +++ b/cupy-xfails.txt @@ -183,7 +183,7 @@ array_api_tests/test_manipulation_functions.py::test_repeat array_api_tests/test_signatures.py::test_func_signature[from_dlpack] array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] -+# 2024.12 support +# 2024.12 support array_api_tests/test_signatures.py::test_func_signature[count_nonzero] array_api_tests/test_signatures.py::test_func_signature[bitwise_and] array_api_tests/test_signatures.py::test_func_signature[bitwise_left_shift] @@ -192,3 +192,5 @@ array_api_tests/test_signatures.py::test_func_signature[bitwise_right_shift] array_api_tests/test_signatures.py::test_func_signature[bitwise_xor] array_api_tests/test_special_cases.py::test_binary[nextafter(x1_i is +0 and x2_i is -0) -> -0] +# cupy 13.x follows numpy 1.x w/o weak promotion: result_type(int32, uint8, 1) != result_type(int32, uint8) +array_api_tests/test_data_type_functions.py::TestResultType::test_with_scalars From a5a1d8ba722da9b8a2783ccd63c0b60713932793 Mon Sep 17 00:00:00 2001 From: Guido Imperiale Date: Sat, 22 Mar 2025 17:34:57 +0000 Subject: [PATCH 04/28] TYP: Type annotations overhaul, part 1 (#257) * ENH: Type annotations overhaul * Re-add py.typed * code review * lint * asarray * fill_value * result_type * lint * Arrays don't need to support buffer protocol * bool is a subclass of int * reshape: copy kwarg is keyword-only * tensordot formatting * Reinstate explicit bool | complex --- array_api_compat/common/_aliases.py | 248 +++++++++++++----------- array_api_compat/common/_fft.py | 87 +++++---- array_api_compat/common/_helpers.py | 32 +-- array_api_compat/common/_linalg.py | 84 +++++--- array_api_compat/common/_typing.py | 16 +- array_api_compat/cupy/_aliases.py | 36 ++-- array_api_compat/cupy/_typing.py | 63 +++--- array_api_compat/dask/array/_aliases.py | 54 ++---- array_api_compat/dask/array/fft.py | 13 +- array_api_compat/dask/array/linalg.py | 25 +-- array_api_compat/numpy/_aliases.py | 41 ++-- array_api_compat/numpy/_typing.py | 63 +++--- array_api_compat/py.typed | 0 array_api_compat/torch/_aliases.py | 168 ++++++++-------- array_api_compat/torch/_typing.py | 4 + array_api_compat/torch/fft.py | 35 ++-- array_api_compat/torch/linalg.py | 28 ++- setup.py | 5 +- tests/test_all.py | 17 +- 19 files changed, 511 insertions(+), 508 deletions(-) create mode 100644 array_api_compat/py.typed create mode 100644 array_api_compat/torch/_typing.py diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 35262d3a..0d123b99 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -4,15 +4,11 @@ from __future__ import annotations -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from typing import Optional, Sequence, Tuple, Union - from ._typing import ndarray, Device, Dtype - -from typing import NamedTuple import inspect +from typing import NamedTuple, Optional, Sequence, Tuple, Union from ._helpers import array_namespace, _check_device, device, is_cupy_namespace +from ._typing import Array, Device, DType, Namespace # These functions are modified from the NumPy versions. @@ -24,29 +20,34 @@ def arange( stop: Optional[Union[int, float]] = None, step: Union[int, float] = 1, *, - xp, - dtype: Optional[Dtype] = None, + xp: Namespace, + dtype: Optional[DType] = None, device: Optional[Device] = None, - **kwargs -) -> ndarray: + **kwargs, +) -> Array: _check_device(xp, device) return xp.arange(start, stop=stop, step=step, dtype=dtype, **kwargs) def empty( shape: Union[int, Tuple[int, ...]], - xp, + xp: Namespace, *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, - **kwargs -) -> ndarray: + **kwargs, +) -> Array: _check_device(xp, device) return xp.empty(shape, dtype=dtype, **kwargs) def empty_like( - x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, - **kwargs -) -> ndarray: + x: Array, + /, + xp: Namespace, + *, + dtype: Optional[DType] = None, + device: Optional[Device] = None, + **kwargs, +) -> Array: _check_device(xp, device) return xp.empty_like(x, dtype=dtype, **kwargs) @@ -55,37 +56,37 @@ def eye( n_cols: Optional[int] = None, /, *, - xp, + xp: Namespace, k: int = 0, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, **kwargs, -) -> ndarray: +) -> Array: _check_device(xp, device) return xp.eye(n_rows, M=n_cols, k=k, dtype=dtype, **kwargs) def full( shape: Union[int, Tuple[int, ...]], - fill_value: Union[int, float], - xp, + fill_value: complex, + xp: Namespace, *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, **kwargs, -) -> ndarray: +) -> Array: _check_device(xp, device) return xp.full(shape, fill_value, dtype=dtype, **kwargs) def full_like( - x: ndarray, + x: Array, /, - fill_value: Union[int, float], + fill_value: complex, *, - xp, - dtype: Optional[Dtype] = None, + xp: Namespace, + dtype: Optional[DType] = None, device: Optional[Device] = None, **kwargs, -) -> ndarray: +) -> Array: _check_device(xp, device) return xp.full_like(x, fill_value, dtype=dtype, **kwargs) @@ -95,48 +96,58 @@ def linspace( /, num: int, *, - xp, - dtype: Optional[Dtype] = None, + xp: Namespace, + dtype: Optional[DType] = None, device: Optional[Device] = None, endpoint: bool = True, **kwargs, -) -> ndarray: +) -> Array: _check_device(xp, device) return xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint, **kwargs) def ones( shape: Union[int, Tuple[int, ...]], - xp, + xp: Namespace, *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, **kwargs, -) -> ndarray: +) -> Array: _check_device(xp, device) return xp.ones(shape, dtype=dtype, **kwargs) def ones_like( - x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, + x: Array, + /, + xp: Namespace, + *, + dtype: Optional[DType] = None, + device: Optional[Device] = None, **kwargs, -) -> ndarray: +) -> Array: _check_device(xp, device) return xp.ones_like(x, dtype=dtype, **kwargs) def zeros( shape: Union[int, Tuple[int, ...]], - xp, + xp: Namespace, *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, **kwargs, -) -> ndarray: +) -> Array: _check_device(xp, device) return xp.zeros(shape, dtype=dtype, **kwargs) def zeros_like( - x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, + x: Array, + /, + xp: Namespace, + *, + dtype: Optional[DType] = None, + device: Optional[Device] = None, **kwargs, -) -> ndarray: +) -> Array: _check_device(xp, device) return xp.zeros_like(x, dtype=dtype, **kwargs) @@ -150,23 +161,23 @@ def zeros_like( # Note that these named tuples aren't actually part of the standard namespace, # but I don't see any issue with exporting the names here regardless. class UniqueAllResult(NamedTuple): - values: ndarray - indices: ndarray - inverse_indices: ndarray - counts: ndarray + values: Array + indices: Array + inverse_indices: Array + counts: Array class UniqueCountsResult(NamedTuple): - values: ndarray - counts: ndarray + values: Array + counts: Array class UniqueInverseResult(NamedTuple): - values: ndarray - inverse_indices: ndarray + values: Array + inverse_indices: Array -def _unique_kwargs(xp): +def _unique_kwargs(xp: Namespace) -> dict[str, bool]: # Older versions of NumPy and CuPy do not have equal_nan. Rather than # trying to parse version numbers, just check if equal_nan is in the # signature. @@ -175,7 +186,7 @@ def _unique_kwargs(xp): return {'equal_nan': False} return {} -def unique_all(x: ndarray, /, xp) -> UniqueAllResult: +def unique_all(x: Array, /, xp: Namespace) -> UniqueAllResult: kwargs = _unique_kwargs(xp) values, indices, inverse_indices, counts = xp.unique( x, @@ -195,7 +206,7 @@ def unique_all(x: ndarray, /, xp) -> UniqueAllResult: ) -def unique_counts(x: ndarray, /, xp) -> UniqueCountsResult: +def unique_counts(x: Array, /, xp: Namespace) -> UniqueCountsResult: kwargs = _unique_kwargs(xp) res = xp.unique( x, @@ -208,7 +219,7 @@ def unique_counts(x: ndarray, /, xp) -> UniqueCountsResult: return UniqueCountsResult(*res) -def unique_inverse(x: ndarray, /, xp) -> UniqueInverseResult: +def unique_inverse(x: Array, /, xp: Namespace) -> UniqueInverseResult: kwargs = _unique_kwargs(xp) values, inverse_indices = xp.unique( x, @@ -223,7 +234,7 @@ def unique_inverse(x: ndarray, /, xp) -> UniqueInverseResult: return UniqueInverseResult(values, inverse_indices) -def unique_values(x: ndarray, /, xp) -> ndarray: +def unique_values(x: Array, /, xp: Namespace) -> Array: kwargs = _unique_kwargs(xp) return xp.unique( x, @@ -236,42 +247,42 @@ def unique_values(x: ndarray, /, xp) -> ndarray: # These functions have different keyword argument names def std( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, - correction: Union[int, float] = 0.0, # correction instead of ddof + correction: Union[int, float] = 0.0, # correction instead of ddof keepdims: bool = False, **kwargs, -) -> ndarray: +) -> Array: return xp.std(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs) def var( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, - correction: Union[int, float] = 0.0, # correction instead of ddof + correction: Union[int, float] = 0.0, # correction instead of ddof keepdims: bool = False, **kwargs, -) -> ndarray: +) -> Array: return xp.var(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs) # cumulative_sum is renamed from cumsum, and adds the include_initial keyword # argument def cumulative_sum( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, axis: Optional[int] = None, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, include_initial: bool = False, - **kwargs -) -> ndarray: + **kwargs, +) -> Array: wrapped_xp = array_namespace(x) # TODO: The standard is not clear about what should happen when x.ndim == 0. @@ -294,15 +305,15 @@ def cumulative_sum( def cumulative_prod( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, axis: Optional[int] = None, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, include_initial: bool = False, - **kwargs -) -> ndarray: + **kwargs, +) -> Array: wrapped_xp = array_namespace(x) if axis is None: @@ -325,17 +336,18 @@ def cumulative_prod( # The min and max argument names in clip are different and not optional in numpy, and type # promotion behavior is different. def clip( - x: ndarray, + x: Array, /, - min: Optional[Union[int, float, ndarray]] = None, - max: Optional[Union[int, float, ndarray]] = None, + min: Optional[Union[int, float, Array]] = None, + max: Optional[Union[int, float, Array]] = None, *, - xp, + xp: Namespace, # TODO: np.clip has other ufunc kwargs - out: Optional[ndarray] = None, -) -> ndarray: + out: Optional[Array] = None, +) -> Array: def _isscalar(a): return isinstance(a, (int, float, type(None))) + min_shape = () if _isscalar(min) else min.shape max_shape = () if _isscalar(max) else max.shape @@ -390,15 +402,19 @@ def _isscalar(a): return out[()] # Unlike transpose(), the axes argument to permute_dims() is required. -def permute_dims(x: ndarray, /, axes: Tuple[int, ...], xp) -> ndarray: +def permute_dims(x: Array, /, axes: Tuple[int, ...], xp: Namespace) -> Array: return xp.transpose(x, axes) # np.reshape calls the keyword argument 'newshape' instead of 'shape' -def reshape(x: ndarray, - /, - shape: Tuple[int, ...], - xp, copy: Optional[bool] = None, - **kwargs) -> ndarray: +def reshape( + x: Array, + /, + shape: Tuple[int, ...], + xp: Namespace, + *, + copy: Optional[bool] = None, + **kwargs, +) -> Array: if copy is True: x = x.copy() elif copy is False: @@ -410,9 +426,15 @@ def reshape(x: ndarray, # The descending keyword is new in sort and argsort, and 'kind' replaced with # 'stable' def argsort( - x: ndarray, /, xp, *, axis: int = -1, descending: bool = False, stable: bool = True, + x: Array, + /, + xp: Namespace, + *, + axis: int = -1, + descending: bool = False, + stable: bool = True, **kwargs, -) -> ndarray: +) -> Array: # Note: this keyword argument is different, and the default is different. # We set it in kwargs like this because numpy.sort uses kind='quicksort' # as the default whereas cupy.sort uses kind=None. @@ -435,9 +457,15 @@ def argsort( return res def sort( - x: ndarray, /, xp, *, axis: int = -1, descending: bool = False, stable: bool = True, + x: Array, + /, + xp: Namespace, + *, + axis: int = -1, + descending: bool = False, + stable: bool = True, **kwargs, -) -> ndarray: +) -> Array: # Note: this keyword argument is different, and the default is different. # We set it in kwargs like this because numpy.sort uses kind='quicksort' # as the default whereas cupy.sort uses kind=None. @@ -449,50 +477,51 @@ def sort( return res # nonzero should error for zero-dimensional arrays -def nonzero(x: ndarray, /, xp, **kwargs) -> Tuple[ndarray, ...]: +def nonzero(x: Array, /, xp: Namespace, **kwargs) -> Tuple[Array, ...]: if x.ndim == 0: raise ValueError("nonzero() does not support zero-dimensional arrays") return xp.nonzero(x, **kwargs) # ceil, floor, and trunc return integers for integer inputs -def ceil(x: ndarray, /, xp, **kwargs) -> ndarray: +def ceil(x: Array, /, xp: Namespace, **kwargs) -> Array: if xp.issubdtype(x.dtype, xp.integer): return x return xp.ceil(x, **kwargs) -def floor(x: ndarray, /, xp, **kwargs) -> ndarray: +def floor(x: Array, /, xp: Namespace, **kwargs) -> Array: if xp.issubdtype(x.dtype, xp.integer): return x return xp.floor(x, **kwargs) -def trunc(x: ndarray, /, xp, **kwargs) -> ndarray: +def trunc(x: Array, /, xp: Namespace, **kwargs) -> Array: if xp.issubdtype(x.dtype, xp.integer): return x return xp.trunc(x, **kwargs) # linear algebra functions -def matmul(x1: ndarray, x2: ndarray, /, xp, **kwargs) -> ndarray: +def matmul(x1: Array, x2: Array, /, xp: Namespace, **kwargs) -> Array: return xp.matmul(x1, x2, **kwargs) # Unlike transpose, matrix_transpose only transposes the last two axes. -def matrix_transpose(x: ndarray, /, xp) -> ndarray: +def matrix_transpose(x: Array, /, xp: Namespace) -> Array: if x.ndim < 2: raise ValueError("x must be at least 2-dimensional for matrix_transpose") return xp.swapaxes(x, -1, -2) -def tensordot(x1: ndarray, - x2: ndarray, - /, - xp, - *, - axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2, - **kwargs, -) -> ndarray: +def tensordot( + x1: Array, + x2: Array, + /, + xp: Namespace, + *, + axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2, + **kwargs, +) -> Array: return xp.tensordot(x1, x2, axes=axes, **kwargs) -def vecdot(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1) -> ndarray: +def vecdot(x1: Array, x2: Array, /, xp: Namespace, *, axis: int = -1) -> Array: if x1.shape[axis] != x2.shape[axis]: raise ValueError("x1 and x2 must have the same size along the given axis") @@ -511,8 +540,11 @@ def vecdot(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1) -> ndarray: # isdtype is a new function in the 2022.12 array API specification. def isdtype( - dtype: Dtype, kind: Union[Dtype, str, Tuple[Union[Dtype, str], ...]], xp, - *, _tuple=True, # Disallow nested tuples + dtype: DType, + kind: Union[DType, str, Tuple[Union[DType, str], ...]], + xp: Namespace, + *, + _tuple: bool = True, # Disallow nested tuples ) -> bool: """ Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``. @@ -551,14 +583,14 @@ def isdtype( return dtype == kind # unstack is a new function in the 2023.12 array API standard -def unstack(x: ndarray, /, xp, *, axis: int = 0) -> Tuple[ndarray, ...]: +def unstack(x: Array, /, xp: Namespace, *, axis: int = 0) -> Tuple[Array, ...]: if x.ndim == 0: raise ValueError("Input array must be at least 1-d.") return tuple(xp.moveaxis(x, axis, 0)) # numpy 1.26 does not use the standard definition for sign on complex numbers -def sign(x: ndarray, /, xp, **kwargs) -> ndarray: +def sign(x: Array, /, xp: Namespace, **kwargs) -> Array: if isdtype(x.dtype, 'complex floating', xp=xp): out = (x/xp.abs(x, **kwargs))[...] # sign(0) = 0 but the above formula would give nan diff --git a/array_api_compat/common/_fft.py b/array_api_compat/common/_fft.py index e5caebef..bd2a4e1a 100644 --- a/array_api_compat/common/_fft.py +++ b/array_api_compat/common/_fft.py @@ -1,149 +1,148 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Union, Optional, Literal +from collections.abc import Sequence +from typing import Union, Optional, Literal -if TYPE_CHECKING: - from ._typing import Device, ndarray, DType - from collections.abc import Sequence +from ._typing import Device, Array, DType, Namespace # Note: NumPy fft functions improperly upcast float32 and complex64 to # complex128, which is why we require wrapping them all here. def fft( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, n: Optional[int] = None, axis: int = -1, norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: +) -> Array: res = xp.fft.fft(x, n=n, axis=axis, norm=norm) if x.dtype in [xp.float32, xp.complex64]: return res.astype(xp.complex64) return res def ifft( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, n: Optional[int] = None, axis: int = -1, norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: +) -> Array: res = xp.fft.ifft(x, n=n, axis=axis, norm=norm) if x.dtype in [xp.float32, xp.complex64]: return res.astype(xp.complex64) return res def fftn( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, s: Sequence[int] = None, axes: Sequence[int] = None, norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: +) -> Array: res = xp.fft.fftn(x, s=s, axes=axes, norm=norm) if x.dtype in [xp.float32, xp.complex64]: return res.astype(xp.complex64) return res def ifftn( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, s: Sequence[int] = None, axes: Sequence[int] = None, norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: +) -> Array: res = xp.fft.ifftn(x, s=s, axes=axes, norm=norm) if x.dtype in [xp.float32, xp.complex64]: return res.astype(xp.complex64) return res def rfft( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, n: Optional[int] = None, axis: int = -1, norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: +) -> Array: res = xp.fft.rfft(x, n=n, axis=axis, norm=norm) if x.dtype == xp.float32: return res.astype(xp.complex64) return res def irfft( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, n: Optional[int] = None, axis: int = -1, norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: +) -> Array: res = xp.fft.irfft(x, n=n, axis=axis, norm=norm) if x.dtype == xp.complex64: return res.astype(xp.float32) return res def rfftn( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, s: Sequence[int] = None, axes: Sequence[int] = None, norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: +) -> Array: res = xp.fft.rfftn(x, s=s, axes=axes, norm=norm) if x.dtype == xp.float32: return res.astype(xp.complex64) return res def irfftn( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, s: Sequence[int] = None, axes: Sequence[int] = None, norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: +) -> Array: res = xp.fft.irfftn(x, s=s, axes=axes, norm=norm) if x.dtype == xp.complex64: return res.astype(xp.float32) return res def hfft( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, n: Optional[int] = None, axis: int = -1, norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: +) -> Array: res = xp.fft.hfft(x, n=n, axis=axis, norm=norm) if x.dtype in [xp.float32, xp.complex64]: return res.astype(xp.float32) return res def ihfft( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, n: Optional[int] = None, axis: int = -1, norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: +) -> Array: res = xp.fft.ihfft(x, n=n, axis=axis, norm=norm) if x.dtype in [xp.float32, xp.complex64]: return res.astype(xp.complex64) @@ -152,12 +151,12 @@ def ihfft( def fftfreq( n: int, /, - xp, + xp: Namespace, *, d: float = 1.0, dtype: Optional[DType] = None, - device: Optional[Device] = None -) -> ndarray: + device: Optional[Device] = None, +) -> Array: if device not in ["cpu", None]: raise ValueError(f"Unsupported device {device!r}") res = xp.fft.fftfreq(n, d=d) @@ -168,12 +167,12 @@ def fftfreq( def rfftfreq( n: int, /, - xp, + xp: Namespace, *, d: float = 1.0, dtype: Optional[DType] = None, - device: Optional[Device] = None -) -> ndarray: + device: Optional[Device] = None, +) -> Array: if device not in ["cpu", None]: raise ValueError(f"Unsupported device {device!r}") res = xp.fft.rfftfreq(n, d=d) @@ -181,10 +180,14 @@ def rfftfreq( return res.astype(dtype) return res -def fftshift(x: ndarray, /, xp, *, axes: Union[int, Sequence[int]] = None) -> ndarray: +def fftshift( + x: Array, /, xp: Namespace, *, axes: Union[int, Sequence[int]] = None +) -> Array: return xp.fft.fftshift(x, axes=axes) -def ifftshift(x: ndarray, /, xp, *, axes: Union[int, Sequence[int]] = None) -> ndarray: +def ifftshift( + x: Array, /, xp: Namespace, *, axes: Union[int, Sequence[int]] = None +) -> Array: return xp.fft.ifftshift(x, axes=axes) __all__ = [ diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 791edb81..6d95069d 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -7,16 +7,14 @@ """ from __future__ import annotations -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from typing import Optional, Union, Any - from ._typing import Array, Device, Namespace - import sys import math import inspect import warnings +from typing import Optional, Union, Any + +from ._typing import Array, Device, Namespace + def _is_jax_zero_gradient_array(x: object) -> bool: """Return True if `x` is a zero-gradient array. @@ -268,7 +266,7 @@ def _compat_module_name() -> str: return __name__.removesuffix('.common._helpers') -def is_numpy_namespace(xp) -> bool: +def is_numpy_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is a NumPy namespace. @@ -289,7 +287,7 @@ def is_numpy_namespace(xp) -> bool: return xp.__name__ in {'numpy', _compat_module_name() + '.numpy'} -def is_cupy_namespace(xp) -> bool: +def is_cupy_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is a CuPy namespace. @@ -310,7 +308,7 @@ def is_cupy_namespace(xp) -> bool: return xp.__name__ in {'cupy', _compat_module_name() + '.cupy'} -def is_torch_namespace(xp) -> bool: +def is_torch_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is a PyTorch namespace. @@ -331,7 +329,7 @@ def is_torch_namespace(xp) -> bool: return xp.__name__ in {'torch', _compat_module_name() + '.torch'} -def is_ndonnx_namespace(xp) -> bool: +def is_ndonnx_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is an NDONNX namespace. @@ -350,7 +348,7 @@ def is_ndonnx_namespace(xp) -> bool: return xp.__name__ == 'ndonnx' -def is_dask_namespace(xp) -> bool: +def is_dask_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is a Dask namespace. @@ -371,7 +369,7 @@ def is_dask_namespace(xp) -> bool: return xp.__name__ in {'dask.array', _compat_module_name() + '.dask.array'} -def is_jax_namespace(xp) -> bool: +def is_jax_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is a JAX namespace. @@ -393,7 +391,7 @@ def is_jax_namespace(xp) -> bool: return xp.__name__ in {'jax.numpy', 'jax.experimental.array_api'} -def is_pydata_sparse_namespace(xp) -> bool: +def is_pydata_sparse_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is a pydata/sparse namespace. @@ -412,7 +410,7 @@ def is_pydata_sparse_namespace(xp) -> bool: return xp.__name__ == 'sparse' -def is_array_api_strict_namespace(xp) -> bool: +def is_array_api_strict_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is an array-api-strict namespace. @@ -439,7 +437,11 @@ def _check_api_version(api_version: str) -> None: raise ValueError("Only the 2024.12 version of the array API specification is currently supported") -def array_namespace(*xs, api_version=None, use_compat=None) -> Namespace: +def array_namespace( + *xs: Union[Array, bool, int, float, complex, None], + api_version: Optional[str] = None, + use_compat: Optional[bool] = None, +) -> Namespace: """ Get the array API compatible namespace for the arrays `xs`. diff --git a/array_api_compat/common/_linalg.py b/array_api_compat/common/_linalg.py index bfa1f1b9..c77ee3b8 100644 --- a/array_api_compat/common/_linalg.py +++ b/array_api_compat/common/_linalg.py @@ -1,11 +1,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, NamedTuple -if TYPE_CHECKING: - from typing import Literal, Optional, Tuple, Union - from ._typing import ndarray - import math +from typing import Literal, NamedTuple, Optional, Tuple, Union import numpy as np if np.__version__[0] == "2": @@ -15,50 +11,53 @@ from ._aliases import matmul, matrix_transpose, tensordot, vecdot, isdtype from .._internal import get_xp +from ._typing import Array, Namespace # These are in the main NumPy namespace but not in numpy.linalg -def cross(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1, **kwargs) -> ndarray: +def cross(x1: Array, x2: Array, /, xp: Namespace, *, axis: int = -1, **kwargs) -> Array: return xp.cross(x1, x2, axis=axis, **kwargs) -def outer(x1: ndarray, x2: ndarray, /, xp, **kwargs) -> ndarray: +def outer(x1: Array, x2: Array, /, xp: Namespace, **kwargs) -> Array: return xp.outer(x1, x2, **kwargs) class EighResult(NamedTuple): - eigenvalues: ndarray - eigenvectors: ndarray + eigenvalues: Array + eigenvectors: Array class QRResult(NamedTuple): - Q: ndarray - R: ndarray + Q: Array + R: Array class SlogdetResult(NamedTuple): - sign: ndarray - logabsdet: ndarray + sign: Array + logabsdet: Array class SVDResult(NamedTuple): - U: ndarray - S: ndarray - Vh: ndarray + U: Array + S: Array + Vh: Array # These functions are the same as their NumPy counterparts except they return # a namedtuple. -def eigh(x: ndarray, /, xp, **kwargs) -> EighResult: +def eigh(x: Array, /, xp: Namespace, **kwargs) -> EighResult: return EighResult(*xp.linalg.eigh(x, **kwargs)) -def qr(x: ndarray, /, xp, *, mode: Literal['reduced', 'complete'] = 'reduced', +def qr(x: Array, /, xp: Namespace, *, mode: Literal['reduced', 'complete'] = 'reduced', **kwargs) -> QRResult: return QRResult(*xp.linalg.qr(x, mode=mode, **kwargs)) -def slogdet(x: ndarray, /, xp, **kwargs) -> SlogdetResult: +def slogdet(x: Array, /, xp: Namespace, **kwargs) -> SlogdetResult: return SlogdetResult(*xp.linalg.slogdet(x, **kwargs)) -def svd(x: ndarray, /, xp, *, full_matrices: bool = True, **kwargs) -> SVDResult: +def svd( + x: Array, /, xp: Namespace, *, full_matrices: bool = True, **kwargs +) -> SVDResult: return SVDResult(*xp.linalg.svd(x, full_matrices=full_matrices, **kwargs)) # These functions have additional keyword arguments # The upper keyword argument is new from NumPy -def cholesky(x: ndarray, /, xp, *, upper: bool = False, **kwargs) -> ndarray: +def cholesky(x: Array, /, xp: Namespace, *, upper: bool = False, **kwargs) -> Array: L = xp.linalg.cholesky(x, **kwargs) if upper: U = get_xp(xp)(matrix_transpose)(L) @@ -69,12 +68,12 @@ def cholesky(x: ndarray, /, xp, *, upper: bool = False, **kwargs) -> ndarray: # The rtol keyword argument of matrix_rank() and pinv() is new from NumPy. # Note that it has a different semantic meaning from tol and rcond. -def matrix_rank(x: ndarray, +def matrix_rank(x: Array, /, - xp, + xp: Namespace, *, - rtol: Optional[Union[float, ndarray]] = None, - **kwargs) -> ndarray: + rtol: Optional[Union[float, Array]] = None, + **kwargs) -> Array: # this is different from xp.linalg.matrix_rank, which supports 1 # dimensional arrays. if x.ndim < 2: @@ -88,7 +87,9 @@ def matrix_rank(x: ndarray, tol = S.max(axis=-1, keepdims=True)*xp.asarray(rtol)[..., xp.newaxis] return xp.count_nonzero(S > tol, axis=-1) -def pinv(x: ndarray, /, xp, *, rtol: Optional[Union[float, ndarray]] = None, **kwargs) -> ndarray: +def pinv( + x: Array, /, xp: Namespace, *, rtol: Optional[Union[float, Array]] = None, **kwargs +) -> Array: # this is different from xp.linalg.pinv, which does not multiply the # default tolerance by max(M, N). if rtol is None: @@ -97,15 +98,30 @@ def pinv(x: ndarray, /, xp, *, rtol: Optional[Union[float, ndarray]] = None, **k # These functions are new in the array API spec -def matrix_norm(x: ndarray, /, xp, *, keepdims: bool = False, ord: Optional[Union[int, float, Literal['fro', 'nuc']]] = 'fro') -> ndarray: +def matrix_norm( + x: Array, + /, + xp: Namespace, + *, + keepdims: bool = False, + ord: Optional[Union[int, float, Literal['fro', 'nuc']]] = 'fro', +) -> Array: return xp.linalg.norm(x, axis=(-2, -1), keepdims=keepdims, ord=ord) # svdvals is not in NumPy (but it is in SciPy). It is equivalent to # xp.linalg.svd(compute_uv=False). -def svdvals(x: ndarray, /, xp) -> Union[ndarray, Tuple[ndarray, ...]]: +def svdvals(x: Array, /, xp: Namespace) -> Union[Array, Tuple[Array, ...]]: return xp.linalg.svd(x, compute_uv=False) -def vector_norm(x: ndarray, /, xp, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ord: Optional[Union[int, float]] = 2) -> ndarray: +def vector_norm( + x: Array, + /, + xp: Namespace, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdims: bool = False, + ord: Optional[Union[int, float]] = 2, +) -> Array: # xp.linalg.norm tries to do a matrix norm whenever axis is a 2-tuple or # when axis=None and the input is 2-D, so to force a vector norm, we make # it so the input is 1-D (for axis=None), or reshape so that norm is done @@ -143,11 +159,15 @@ def vector_norm(x: ndarray, /, xp, *, axis: Optional[Union[int, Tuple[int, ...]] # xp.diagonal and xp.trace operate on the first two axes whereas these # operates on the last two -def diagonal(x: ndarray, /, xp, *, offset: int = 0, **kwargs) -> ndarray: +def diagonal(x: Array, /, xp: Namespace, *, offset: int = 0, **kwargs) -> Array: return xp.diagonal(x, offset=offset, axis1=-2, axis2=-1, **kwargs) -def trace(x: ndarray, /, xp, *, offset: int = 0, dtype=None, **kwargs) -> ndarray: - return xp.asarray(xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs)) +def trace( + x: Array, /, xp: Namespace, *, offset: int = 0, dtype=None, **kwargs +) -> Array: + return xp.asarray( + xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs) + ) __all__ = ['cross', 'matmul', 'outer', 'tensordot', 'EighResult', 'QRResult', 'SlogdetResult', 'SVDResult', 'eigh', 'qr', 'slogdet', diff --git a/array_api_compat/common/_typing.py b/array_api_compat/common/_typing.py index d8acdef7..4c3b356b 100644 --- a/array_api_compat/common/_typing.py +++ b/array_api_compat/common/_typing.py @@ -1,26 +1,24 @@ from __future__ import annotations +from types import ModuleType as Namespace +from typing import Any, TypeVar, Protocol __all__ = [ + "Array", + "DType", + "Device", + "Namespace", "NestedSequence", "SupportsBufferProtocol", ] -from types import ModuleType -from typing import ( - Any, - TypeVar, - Protocol, -) - _T_co = TypeVar("_T_co", covariant=True) class NestedSequence(Protocol[_T_co]): def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ... def __len__(self, /) -> int: ... -SupportsBufferProtocol = Any +SupportsBufferProtocol = Any Array = Any Device = Any DType = Any -Namespace = ModuleType diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index 30d9fe48..ebc7ccd9 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -1,16 +1,14 @@ from __future__ import annotations +from typing import Optional + import cupy as cp from ..common import _aliases, _helpers +from ..common._typing import NestedSequence, SupportsBufferProtocol from .._internal import get_xp - from ._info import __array_namespace_info__ - -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from typing import Optional, Union - from ._typing import ndarray, Device, Dtype, NestedSequence, SupportsBufferProtocol +from ._typing import Array, Device, DType bool = cp.bool_ @@ -66,23 +64,19 @@ _copy_default = object() + # asarray also adds the copy keyword, which is not present in numpy 1.0. def asarray( - obj: Union[ - ndarray, - bool, - int, - float, - NestedSequence[bool | int | float], - SupportsBufferProtocol, - ], + obj: ( + Array | bool | complex | NestedSequence[bool | complex] | SupportsBufferProtocol + ), /, *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, copy: Optional[bool] = _copy_default, **kwargs, -) -> ndarray: +) -> Array: """ Array API compatibility wrapper for asarray(). @@ -112,13 +106,13 @@ def asarray( def astype( - x: ndarray, - dtype: Dtype, + x: Array, + dtype: DType, /, *, copy: bool = True, device: Optional[Device] = None, -) -> ndarray: +) -> Array: if device is None: return x.astype(dtype=dtype, copy=copy) out = _helpers.to_device(x.astype(dtype=dtype, copy=False), device) @@ -127,10 +121,10 @@ def astype( # cupy.count_nonzero does not have keepdims def count_nonzero( - x: ndarray, + x: Array, axis=None, keepdims=False -) -> ndarray: +) -> Array: result = cp.count_nonzero(x, axis) if keepdims: if axis is None: diff --git a/array_api_compat/cupy/_typing.py b/array_api_compat/cupy/_typing.py index f3d9aab6..66af5d19 100644 --- a/array_api_compat/cupy/_typing.py +++ b/array_api_compat/cupy/_typing.py @@ -1,46 +1,31 @@ from __future__ import annotations -__all__ = [ - "ndarray", - "Device", - "Dtype", -] +__all__ = ["Array", "DType", "Device"] +_all_ignore = ["cp"] -import sys -from typing import ( - Union, - TYPE_CHECKING, -) - -from cupy import ( - ndarray, - dtype, - int8, - int16, - int32, - int64, - uint8, - uint16, - uint32, - uint64, - float32, - float64, -) +from typing import TYPE_CHECKING +import cupy as cp +from cupy import ndarray as Array from cupy.cuda.device import Device -if TYPE_CHECKING or sys.version_info >= (3, 9): - Dtype = dtype[Union[ - int8, - int16, - int32, - int64, - uint8, - uint16, - uint32, - uint64, - float32, - float64, - ]] +if TYPE_CHECKING: + # NumPy 1.x on Python 3.9 and 3.10 fails to parse np.dtype[] + DType = cp.dtype[ + cp.intp + | cp.int8 + | cp.int16 + | cp.int32 + | cp.int64 + | cp.uint8 + | cp.uint16 + | cp.uint32 + | cp.uint64 + | cp.float32 + | cp.float64 + | cp.complex64 + | cp.complex128 + | cp.bool_ + ] else: - Dtype = dtype + DType = cp.dtype diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index 80d66281..e737cebd 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -1,16 +1,10 @@ from __future__ import annotations -from typing import Callable - -from ...common import _aliases, array_namespace - -from ..._internal import get_xp - -from ._info import __array_namespace_info__ +from typing import Callable, Optional, Union import numpy as np from numpy import ( - # Dtypes + # dtypes iinfo, finfo, bool_ as bool, @@ -29,22 +23,19 @@ can_cast, result_type, ) - -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from typing import Optional, Union - - from ...common._typing import ( - Device, - Dtype, - Array, - NestedSequence, - SupportsBufferProtocol, - ) - import dask.array as da +from ...common import _aliases, array_namespace +from ...common._typing import ( + Array, + Device, + DType, + NestedSequence, + SupportsBufferProtocol, +) +from ..._internal import get_xp +from ._info import __array_namespace_info__ + isdtype = get_xp(np)(_aliases.isdtype) unstack = get_xp(da)(_aliases.unstack) @@ -52,7 +43,7 @@ # da.astype doesn't respect copy=True def astype( x: Array, - dtype: Dtype, + dtype: DType, /, *, copy: bool = True, @@ -84,7 +75,7 @@ def arange( stop: Optional[Union[int, float]] = None, step: Union[int, float] = 1, *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, **kwargs, ) -> Array: @@ -144,17 +135,12 @@ def arange( # asarray also adds the copy keyword, which is not present in numpy 1.0. def asarray( - obj: Union[ - Array, - bool, - int, - float, - NestedSequence[bool | int | float], - SupportsBufferProtocol, - ], + obj: ( + Array | bool | complex | NestedSequence[bool | complex] | SupportsBufferProtocol + ), /, *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, copy: Optional[Union[bool, np._CopyMode]] = None, **kwargs, @@ -360,4 +346,4 @@ def count_nonzero( 'complex64', 'complex128', 'iinfo', 'finfo', 'can_cast', 'count_nonzero', 'result_type'] -_all_ignore = ["Callable", "array_namespace", "get_xp", "da", "np"] +_all_ignore = ["array_namespace", "get_xp", "da", "np"] diff --git a/array_api_compat/dask/array/fft.py b/array_api_compat/dask/array/fft.py index aebd86f7..3f40dffe 100644 --- a/array_api_compat/dask/array/fft.py +++ b/array_api_compat/dask/array/fft.py @@ -4,9 +4,10 @@ # from dask.array.fft import __all__ as linalg_all _n = {} exec('from dask.array.fft import *', _n) -del _n['__builtins__'] +for k in ("__builtins__", "Sequence", "annotations", "warnings"): + _n.pop(k, None) fft_all = list(_n) -del _n +del _n, k from ...common import _fft from ..._internal import get_xp @@ -16,9 +17,5 @@ fftfreq = get_xp(da)(_fft.fftfreq) rfftfreq = get_xp(da)(_fft.rfftfreq) -__all__ = [elem for elem in fft_all if elem != "annotations"] + ["fftfreq", "rfftfreq"] - -del get_xp -del da -del fft_all -del _fft +__all__ = fft_all + ["fftfreq", "rfftfreq"] +_all_ignore = ["da", "fft_all", "get_xp", "warnings"] diff --git a/array_api_compat/dask/array/linalg.py b/array_api_compat/dask/array/linalg.py index 49c26d8b..bd53f0df 100644 --- a/array_api_compat/dask/array/linalg.py +++ b/array_api_compat/dask/array/linalg.py @@ -1,33 +1,28 @@ from __future__ import annotations -from ...common import _linalg -from ..._internal import get_xp +from typing import Literal +import dask.array as da # Exports from dask.array.linalg import * # noqa: F403 from dask.array import outer - # These functions are in both the main and linalg namespaces from dask.array import matmul, tensordot -from ._aliases import matrix_transpose, vecdot -import dask.array as da - -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from ...common._typing import Array - from typing import Literal +from ..._internal import get_xp +from ...common import _linalg +from ...common._typing import Array +from ._aliases import matrix_transpose, vecdot # dask.array.linalg doesn't have __all__. If it is added, replace this with # # from dask.array.linalg import __all__ as linalg_all _n = {} exec('from dask.array.linalg import *', _n) -del _n['__builtins__'] -if 'annotations' in _n: - del _n['annotations'] +for k in ('__builtins__', 'annotations', 'operator', 'warnings', 'Array'): + _n.pop(k, None) linalg_all = list(_n) -del _n +del _n, k EighResult = _linalg.EighResult QRResult = _linalg.QRResult @@ -70,4 +65,4 @@ def svdvals(x: Array) -> Array: "cholesky", "matrix_rank", "matrix_norm", "svdvals", "vector_norm", "diagonal"] -_all_ignore = ['get_xp', 'da', 'linalg_all'] +_all_ignore = ['get_xp', 'da', 'linalg_all', 'warnings'] diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index a47f7121..6536d9a8 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -1,17 +1,15 @@ from __future__ import annotations -from ..common import _aliases +from typing import Optional, Union from .._internal import get_xp - +from ..common import _aliases +from ..common._typing import NestedSequence, SupportsBufferProtocol from ._info import __array_namespace_info__ - -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from typing import Optional, Union - from ._typing import ndarray, Device, Dtype, NestedSequence, SupportsBufferProtocol +from ._typing import Array, Device, DType import numpy as np + bool = np.bool_ # Basic renames @@ -64,6 +62,7 @@ tensordot = get_xp(np)(_aliases.tensordot) sign = get_xp(np)(_aliases.sign) + def _supports_buffer_protocol(obj): try: memoryview(obj) @@ -71,26 +70,22 @@ def _supports_buffer_protocol(obj): return False return True + # asarray also adds the copy keyword, which is not present in numpy 1.0. # asarray() is different enough between numpy, cupy, and dask, the logic # complicated enough that it's easier to define it separately for each module # rather than trying to combine everything into one function in common/ def asarray( - obj: Union[ - ndarray, - bool, - int, - float, - NestedSequence[bool | int | float], - SupportsBufferProtocol, - ], + obj: ( + Array | bool | complex | NestedSequence[bool | complex] | SupportsBufferProtocol + ), /, *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, copy: "Optional[Union[bool, np._CopyMode]]" = None, **kwargs, -) -> ndarray: +) -> Array: """ Array API compatibility wrapper for asarray(). @@ -117,23 +112,19 @@ def asarray( def astype( - x: ndarray, - dtype: Dtype, + x: Array, + dtype: DType, /, *, copy: bool = True, device: Optional[Device] = None, -) -> ndarray: +) -> Array: return x.astype(dtype=dtype, copy=copy) # count_nonzero returns a python int for axis=None and keepdims=False # https://github.com/numpy/numpy/issues/17562 -def count_nonzero( - x : ndarray, - axis=None, - keepdims=False -) -> ndarray: +def count_nonzero(x: Array, axis=None, keepdims=False) -> Array: result = np.count_nonzero(x, axis=axis, keepdims=keepdims) if axis is None and not keepdims: return np.asarray(result) diff --git a/array_api_compat/numpy/_typing.py b/array_api_compat/numpy/_typing.py index c5ebb5ab..6a18a3b2 100644 --- a/array_api_compat/numpy/_typing.py +++ b/array_api_compat/numpy/_typing.py @@ -1,46 +1,31 @@ from __future__ import annotations -__all__ = [ - "ndarray", - "Device", - "Dtype", -] +__all__ = ["Array", "DType", "Device"] +_all_ignore = ["np"] -import sys -from typing import ( - Literal, - Union, - TYPE_CHECKING, -) +from typing import Literal, TYPE_CHECKING -from numpy import ( - ndarray, - dtype, - int8, - int16, - int32, - int64, - uint8, - uint16, - uint32, - uint64, - float32, - float64, -) +import numpy as np +from numpy import ndarray as Array Device = Literal["cpu"] -if TYPE_CHECKING or sys.version_info >= (3, 9): - Dtype = dtype[Union[ - int8, - int16, - int32, - int64, - uint8, - uint16, - uint32, - uint64, - float32, - float64, - ]] +if TYPE_CHECKING: + # NumPy 1.x on Python 3.9 and 3.10 fails to parse np.dtype[] + DType = np.dtype[ + np.intp + | np.int8 + | np.int16 + | np.int32 + | np.int64 + | np.uint8 + | np.uint16 + | np.uint32 + | np.uint64 + | np.float32 + | np.float64 + | np.complex64 + | np.complex128 + | np.bool + ] else: - Dtype = dtype + DType = np.dtype diff --git a/array_api_compat/py.typed b/array_api_compat/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 4b727f1c..87d32d85 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -2,21 +2,14 @@ from functools import reduce as _reduce, wraps as _wraps from builtins import all as _builtin_all, any as _builtin_any - -from ..common import _aliases -from .._internal import get_xp - -from ._info import __array_namespace_info__ +from typing import List, Optional, Sequence, Tuple, Union import torch -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from typing import List, Optional, Sequence, Tuple, Union - from ..common._typing import Device - from torch import dtype as Dtype - - array = torch.Tensor +from .._internal import get_xp +from ..common import _aliases +from ._info import __array_namespace_info__ +from ._typing import Array, Device, DType _int_dtypes = { torch.uint8, @@ -123,7 +116,7 @@ def _fix_promotion(x1, x2, only_scalar=True): _py_scalars = (bool, int, float, complex) -def result_type(*arrays_and_dtypes: Union[array, Dtype, bool, int, float, complex]) -> Dtype: +def result_type(*arrays_and_dtypes: Array | DType | complex) -> DType: num = len(arrays_and_dtypes) if num == 0: @@ -170,7 +163,7 @@ def _result_type(x, y): return torch.result_type(x, y) -def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool: +def can_cast(from_: Union[DType, Array], to: DType, /) -> bool: if not isinstance(from_, torch.dtype): from_ = from_.dtype return torch.can_cast(from_, to) @@ -216,13 +209,13 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool: # of 'axis'. # torch.min and torch.max return a tuple and don't support multiple axes https://github.com/pytorch/pytorch/issues/58745 -def max(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array: +def max(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Array: # https://github.com/pytorch/pytorch/issues/29137 if axis == (): return torch.clone(x) return torch.amax(x, axis, keepdims=keepdims) -def min(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array: +def min(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Array: # https://github.com/pytorch/pytorch/issues/29137 if axis == (): return torch.clone(x) @@ -235,7 +228,7 @@ def min(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keep # torch.sort also returns a tuple # https://github.com/pytorch/pytorch/issues/70921 -def sort(x: array, /, *, axis: int = -1, descending: bool = False, stable: bool = True, **kwargs) -> array: +def sort(x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True, **kwargs) -> Array: return torch.sort(x, dim=axis, descending=descending, stable=stable, **kwargs).values def _normalize_axes(axis, ndim): @@ -280,13 +273,13 @@ def _reduce_multiple_axes(f, x, axis, keepdims=False, **kwargs): out = torch.unsqueeze(out, a) return out -def prod(x: array, +def prod(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, keepdims: bool = False, - **kwargs) -> array: + **kwargs) -> Array: x = torch.asarray(x) ndim = x.ndim @@ -316,13 +309,13 @@ def prod(x: array, return torch.prod(x, axis, dtype=dtype, keepdims=keepdims, **kwargs) -def sum(x: array, +def sum(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, keepdims: bool = False, - **kwargs) -> array: + **kwargs) -> Array: x = torch.asarray(x) ndim = x.ndim @@ -347,12 +340,12 @@ def sum(x: array, return torch.sum(x, axis, dtype=dtype, keepdims=keepdims, **kwargs) -def any(x: array, +def any(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, - **kwargs) -> array: + **kwargs) -> Array: x = torch.asarray(x) ndim = x.ndim if axis == (): @@ -372,12 +365,12 @@ def any(x: array, # torch.any doesn't return bool for uint8 return torch.any(x, axis, keepdims=keepdims).to(torch.bool) -def all(x: array, +def all(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, - **kwargs) -> array: + **kwargs) -> Array: x = torch.asarray(x) ndim = x.ndim if axis == (): @@ -397,12 +390,12 @@ def all(x: array, # torch.all doesn't return bool for uint8 return torch.all(x, axis, keepdims=keepdims).to(torch.bool) -def mean(x: array, +def mean(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, - **kwargs) -> array: + **kwargs) -> Array: # https://github.com/pytorch/pytorch/issues/29137 if axis == (): return torch.clone(x) @@ -414,13 +407,13 @@ def mean(x: array, return res return torch.mean(x, axis, keepdims=keepdims, **kwargs) -def std(x: array, +def std(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, correction: Union[int, float] = 0.0, keepdims: bool = False, - **kwargs) -> array: + **kwargs) -> Array: # Note, float correction is not supported # https://github.com/pytorch/pytorch/issues/61492. We don't try to # implement it here for now. @@ -445,13 +438,13 @@ def std(x: array, return res return torch.std(x, axis, correction=_correction, keepdims=keepdims, **kwargs) -def var(x: array, +def var(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, correction: Union[int, float] = 0.0, keepdims: bool = False, - **kwargs) -> array: + **kwargs) -> Array: # Note, float correction is not supported # https://github.com/pytorch/pytorch/issues/61492. We don't try to # implement it here for now. @@ -474,11 +467,11 @@ def var(x: array, # torch.concat doesn't support dim=None # https://github.com/pytorch/pytorch/issues/70925 -def concat(arrays: Union[Tuple[array, ...], List[array]], +def concat(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: Optional[int] = 0, - **kwargs) -> array: + **kwargs) -> Array: if axis is None: arrays = tuple(ar.flatten() for ar in arrays) axis = 0 @@ -487,7 +480,7 @@ def concat(arrays: Union[Tuple[array, ...], List[array]], # torch.squeeze only accepts int dim and doesn't require it # https://github.com/pytorch/pytorch/issues/70924. Support for tuple dim was # added at https://github.com/pytorch/pytorch/pull/89017. -def squeeze(x: array, /, axis: Union[int, Tuple[int, ...]]) -> array: +def squeeze(x: Array, /, axis: Union[int, Tuple[int, ...]]) -> Array: if isinstance(axis, int): axis = (axis,) for a in axis: @@ -501,27 +494,27 @@ def squeeze(x: array, /, axis: Union[int, Tuple[int, ...]]) -> array: return x # torch.broadcast_to uses size instead of shape -def broadcast_to(x: array, /, shape: Tuple[int, ...], **kwargs) -> array: +def broadcast_to(x: Array, /, shape: Tuple[int, ...], **kwargs) -> Array: return torch.broadcast_to(x, shape, **kwargs) # torch.permute uses dims instead of axes -def permute_dims(x: array, /, axes: Tuple[int, ...]) -> array: +def permute_dims(x: Array, /, axes: Tuple[int, ...]) -> Array: return torch.permute(x, axes) # The axis parameter doesn't work for flip() and roll() # https://github.com/pytorch/pytorch/issues/71210. Also torch.flip() doesn't # accept axis=None -def flip(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs) -> array: +def flip(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs) -> Array: if axis is None: axis = tuple(range(x.ndim)) # torch.flip doesn't accept dim as an int but the method does # https://github.com/pytorch/pytorch/issues/18095 return x.flip(axis, **kwargs) -def roll(x: array, /, shift: Union[int, Tuple[int, ...]], *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs) -> array: +def roll(x: Array, /, shift: Union[int, Tuple[int, ...]], *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs) -> Array: return torch.roll(x, shift, axis, **kwargs) -def nonzero(x: array, /, **kwargs) -> Tuple[array, ...]: +def nonzero(x: Array, /, **kwargs) -> Tuple[Array, ...]: if x.ndim == 0: raise ValueError("nonzero() does not support zero-dimensional arrays") return torch.nonzero(x, as_tuple=True, **kwargs) @@ -529,25 +522,25 @@ def nonzero(x: array, /, **kwargs) -> Tuple[array, ...]: # torch uses `dim` instead of `axis` def diff( - x: array, + x: Array, /, *, axis: int = -1, n: int = 1, - prepend: Optional[array] = None, - append: Optional[array] = None, -) -> array: + prepend: Optional[Array] = None, + append: Optional[Array] = None, +) -> Array: return torch.diff(x, dim=axis, n=n, prepend=prepend, append=append) # torch uses `dim` instead of `axis`, does not have keepdims def count_nonzero( - x: array, + x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, -) -> array: +) -> Array: result = torch.count_nonzero(x, dim=axis) if keepdims: if axis is not None: @@ -557,17 +550,17 @@ def count_nonzero( return result - -def where(condition: array, x1: array, x2: array, /) -> array: +def where(condition: Array, x1: Array, x2: Array, /) -> Array: x1, x2 = _fix_promotion(x1, x2) return torch.where(condition, x1, x2) # torch.reshape doesn't have the copy keyword -def reshape(x: array, +def reshape(x: Array, /, shape: Tuple[int, ...], + *, copy: Optional[bool] = None, - **kwargs) -> array: + **kwargs) -> Array: if copy is not None: raise NotImplementedError("torch.reshape doesn't yet support the copy keyword") return torch.reshape(x, shape, **kwargs) @@ -581,9 +574,9 @@ def arange(start: Union[int, float], stop: Optional[Union[int, float]] = None, step: Union[int, float] = 1, *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, - **kwargs) -> array: + **kwargs) -> Array: if stop is None: start, stop = 0, start if step > 0 and stop <= start or step < 0 and stop >= start: @@ -602,9 +595,9 @@ def eye(n_rows: int, /, *, k: int = 0, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, - **kwargs) -> array: + **kwargs) -> Array: if n_cols is None: n_cols = n_rows z = torch.zeros(n_rows, n_cols, dtype=dtype, device=device, **kwargs) @@ -618,10 +611,10 @@ def linspace(start: Union[int, float], /, num: int, *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, endpoint: bool = True, - **kwargs) -> array: + **kwargs) -> Array: if not endpoint: return torch.linspace(start, stop, num+1, dtype=dtype, device=device, **kwargs)[:-1] return torch.linspace(start, stop, num, dtype=dtype, device=device, **kwargs) @@ -629,11 +622,11 @@ def linspace(start: Union[int, float], # torch.full does not accept an int size # https://github.com/pytorch/pytorch/issues/70906 def full(shape: Union[int, Tuple[int, ...]], - fill_value: Union[bool, int, float, complex], + fill_value: complex, *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, - **kwargs) -> array: + **kwargs) -> Array: if isinstance(shape, int): shape = (shape,) @@ -642,52 +635,52 @@ def full(shape: Union[int, Tuple[int, ...]], # ones, zeros, and empty do not accept shape as a keyword argument def ones(shape: Union[int, Tuple[int, ...]], *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, - **kwargs) -> array: + **kwargs) -> Array: return torch.ones(shape, dtype=dtype, device=device, **kwargs) def zeros(shape: Union[int, Tuple[int, ...]], *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, - **kwargs) -> array: + **kwargs) -> Array: return torch.zeros(shape, dtype=dtype, device=device, **kwargs) def empty(shape: Union[int, Tuple[int, ...]], *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, - **kwargs) -> array: + **kwargs) -> Array: return torch.empty(shape, dtype=dtype, device=device, **kwargs) # tril and triu do not call the keyword argument k -def tril(x: array, /, *, k: int = 0) -> array: +def tril(x: Array, /, *, k: int = 0) -> Array: return torch.tril(x, k) -def triu(x: array, /, *, k: int = 0) -> array: +def triu(x: Array, /, *, k: int = 0) -> Array: return torch.triu(x, k) # Functions that aren't in torch https://github.com/pytorch/pytorch/issues/58742 -def expand_dims(x: array, /, *, axis: int = 0) -> array: +def expand_dims(x: Array, /, *, axis: int = 0) -> Array: return torch.unsqueeze(x, axis) def astype( - x: array, - dtype: Dtype, + x: Array, + dtype: DType, /, *, copy: bool = True, device: Optional[Device] = None, -) -> array: +) -> Array: if device is not None: return x.to(device, dtype=dtype, copy=copy) return x.to(dtype=dtype, copy=copy) -def broadcast_arrays(*arrays: array) -> List[array]: +def broadcast_arrays(*arrays: Array) -> List[Array]: shape = torch.broadcast_shapes(*[a.shape for a in arrays]) return [torch.broadcast_to(a, shape) for a in arrays] @@ -697,7 +690,7 @@ def broadcast_arrays(*arrays: array) -> List[array]: UniqueInverseResult) # https://github.com/pytorch/pytorch/issues/70920 -def unique_all(x: array) -> UniqueAllResult: +def unique_all(x: Array) -> UniqueAllResult: # torch.unique doesn't support returning indices. # https://github.com/pytorch/pytorch/issues/36748. The workaround # suggested in that issue doesn't actually function correctly (it relies @@ -710,7 +703,7 @@ def unique_all(x: array) -> UniqueAllResult: # counts[torch.isnan(values)] = 1 # return UniqueAllResult(values, indices, inverse_indices, counts) -def unique_counts(x: array) -> UniqueCountsResult: +def unique_counts(x: Array) -> UniqueCountsResult: values, counts = torch.unique(x, return_counts=True) # torch.unique incorrectly gives a 0 count for nan values. @@ -718,14 +711,14 @@ def unique_counts(x: array) -> UniqueCountsResult: counts[torch.isnan(values)] = 1 return UniqueCountsResult(values, counts) -def unique_inverse(x: array) -> UniqueInverseResult: +def unique_inverse(x: Array) -> UniqueInverseResult: values, inverse = torch.unique(x, return_inverse=True) return UniqueInverseResult(values, inverse) -def unique_values(x: array) -> array: +def unique_values(x: Array) -> Array: return torch.unique(x) -def matmul(x1: array, x2: array, /, **kwargs) -> array: +def matmul(x1: Array, x2: Array, /, **kwargs) -> Array: # torch.matmul doesn't type promote (but differently from _fix_promotion) x1, x2 = _fix_promotion(x1, x2, only_scalar=False) return torch.matmul(x1, x2, **kwargs) @@ -733,12 +726,19 @@ def matmul(x1: array, x2: array, /, **kwargs) -> array: matrix_transpose = get_xp(torch)(_aliases.matrix_transpose) _vecdot = get_xp(torch)(_aliases.vecdot) -def vecdot(x1: array, x2: array, /, *, axis: int = -1) -> array: +def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: x1, x2 = _fix_promotion(x1, x2, only_scalar=False) return _vecdot(x1, x2, axis=axis) # torch.tensordot uses dims instead of axes -def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2, **kwargs) -> array: +def tensordot( + x1: Array, + x2: Array, + /, + *, + axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2, + **kwargs, +) -> Array: # Note: torch.tensordot fails with integer dtypes when there is only 1 # element in the axis (https://github.com/pytorch/pytorch/issues/84530). x1, x2 = _fix_promotion(x1, x2, only_scalar=False) @@ -746,7 +746,7 @@ def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int], def isdtype( - dtype: Dtype, kind: Union[Dtype, str, Tuple[Union[Dtype, str], ...]], + dtype: DType, kind: Union[DType, str, Tuple[Union[DType, str], ...]], *, _tuple=True, # Disallow nested tuples ) -> bool: """ @@ -781,7 +781,7 @@ def isdtype( else: return dtype == kind -def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -> array: +def take(x: Array, indices: Array, /, *, axis: Optional[int] = None, **kwargs) -> Array: if axis is None: if x.ndim != 1: raise ValueError("axis must be specified when ndim > 1") @@ -789,11 +789,11 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) - return torch.index_select(x, axis, indices, **kwargs) -def take_along_axis(x: array, indices: array, /, *, axis: int = -1) -> array: +def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array: return torch.take_along_dim(x, indices, dim=axis) -def sign(x: array, /) -> array: +def sign(x: Array, /) -> Array: # torch sign() does not support complex numbers and does not propagate # nans. See https://github.com/data-apis/array-api-compat/issues/136 if x.dtype.is_complex: diff --git a/array_api_compat/torch/_typing.py b/array_api_compat/torch/_typing.py new file mode 100644 index 00000000..29ad3fa7 --- /dev/null +++ b/array_api_compat/torch/_typing.py @@ -0,0 +1,4 @@ +__all__ = ["Array", "DType", "Device"] + +from torch import dtype as DType, Tensor as Array +from ..common._typing import Device diff --git a/array_api_compat/torch/fft.py b/array_api_compat/torch/fft.py index 3c9117ee..50e6a0d0 100644 --- a/array_api_compat/torch/fft.py +++ b/array_api_compat/torch/fft.py @@ -1,76 +1,75 @@ from __future__ import annotations -from typing import TYPE_CHECKING -if TYPE_CHECKING: - import torch - array = torch.Tensor - from typing import Union, Sequence, Literal +from typing import Union, Sequence, Literal -from torch.fft import * # noqa: F403 +import torch import torch.fft +from torch.fft import * # noqa: F403 + +from ._typing import Array # Several torch fft functions do not map axes to dim def fftn( - x: array, + x: Array, /, *, s: Sequence[int] = None, axes: Sequence[int] = None, norm: Literal["backward", "ortho", "forward"] = "backward", **kwargs, -) -> array: +) -> Array: return torch.fft.fftn(x, s=s, dim=axes, norm=norm, **kwargs) def ifftn( - x: array, + x: Array, /, *, s: Sequence[int] = None, axes: Sequence[int] = None, norm: Literal["backward", "ortho", "forward"] = "backward", **kwargs, -) -> array: +) -> Array: return torch.fft.ifftn(x, s=s, dim=axes, norm=norm, **kwargs) def rfftn( - x: array, + x: Array, /, *, s: Sequence[int] = None, axes: Sequence[int] = None, norm: Literal["backward", "ortho", "forward"] = "backward", **kwargs, -) -> array: +) -> Array: return torch.fft.rfftn(x, s=s, dim=axes, norm=norm, **kwargs) def irfftn( - x: array, + x: Array, /, *, s: Sequence[int] = None, axes: Sequence[int] = None, norm: Literal["backward", "ortho", "forward"] = "backward", **kwargs, -) -> array: +) -> Array: return torch.fft.irfftn(x, s=s, dim=axes, norm=norm, **kwargs) def fftshift( - x: array, + x: Array, /, *, axes: Union[int, Sequence[int]] = None, **kwargs, -) -> array: +) -> Array: return torch.fft.fftshift(x, dim=axes, **kwargs) def ifftshift( - x: array, + x: Array, /, *, axes: Union[int, Sequence[int]] = None, **kwargs, -) -> array: +) -> Array: return torch.fft.ifftshift(x, dim=axes, **kwargs) diff --git a/array_api_compat/torch/linalg.py b/array_api_compat/torch/linalg.py index e26198b9..7b59a670 100644 --- a/array_api_compat/torch/linalg.py +++ b/array_api_compat/torch/linalg.py @@ -1,14 +1,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING -if TYPE_CHECKING: - import torch - array = torch.Tensor - from torch import dtype as Dtype - from typing import Optional, Union, Tuple, Literal - inf = float('inf') - -from ._aliases import _fix_promotion, sum +import torch +from typing import Optional, Union, Tuple from torch.linalg import * # noqa: F403 @@ -19,15 +12,17 @@ # outer is implemented in torch but aren't in the linalg namespace from torch import outer +from ._aliases import _fix_promotion, sum # These functions are in both the main and linalg namespaces from ._aliases import matmul, matrix_transpose, tensordot +from ._typing import Array, DType # Note: torch.linalg.cross does not default to axis=-1 (it defaults to the # first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743 # torch.cross also does not support broadcasting when it would add new # dimensions https://github.com/pytorch/pytorch/issues/39656 -def cross(x1: array, x2: array, /, *, axis: int = -1) -> array: +def cross(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: x1, x2 = _fix_promotion(x1, x2, only_scalar=False) if not (-min(x1.ndim, x2.ndim) <= axis < max(x1.ndim, x2.ndim)): raise ValueError(f"axis {axis} out of bounds for cross product of arrays with shapes {x1.shape} and {x2.shape}") @@ -36,7 +31,7 @@ def cross(x1: array, x2: array, /, *, axis: int = -1) -> array: x1, x2 = torch.broadcast_tensors(x1, x2) return torch_linalg.cross(x1, x2, dim=axis) -def vecdot(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array: +def vecdot(x1: Array, x2: Array, /, *, axis: int = -1, **kwargs) -> Array: from ._aliases import isdtype x1, x2 = _fix_promotion(x1, x2, only_scalar=False) @@ -58,7 +53,7 @@ def vecdot(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array: return res[..., 0, 0] return torch.linalg.vecdot(x1, x2, dim=axis, **kwargs) -def solve(x1: array, x2: array, /, **kwargs) -> array: +def solve(x1: Array, x2: Array, /, **kwargs) -> Array: x1, x2 = _fix_promotion(x1, x2, only_scalar=False) # Torch tries to emulate NumPy 1 solve behavior by using batched 1-D solve # whenever @@ -79,19 +74,20 @@ def solve(x1: array, x2: array, /, **kwargs) -> array: return torch.linalg.solve(x1, x2, **kwargs) # torch.trace doesn't support the offset argument and doesn't support stacking -def trace(x: array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> array: +def trace(x: Array, /, *, offset: int = 0, dtype: Optional[DType] = None) -> Array: # Use our wrapped sum to make sure it does upcasting correctly return sum(torch.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype) def vector_norm( - x: array, + x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, - ord: Union[int, float, Literal[inf, -inf]] = 2, + # float stands for inf | -inf, which are not valid for Literal + ord: Union[int, float, float] = 2, **kwargs, -) -> array: +) -> Array: # torch.vector_norm incorrectly treats axis=() the same as axis=None if axis == (): out = kwargs.get('out') diff --git a/setup.py b/setup.py index 3d2b68a2..2368ccc4 100644 --- a/setup.py +++ b/setup.py @@ -33,5 +33,8 @@ "Programming Language :: Python :: 3.13", "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", - ] + ], + package_data={ + "array_api_compat": ["py.typed"], + }, ) diff --git a/tests/test_all.py b/tests/test_all.py index d2e9b768..598fab62 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -15,6 +15,16 @@ from ._helpers import import_, wrapped_libraries import pytest +import typing + +TYPING_NAMES = frozenset(( + "Array", + "Device", + "DType", + "Namespace", + "NestedSequence", + "SupportsBufferProtocol", +)) @pytest.mark.skip(reason="TODO: starts failing after adding test_torch.py in gh-277") @pytest.mark.parametrize("library", ["common"] + wrapped_libraries) @@ -38,8 +48,11 @@ def test_all(library): dir_names = [n for n in dir(module) if not n.startswith('_')] if '__array_namespace_info__' in dir(module): dir_names.append('__array_namespace_info__') - ignore_all_names = getattr(module, '_all_ignore', []) - ignore_all_names += ['annotations', 'TYPE_CHECKING'] + ignore_all_names = set(getattr(module, '_all_ignore', ())) + ignore_all_names |= set(dir(typing)) + ignore_all_names |= {"annotations"} + if not module.__name__.endswith("._typing"): + ignore_all_names |= TYPING_NAMES dir_names = set(dir_names) - set(ignore_all_names) all_names = module.__all__ From 26845bd904ee66bb830463f46bb39f1cc5392275 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sat, 22 Mar 2025 11:12:31 +0100 Subject: [PATCH 05/28] Revert "TST: skip test_all" This reverts commit 5473d84d5c36b23e091b880279c863c32f41b828. --- tests/test_all.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_all.py b/tests/test_all.py index 598fab62..eeb67e4b 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -26,7 +26,6 @@ "SupportsBufferProtocol", )) -@pytest.mark.skip(reason="TODO: starts failing after adding test_torch.py in gh-277") @pytest.mark.parametrize("library", ["common"] + wrapped_libraries) def test_all(library): if library == "common": From 07a3cd41e1c5804b7c11d358400431e8a53a984a Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sat, 22 Mar 2025 11:40:02 +0100 Subject: [PATCH 06/28] MAINT: run self-tests even if a library is missing --- tests/test_array_namespace.py | 6 ++++-- tests/test_dask.py | 8 ++++++-- tests/test_jax.py | 8 ++++++-- tests/test_torch.py | 6 +++++- tests/test_vendoring.py | 2 ++ 5 files changed, 23 insertions(+), 7 deletions(-) diff --git a/tests/test_array_namespace.py b/tests/test_array_namespace.py index 605c69a1..cdb80007 100644 --- a/tests/test_array_namespace.py +++ b/tests/test_array_namespace.py @@ -2,10 +2,8 @@ import sys import warnings -import jax import numpy as np import pytest -import torch import array_api_compat from array_api_compat import array_namespace @@ -76,6 +74,7 @@ def test_array_namespace(library, api_version, use_compat): subprocess.run([sys.executable, "-c", code], check=True) def test_jax_zero_gradient(): + jax = import_("jax") jx = jax.numpy.arange(4) jax_zero = jax.vmap(jax.grad(jax.numpy.float32, allow_int=True))(jx) assert array_namespace(jax_zero) is array_namespace(jx) @@ -89,11 +88,13 @@ def test_array_namespace_errors(): pytest.raises(TypeError, lambda: array_namespace(x, (x, x))) def test_array_namespace_errors_torch(): + torch = import_("torch") y = torch.asarray([1, 2]) x = np.asarray([1, 2]) pytest.raises(TypeError, lambda: array_namespace(x, y)) def test_api_version_torch(): + torch = import_("torch") x = torch.asarray([1, 2]) torch_ = import_("torch", wrapper=True) assert array_namespace(x, api_version="2023.12") == torch_ @@ -118,6 +119,7 @@ def test_get_namespace(): assert array_api_compat.get_namespace is array_namespace def test_python_scalars(): + torch = import_("torch") a = torch.asarray([1, 2]) xp = import_("torch", wrapper=True) diff --git a/tests/test_dask.py b/tests/test_dask.py index be2b1e39..69c738f6 100644 --- a/tests/test_dask.py +++ b/tests/test_dask.py @@ -1,10 +1,14 @@ from contextlib import contextmanager import array_api_strict -import dask import numpy as np import pytest -import dask.array as da + +try: + import dask + import dask.array as da +except ImportError: + pytestmark = pytest.skip(allow_module_level=True, reason="dask not found") from array_api_compat import array_namespace diff --git a/tests/test_jax.py b/tests/test_jax.py index e33cec02..285958d4 100644 --- a/tests/test_jax.py +++ b/tests/test_jax.py @@ -1,10 +1,14 @@ -import jax -import jax.numpy as jnp from numpy.testing import assert_equal import pytest from array_api_compat import device, to_device +try: + import jax + import jax.numpy as jnp +except ImportError: + pytestmark = pytest.skip(allow_module_level=True, reason="jax not found") + HAS_JAX_0_4_31 = jax.__version__ >= "0.4.31" diff --git a/tests/test_torch.py b/tests/test_torch.py index 75b3a136..e8340f31 100644 --- a/tests/test_torch.py +++ b/tests/test_torch.py @@ -3,7 +3,11 @@ import itertools import pytest -import torch + +try: + import torch +except ImportError: + pytestmark = pytest.skip(allow_module_level=True, reason="pytorch not found") from array_api_compat import torch as xp diff --git a/tests/test_vendoring.py b/tests/test_vendoring.py index 70083b49..8b561551 100644 --- a/tests/test_vendoring.py +++ b/tests/test_vendoring.py @@ -16,11 +16,13 @@ def test_vendoring_cupy(): def test_vendoring_torch(): + pytest.importorskip("torch") from vendor_test import uses_torch uses_torch._test_torch() def test_vendoring_dask(): + pytest.importorskip("dask") from vendor_test import uses_dask uses_dask._test_dask() From 89466a6b43672b9a4a2dbdaea2896c24e4dcdd76 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sat, 22 Mar 2025 14:01:44 +0100 Subject: [PATCH 07/28] MAINT: common._aliases.__all__ --- array_api_compat/common/_aliases.py | 18 +++++++++++++----- tests/test_all.py | 2 +- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 0d123b99..0d1ecfbc 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -7,8 +7,14 @@ import inspect from typing import NamedTuple, Optional, Sequence, Tuple, Union -from ._helpers import array_namespace, _check_device, device, is_cupy_namespace from ._typing import Array, Device, DType, Namespace +from ._helpers import ( + array_namespace, + _check_device, + device as _get_device, + is_cupy_namespace as _is_cupy_namespace +) + # These functions are modified from the NumPy versions. @@ -298,7 +304,7 @@ def cumulative_sum( initial_shape = list(x.shape) initial_shape[axis] = 1 res = xp.concatenate( - [wrapped_xp.zeros(shape=initial_shape, dtype=res.dtype, device=device(res)), res], + [wrapped_xp.zeros(shape=initial_shape, dtype=res.dtype, device=_get_device(res)), res], axis=axis, ) return res @@ -328,7 +334,7 @@ def cumulative_prod( initial_shape = list(x.shape) initial_shape[axis] = 1 res = xp.concatenate( - [wrapped_xp.ones(shape=initial_shape, dtype=res.dtype, device=device(res)), res], + [wrapped_xp.ones(shape=initial_shape, dtype=res.dtype, device=_get_device(res)), res], axis=axis, ) return res @@ -381,7 +387,7 @@ def _isscalar(a): if type(max) is int and max >= wrapped_xp.iinfo(x.dtype).max: max = None - dev = device(x) + dev = _get_device(x) if out is None: out = wrapped_xp.empty(result_shape, dtype=x.dtype, device=dev) out[()] = x @@ -599,7 +605,7 @@ def sign(x: Array, /, xp: Namespace, **kwargs) -> Array: out = xp.sign(x, **kwargs) # CuPy sign() does not propagate nans. See # https://github.com/data-apis/array-api-compat/issues/136 - if is_cupy_namespace(xp) and isdtype(x.dtype, 'real floating', xp=xp): + if _is_cupy_namespace(xp) and isdtype(x.dtype, 'real floating', xp=xp): out[xp.isnan(x)] = xp.nan return out[()] @@ -611,3 +617,5 @@ def sign(x: Array, /, xp: Namespace, **kwargs) -> Array: 'reshape', 'argsort', 'sort', 'nonzero', 'ceil', 'floor', 'trunc', 'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype', 'unstack', 'sign'] + +_all_ignore = ['inspect', 'array_namespace', 'NamedTuple'] diff --git a/tests/test_all.py b/tests/test_all.py index eeb67e4b..4df4a361 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -33,7 +33,7 @@ def test_all(library): else: import_(library, wrapper=True) - for mod_name in sys.modules: + for mod_name in sys.modules.copy(): if not mod_name.startswith('array_api_compat.' + library): continue From 23841dfdb319fbb66a4065e0c138235c56e611f0 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sun, 23 Mar 2025 09:28:03 +0100 Subject: [PATCH 08/28] TST: update the torch skiplist --- torch-xfails.txt | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torch-xfails.txt b/torch-xfails.txt index 6e8f7dc6..f8333d90 100644 --- a/torch-xfails.txt +++ b/torch-xfails.txt @@ -144,10 +144,12 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_sc array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[logaddexp] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[maximum] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[minimum] + +# https://github.com/pytorch/pytorch/issues/149815 array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[equal] -array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[neq] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[not_equal] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[less] -array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[les_equal] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[less_equal] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[greater] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[greater_equal] From 3b4ea593d43c3d522aa1e601a93781774606bbc3 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sun, 23 Mar 2025 09:33:26 +0100 Subject: [PATCH 09/28] TST: update numpy<2 skiplists --- numpy-1-21-xfails.txt | 1 + numpy-1-26-xfails.txt | 1 + 2 files changed, 2 insertions(+) diff --git a/numpy-1-21-xfails.txt b/numpy-1-21-xfails.txt index 7c7a0757..30cde668 100644 --- a/numpy-1-21-xfails.txt +++ b/numpy-1-21-xfails.txt @@ -192,6 +192,7 @@ array_api_tests/test_signatures.py::test_func_signature[bitwise_left_shift] array_api_tests/test_signatures.py::test_func_signature[bitwise_or] array_api_tests/test_signatures.py::test_func_signature[bitwise_right_shift] array_api_tests/test_signatures.py::test_func_signature[bitwise_xor] +array_api_tests/test_data_type_functions.py::TestResultType::test_with_scalars # Stubs have a comment: (**note**: libraries may return ``NaN`` to match Python behavior.); Apparently,NumPy does just that array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] diff --git a/numpy-1-26-xfails.txt b/numpy-1-26-xfails.txt index 57259b6f..1ce28ef4 100644 --- a/numpy-1-26-xfails.txt +++ b/numpy-1-26-xfails.txt @@ -46,6 +46,7 @@ array_api_tests/test_signatures.py::test_func_signature[bitwise_left_shift] array_api_tests/test_signatures.py::test_func_signature[bitwise_or] array_api_tests/test_signatures.py::test_func_signature[bitwise_right_shift] array_api_tests/test_signatures.py::test_func_signature[bitwise_xor] +array_api_tests/test_data_type_functions.py::TestResultType::test_with_scalars # Stubs have a comment: (**note**: libraries may return ``NaN`` to match Python behavior.); Apparently, NumPy does just that array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] From f19256e3e132f0c16147936d1cf320680366055a Mon Sep 17 00:00:00 2001 From: Neil Girdhar Date: Fri, 21 Mar 2025 07:02:22 -0400 Subject: [PATCH 10/28] Add pyprject.toml --- .github/workflows/docs-build.yml | 2 +- .github/workflows/tests.yml | 6 +- docs/dev/tests.md | 2 +- docs/requirements.txt | 6 -- pyproject.toml | 96 ++++++++++++++++++++++++++++++++ requirements-dev.txt | 8 --- ruff.toml | 17 ------ setup.py | 40 ------------- 8 files changed, 99 insertions(+), 78 deletions(-) delete mode 100644 docs/requirements.txt create mode 100644 pyproject.toml delete mode 100644 requirements-dev.txt delete mode 100644 ruff.toml delete mode 100644 setup.py diff --git a/.github/workflows/docs-build.yml b/.github/workflows/docs-build.yml index 04c3aa66..34b9cbc6 100644 --- a/.github/workflows/docs-build.yml +++ b/.github/workflows/docs-build.yml @@ -10,7 +10,7 @@ jobs: - uses: actions/setup-python@v5 - name: Install Dependencies run: | - python -m pip install -r docs/requirements.txt + python -m pip install .[docs] - name: Build Docs run: | cd docs diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index fcd43367..54f6f402 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -29,11 +29,7 @@ jobs: PIP_EXTRA='numpy==1.26.*' fi - if [ "${{ matrix.python-version }}" == "3.9" ]; then - sed -i '/^ndonnx/d' requirements-dev.txt - fi - - python -m pip install -r requirements-dev.txt $PIP_EXTRA + python -m pip install .[dev] $PIP_EXTRA - name: Run Tests run: | diff --git a/docs/dev/tests.md b/docs/dev/tests.md index 6d9d1d7b..18fb7cf5 100644 --- a/docs/dev/tests.md +++ b/docs/dev/tests.md @@ -7,7 +7,7 @@ the array API standard. There are also array-api-compat specific tests in These tests should be limited to things that are not tested by the test suite, e.g., tests for [helper functions](../helper-functions.rst) or for behavior that is not strictly required by the standard. To run these tests, install the -dependencies from `requirements-dev.txt` (array-api-compat has [no hard +dependencies from the `dev` optional group (array-api-compat has [no hard runtime dependencies](no-dependencies)). array-api-tests is run against all supported libraries are tested on CI diff --git a/docs/requirements.txt b/docs/requirements.txt deleted file mode 100644 index dbec7740..00000000 --- a/docs/requirements.txt +++ /dev/null @@ -1,6 +0,0 @@ -furo -linkify-it-py -myst-parser -sphinx -sphinx-copybutton -sphinx-autobuild diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..f17c720f --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,96 @@ +[build-system] +requires = ["setuptools", "setuptools-scm"] +build-backend = "setuptools.build_meta" + +[project] +name = "array-api-compat" +dynamic = ["version"] +description = "A wrapper around NumPy and other array libraries to make them compatible with the Array API standard" +readme = "README.md" +requires-python = ">=3.9" +license = "MIT" +authors = [{name = "Consortium for Python Data API Standards"}] +classifiers = [ + "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Topic :: Software Development :: Libraries :: Python Modules", + "Typing :: Typed", +] + +[project.optional-dependencies] +cupy = ["cupy"] +dask = ["dask"] +jax = ["jax"] +numpy = ["numpy"] +pytorch = ["torch"] +sparse = ["sparse>=0.15.1"] +docs = [ + "furo", + "linkify-it-py", + "myst-parser", + "sphinx", + "sphinx-copybutton", + "sphinx-autobuild", +] +dev = [ + "array-api-strict", + "dask[array]", + "jax[cpu]", + "numpy", + "pytest", + "torch", + "sparse>=0.15.1", + "ndonnx; python_version>=\"3.10\"" +] + +[project.urls] +homepage = "https://data-apis.org/array-api-compat/" +repository = "https://github.com/data-apis/array-api-compat/" + +[tool.setuptools.dynamic] +version = {attr = "array_api_compat.__version__"} + +[tool.setuptools.packages.find] +include = ["array_api_compat*"] +namespaces = false + +[toolint] +preview = true +select = [ +# Defaults +"E4", "E7", "E9", "F", +# Undefined export +"F822", +# Useless import alias +"PLC0414" +] + +ignore = [ + # Module import not at top of file + "E402", + # Do not use bare `except` + "E722" +] + +[tool.ruff.lint] +preview = true +select = [ +# Defaults +"E4", "E7", "E9", "F", +# Undefined export +"F822", +# Useless import alias +"PLC0414" +] + +ignore = [ + # Module import not at top of file + "E402", + # Do not use bare `except` + "E722" +] diff --git a/requirements-dev.txt b/requirements-dev.txt deleted file mode 100644 index c9d10f71..00000000 --- a/requirements-dev.txt +++ /dev/null @@ -1,8 +0,0 @@ -array-api-strict -dask[array] -jax[cpu] -numpy -pytest -torch -sparse >=0.15.1 -ndonnx diff --git a/ruff.toml b/ruff.toml deleted file mode 100644 index 72e111b5..00000000 --- a/ruff.toml +++ /dev/null @@ -1,17 +0,0 @@ -[lint] -preview = true -select = [ -# Defaults -"E4", "E7", "E9", "F", -# Undefined export -"F822", -# Useless import alias -"PLC0414" -] - -ignore = [ - # Module import not at top of file - "E402", - # Do not use bare `except` - "E722" -] diff --git a/setup.py b/setup.py deleted file mode 100644 index 2368ccc4..00000000 --- a/setup.py +++ /dev/null @@ -1,40 +0,0 @@ -from setuptools import setup, find_packages - -with open("README.md", "r") as fh: - long_description = fh.read() - -import array_api_compat - -setup( - name='array_api_compat', - version=array_api_compat.__version__, - packages=find_packages(include=["array_api_compat*"]), - author="Consortium for Python Data API Standards", - description="A wrapper around NumPy and other array libraries to make them compatible with the Array API standard", - long_description=long_description, - long_description_content_type="text/markdown", - url="https://data-apis.org/array-api-compat/", - license="MIT", - extras_require={ - "numpy": "numpy", - "cupy": "cupy", - "jax": "jax", - "pytorch": "pytorch", - "dask": "dask", - "sparse": "sparse >=0.15.1", - }, - python_requires=">=3.9", - classifiers=[ - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", - "Programming Language :: Python :: 3.12", - "Programming Language :: Python :: 3.13", - "License :: OSI Approved :: MIT License", - "Operating System :: OS Independent", - ], - package_data={ - "array_api_compat": ["py.typed"], - }, -) From 1db3fae0f682199bda3ae920f8a695e4f579b439 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Tue, 25 Mar 2025 18:07:45 +0000 Subject: [PATCH 11/28] ENH: correct Dask capabilities --- array_api_compat/dask/array/_info.py | 22 ++++++++++++++++------ dask-xfails.txt | 8 +++++--- 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/array_api_compat/dask/array/_info.py b/array_api_compat/dask/array/_info.py index e15a69f4..fc70b5a2 100644 --- a/array_api_compat/dask/array/_info.py +++ b/array_api_compat/dask/array/_info.py @@ -68,11 +68,22 @@ def capabilities(self): The resulting dictionary has the following keys: - **"boolean indexing"**: boolean indicating whether an array library - supports boolean indexing. Always ``False`` for Dask. + supports boolean indexing. + + Dask support boolean indexing as long as both the index + and the indexed arrays have known shapes. + Note however that the output .shape and .size properties + will contain a non-compliant math.nan instead of None. - **"data-dependent shapes"**: boolean indicating whether an array - library supports data-dependent output shapes. Always ``False`` for - Dask. + library supports data-dependent output shapes. + + Dask implements unique_values et.al. + Note however that the output .shape and .size properties + will contain a non-compliant math.nan instead of None. + + - **"max dimensions"**: integer indicating the maximum number of + dimensions supported by the array library. See https://data-apis.org/array-api/latest/API_specification/generated/array_api.info.capabilities.html @@ -99,9 +110,8 @@ def capabilities(self): """ return { - "boolean indexing": False, - "data-dependent shapes": False, - # 'max rank' will be part of the 2024.12 standard + "boolean indexing": True, + "data-dependent shapes": True, "max dimensions": 64, } diff --git a/dask-xfails.txt b/dask-xfails.txt index d2474f9f..bd65d004 100644 --- a/dask-xfails.txt +++ b/dask-xfails.txt @@ -28,12 +28,14 @@ array_api_tests/test_has_names.py::test_has_names[array_method-to_device] array_api_tests/test_has_names.py::test_has_names[array_attribute-device] array_api_tests/test_has_names.py::test_has_names[array_attribute-mT] -# Fails because shape is NaN since we don't materialize it yet +# Data-dependent output shape +# These tests fail as array-api-tests doesn't cope with unknown shapes +# Also, output shape is (math.nan, ) instead of (None, ) +# Also, da.unique() doesn't accept equals_nan which causes non-compliant +# output when there are NaNs in the input. array_api_tests/test_searching_functions.py::test_nonzero array_api_tests/test_set_functions.py::test_unique_all array_api_tests/test_set_functions.py::test_unique_counts - -# Different error but same cause as above, we're just trying to do ndindex on nan shape array_api_tests/test_set_functions.py::test_unique_inverse array_api_tests/test_set_functions.py::test_unique_values From 71d90ead399c03f5fcbc15d205d7cedb6bc9825c Mon Sep 17 00:00:00 2001 From: Lucas Colley Date: Sun, 30 Mar 2025 09:19:56 +0100 Subject: [PATCH 12/28] Update test_all.py Co-authored-by: Evgeni Burovski --- tests/test_all.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_all.py b/tests/test_all.py index 4df4a361..271cd189 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -33,6 +33,7 @@ def test_all(library): else: import_(library, wrapper=True) + # NB: iterate over a copy to avoid a "dictionary size changed" error for mod_name in sys.modules.copy(): if not mod_name.startswith('array_api_compat.' + library): continue From b2af137864a484908fc96fddb1e47af56f0a4adf Mon Sep 17 00:00:00 2001 From: Guido Imperiale Date: Mon, 31 Mar 2025 23:51:14 +0100 Subject: [PATCH 13/28] TYP: Type annotations overhaul, part 2 (#291) --- array_api_compat/common/_aliases.py | 4 ++-- array_api_compat/cupy/_aliases.py | 5 ++++- array_api_compat/dask/array/_aliases.py | 5 ++++- array_api_compat/numpy/_aliases.py | 5 ++++- array_api_compat/torch/_aliases.py | 14 +++++++++++--- array_api_compat/torch/linalg.py | 2 +- 6 files changed, 26 insertions(+), 9 deletions(-) diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 0d1ecfbc..03910681 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -73,7 +73,7 @@ def eye( def full( shape: Union[int, Tuple[int, ...]], - fill_value: complex, + fill_value: bool | int | float | complex, xp: Namespace, *, dtype: Optional[DType] = None, @@ -86,7 +86,7 @@ def full( def full_like( x: Array, /, - fill_value: complex, + fill_value: bool | int | float | complex, *, xp: Namespace, dtype: Optional[DType] = None, diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index ebc7ccd9..423fd10a 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -68,7 +68,10 @@ # asarray also adds the copy keyword, which is not present in numpy 1.0. def asarray( obj: ( - Array | bool | complex | NestedSequence[bool | complex] | SupportsBufferProtocol + Array + | bool | int | float | complex + | NestedSequence[bool | int | float | complex] + | SupportsBufferProtocol ), /, *, diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index e737cebd..e6eff359 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -136,7 +136,10 @@ def arange( # asarray also adds the copy keyword, which is not present in numpy 1.0. def asarray( obj: ( - Array | bool | complex | NestedSequence[bool | complex] | SupportsBufferProtocol + Array + | bool | int | float | complex + | NestedSequence[bool | int | float | complex] + | SupportsBufferProtocol ), /, *, diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index 6536d9a8..1d084b2b 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -77,7 +77,10 @@ def _supports_buffer_protocol(obj): # rather than trying to combine everything into one function in common/ def asarray( obj: ( - Array | bool | complex | NestedSequence[bool | complex] | SupportsBufferProtocol + Array + | bool | int | float | complex + | NestedSequence[bool | int | float | complex] + | SupportsBufferProtocol ), /, *, diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 87d32d85..982500b0 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -116,7 +116,9 @@ def _fix_promotion(x1, x2, only_scalar=True): _py_scalars = (bool, int, float, complex) -def result_type(*arrays_and_dtypes: Array | DType | complex) -> DType: +def result_type( + *arrays_and_dtypes: Array | DType | bool | int | float | complex +) -> DType: num = len(arrays_and_dtypes) if num == 0: @@ -550,10 +552,16 @@ def count_nonzero( return result -def where(condition: Array, x1: Array, x2: Array, /) -> Array: +def where( + condition: Array, + x1: Array | bool | int | float | complex, + x2: Array | bool | int | float | complex, + /, +) -> Array: x1, x2 = _fix_promotion(x1, x2) return torch.where(condition, x1, x2) + # torch.reshape doesn't have the copy keyword def reshape(x: Array, /, @@ -622,7 +630,7 @@ def linspace(start: Union[int, float], # torch.full does not accept an int size # https://github.com/pytorch/pytorch/issues/70906 def full(shape: Union[int, Tuple[int, ...]], - fill_value: complex, + fill_value: bool | int | float | complex, *, dtype: Optional[DType] = None, device: Optional[Device] = None, diff --git a/array_api_compat/torch/linalg.py b/array_api_compat/torch/linalg.py index 7b59a670..1ff7319d 100644 --- a/array_api_compat/torch/linalg.py +++ b/array_api_compat/torch/linalg.py @@ -85,7 +85,7 @@ def vector_norm( axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, # float stands for inf | -inf, which are not valid for Literal - ord: Union[int, float, float] = 2, + ord: Union[int, float] = 2, **kwargs, ) -> Array: # torch.vector_norm incorrectly treats axis=() the same as axis=None From 29f494160a7657dc4da21113851bd6880e39dc7c Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 3 Apr 2025 10:26:59 +0100 Subject: [PATCH 14/28] TST: bump to ndonnx 0.10.1 --- tests/test_common.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/test_common.py b/tests/test_common.py index f86e0936..bbf14572 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -234,6 +234,7 @@ def test_asarray_cross_library(source_library, target_library, request): # TODO: remove xfail once # https://github.com/dask/dask/issues/8260 is resolved xfail(request, reason="Bug in dask raising error on conversion") + elif ( source_library == "ndonnx" and target_library not in ("array_api_strict", "ndonnx", "numpy") @@ -241,6 +242,9 @@ def test_asarray_cross_library(source_library, target_library, request): xfail(request, reason="The truth value of lazy Array Array(dtype=Boolean) is unknown") elif source_library == "ndonnx" and target_library == "numpy": xfail(request, reason="produces numpy array of ndonnx scalar arrays") + elif target_library == "ndonnx" and source_library in ("torch", "dask.array", "jax.numpy"): + xfail(request, reason="unable to infer dtype") + elif source_library == "jax.numpy" and target_library == "torch": xfail(request, reason="casts int to float") elif source_library == "cupy" and target_library != "cupy": From f80f15792ec981e943bef7f49faff687ef29b27c Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 2 Apr 2025 17:07:20 +0100 Subject: [PATCH 15/28] ENH: wrap iinfo/finfo --- array_api_compat/common/_aliases.py | 21 +++++++++++++++++++-- array_api_compat/cupy/_aliases.py | 2 ++ array_api_compat/dask/array/_aliases.py | 9 ++++----- array_api_compat/numpy/_aliases.py | 2 ++ array_api_compat/torch/_aliases.py | 5 ++++- cupy-xfails.txt | 11 ++++++++--- dask-xfails.txt | 10 ++++++++-- numpy-1-21-xfails.txt | 12 +++++++++--- numpy-1-26-xfails.txt | 10 ++++++++-- numpy-dev-xfails.txt | 10 ++++++++-- numpy-xfails.txt | 10 ++++++++-- torch-xfails.txt | 4 ++++ 12 files changed, 84 insertions(+), 22 deletions(-) diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 03910681..46cbb359 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -5,7 +5,7 @@ from __future__ import annotations import inspect -from typing import NamedTuple, Optional, Sequence, Tuple, Union +from typing import Any, NamedTuple, Optional, Sequence, Tuple, Union from ._typing import Array, Device, DType, Namespace from ._helpers import ( @@ -609,6 +609,23 @@ def sign(x: Array, /, xp: Namespace, **kwargs) -> Array: out[xp.isnan(x)] = xp.nan return out[()] + +def finfo(type_: DType | Array, /, xp: Namespace) -> Any: + # It is surprisingly difficult to recognize a dtype apart from an array. + # np.int64 is not the same as np.asarray(1).dtype! + try: + return xp.finfo(type_) + except (ValueError, TypeError): + return xp.finfo(type_.dtype) + + +def iinfo(type_: DType | Array, /, xp: Namespace) -> Any: + try: + return xp.iinfo(type_) + except (ValueError, TypeError): + return xp.iinfo(type_.dtype) + + __all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like', 'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like', 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult', @@ -616,6 +633,6 @@ def sign(x: Array, /, xp: Namespace, **kwargs) -> Array: 'std', 'var', 'cumulative_sum', 'cumulative_prod','clip', 'permute_dims', 'reshape', 'argsort', 'sort', 'nonzero', 'ceil', 'floor', 'trunc', 'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype', - 'unstack', 'sign'] + 'unstack', 'sign', 'finfo', 'iinfo'] _all_ignore = ['inspect', 'array_namespace', 'NamedTuple'] diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index 423fd10a..fd1460ae 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -61,6 +61,8 @@ matrix_transpose = get_xp(cp)(_aliases.matrix_transpose) tensordot = get_xp(cp)(_aliases.tensordot) sign = get_xp(cp)(_aliases.sign) +finfo = get_xp(cp)(_aliases.finfo) +iinfo = get_xp(cp)(_aliases.iinfo) _copy_default = object() diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index e6eff359..dca6d570 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -5,8 +5,6 @@ import numpy as np from numpy import ( # dtypes - iinfo, - finfo, bool_ as bool, float32, float64, @@ -131,6 +129,8 @@ def arange( matmul = get_xp(np)(_aliases.matmul) tensordot = get_xp(np)(_aliases.tensordot) sign = get_xp(np)(_aliases.sign) +finfo = get_xp(np)(_aliases.finfo) +iinfo = get_xp(np)(_aliases.iinfo) # asarray also adds the copy keyword, which is not present in numpy 1.0. @@ -343,10 +343,9 @@ def count_nonzero( '__array_namespace_info__', 'asarray', 'astype', 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atan2', 'atanh', 'bitwise_left_shift', 'bitwise_invert', - 'bitwise_right_shift', 'concat', 'pow', 'iinfo', 'finfo', 'can_cast', + 'bitwise_right_shift', 'concat', 'pow', 'can_cast', 'result_type', 'bool', 'float32', 'float64', 'int8', 'int16', 'int32', 'int64', - 'uint8', 'uint16', 'uint32', 'uint64', - 'complex64', 'complex128', 'iinfo', 'finfo', + 'uint8', 'uint16', 'uint32', 'uint64', 'complex64', 'complex128', 'can_cast', 'count_nonzero', 'result_type'] _all_ignore = ["array_namespace", "get_xp", "da", "np"] diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index 1d084b2b..ae0d006d 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -61,6 +61,8 @@ matrix_transpose = get_xp(np)(_aliases.matrix_transpose) tensordot = get_xp(np)(_aliases.tensordot) sign = get_xp(np)(_aliases.sign) +finfo = get_xp(np)(_aliases.finfo) +iinfo = get_xp(np)(_aliases.iinfo) def _supports_buffer_protocol(obj): diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 982500b0..9384e4c0 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -227,6 +227,9 @@ def min(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keep unstack = get_xp(torch)(_aliases.unstack) cumulative_sum = get_xp(torch)(_aliases.cumulative_sum) cumulative_prod = get_xp(torch)(_aliases.cumulative_prod) +finfo = get_xp(torch)(_aliases.finfo) +iinfo = get_xp(torch)(_aliases.iinfo) + # torch.sort also returns a tuple # https://github.com/pytorch/pytorch/issues/70921 @@ -832,6 +835,6 @@ def sign(x: Array, /) -> Array: 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult', 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values', 'matmul', 'matrix_transpose', 'vecdot', 'tensordot', 'isdtype', - 'take', 'take_along_axis', 'sign'] + 'take', 'take_along_axis', 'sign', 'finfo', 'iinfo'] _all_ignore = ['torch', 'get_xp'] diff --git a/cupy-xfails.txt b/cupy-xfails.txt index 3d20d745..f4cd1e36 100644 --- a/cupy-xfails.txt +++ b/cupy-xfails.txt @@ -14,9 +14,14 @@ array_api_tests/test_array_object.py::test_getitem # copy=False is not yet implemented array_api_tests/test_creation_functions.py::test_asarray_arrays -# finfo test is testing that the result is a float instead of float32 (see -# also https://github.com/data-apis/array-api/issues/405) -array_api_tests/test_data_type_functions.py::test_finfo[float32] +# attributes are np.float32 instead of float +# (see also https://github.com/data-apis/array-api/issues/405) +array_api_tests/test_data_type_functions.py::test_finfo[float32-dtype] +array_api_tests/test_data_type_functions.py::test_finfo[float32-array] +array_api_tests/test_data_type_functions.py::test_finfo[float32-array.dtype] +array_api_tests/test_data_type_functions.py::test_finfo[complex64-dtype] +array_api_tests/test_data_type_functions.py::test_finfo[complex64-array] +array_api_tests/test_data_type_functions.py::test_finfo[complex64-array.dtype] # Some array attributes are missing, and we do not wrap the array object array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__] diff --git a/dask-xfails.txt b/dask-xfails.txt index bd65d004..abab825c 100644 --- a/dask-xfails.txt +++ b/dask-xfails.txt @@ -12,8 +12,14 @@ array_api_tests/test_array_object.py::test_getitem_masking # zero division error, and typeerror: tuple indices must be integers or slices not tuple array_api_tests/test_creation_functions.py::test_eye -# finfo(float32).eps returns float32 but should return float -array_api_tests/test_data_type_functions.py::test_finfo[float32] +# attributes are np.float32 instead of float +# (see also https://github.com/data-apis/array-api/issues/405) +array_api_tests/test_data_type_functions.py::test_finfo[float32-dtype] +array_api_tests/test_data_type_functions.py::test_finfo[float32-array] +array_api_tests/test_data_type_functions.py::test_finfo[float32-array.dtype] +array_api_tests/test_data_type_functions.py::test_finfo[complex64-dtype] +array_api_tests/test_data_type_functions.py::test_finfo[complex64-array] +array_api_tests/test_data_type_functions.py::test_finfo[complex64-array.dtype] # out[-1]=dask.array but should be some floating number # (I think the test is not forcing the op to be computed?) diff --git a/numpy-1-21-xfails.txt b/numpy-1-21-xfails.txt index 30cde668..93a90757 100644 --- a/numpy-1-21-xfails.txt +++ b/numpy-1-21-xfails.txt @@ -1,8 +1,14 @@ # asarray(copy=False) is not yet implemented array_api_tests/test_creation_functions.py::test_asarray_arrays -# finfo(float32).eps returns float32 but should return float -array_api_tests/test_data_type_functions.py::test_finfo[float32] +# attributes are np.float32 instead of float +# (see also https://github.com/data-apis/array-api/issues/405) +array_api_tests/test_data_type_functions.py::test_finfo[float32-dtype] +array_api_tests/test_data_type_functions.py::test_finfo[float32-array] +array_api_tests/test_data_type_functions.py::test_finfo[float32-array.dtype] +array_api_tests/test_data_type_functions.py::test_finfo[complex64-dtype] +array_api_tests/test_data_type_functions.py::test_finfo[complex64-array] +array_api_tests/test_data_type_functions.py::test_finfo[complex64-array.dtype] # Array methods and attributes not already on np.ndarray cannot be wrapped array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__] @@ -41,7 +47,7 @@ array_api_tests/meta/test_hypothesis_helpers.py::test_symmetric_matrices ############################ # finfo has no smallest_normal -array_api_tests/test_data_type_functions.py::test_finfo[float64] +array_api_tests/test_data_type_functions.py::test_finfo # dlpack stuff array_api_tests/test_has_names.py::test_has_names[creation-from_dlpack] diff --git a/numpy-1-26-xfails.txt b/numpy-1-26-xfails.txt index 1ce28ef4..84916e73 100644 --- a/numpy-1-26-xfails.txt +++ b/numpy-1-26-xfails.txt @@ -1,5 +1,11 @@ -# finfo(float32).eps returns float32 but should return float -array_api_tests/test_data_type_functions.py::test_finfo[float32] +# attributes are np.float32 instead of float +# (see also https://github.com/data-apis/array-api/issues/405) +array_api_tests/test_data_type_functions.py::test_finfo[float32-dtype] +array_api_tests/test_data_type_functions.py::test_finfo[float32-array] +array_api_tests/test_data_type_functions.py::test_finfo[float32-array.dtype] +array_api_tests/test_data_type_functions.py::test_finfo[complex64-dtype] +array_api_tests/test_data_type_functions.py::test_finfo[complex64-array] +array_api_tests/test_data_type_functions.py::test_finfo[complex64-array.dtype] # Array methods and attributes not already on np.ndarray cannot be wrapped array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__] diff --git a/numpy-dev-xfails.txt b/numpy-dev-xfails.txt index 98659710..31bcb63b 100644 --- a/numpy-dev-xfails.txt +++ b/numpy-dev-xfails.txt @@ -1,5 +1,11 @@ -# finfo(float32).eps returns float32 but should return float -array_api_tests/test_data_type_functions.py::test_finfo[float32] +# attributes are np.float32 instead of float +# (see also https://github.com/data-apis/array-api/issues/405) +array_api_tests/test_data_type_functions.py::test_finfo[float32-dtype] +array_api_tests/test_data_type_functions.py::test_finfo[float32-array] +array_api_tests/test_data_type_functions.py::test_finfo[float32-array.dtype] +array_api_tests/test_data_type_functions.py::test_finfo[complex64-dtype] +array_api_tests/test_data_type_functions.py::test_finfo[complex64-array] +array_api_tests/test_data_type_functions.py::test_finfo[complex64-array.dtype] # The test suite cannot properly get the signature for vecdot # https://github.com/numpy/numpy/pull/26237 diff --git a/numpy-xfails.txt b/numpy-xfails.txt index 0885dcaa..0810aea6 100644 --- a/numpy-xfails.txt +++ b/numpy-xfails.txt @@ -1,5 +1,11 @@ -# finfo(float32).eps returns float32 but should return float -array_api_tests/test_data_type_functions.py::test_finfo[float32] +# attributes are np.float32 instead of float +# (see also https://github.com/data-apis/array-api/issues/405) +array_api_tests/test_data_type_functions.py::test_finfo[float32-dtype] +array_api_tests/test_data_type_functions.py::test_finfo[float32-array] +array_api_tests/test_data_type_functions.py::test_finfo[float32-array.dtype] +array_api_tests/test_data_type_functions.py::test_finfo[complex64-dtype] +array_api_tests/test_data_type_functions.py::test_finfo[complex64-array] +array_api_tests/test_data_type_functions.py::test_finfo[complex64-array.dtype] # The test suite cannot properly get the signature for vecdot # https://github.com/numpy/numpy/pull/26237 diff --git a/torch-xfails.txt b/torch-xfails.txt index f8333d90..e556fa4f 100644 --- a/torch-xfails.txt +++ b/torch-xfails.txt @@ -115,6 +115,10 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_round array_api_tests/test_set_functions.py::test_unique_counts array_api_tests/test_set_functions.py::test_unique_values +# finfo/iinfo.dtype is a string instead of a dtype +array_api_tests/test_data_type_functions.py::test_finfo_dtype +array_api_tests/test_data_type_functions.py::test_iinfo_dtype + # 2023.12 support array_api_tests/test_has_names.py::test_has_names[manipulation-repeat] array_api_tests/test_manipulation_functions.py::test_repeat From 37b1c475c98fb092135ef021f11b7f79cd46debd Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 3 Apr 2025 12:42:35 +0100 Subject: [PATCH 16/28] MAINT: validate device on numpy and dask --- array_api_compat/common/_helpers.py | 24 +++++++++++++++++++++--- array_api_compat/dask/array/_aliases.py | 5 ++++- array_api_compat/numpy/_aliases.py | 6 +++--- 3 files changed, 28 insertions(+), 7 deletions(-) diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 6d95069d..67c619b8 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -595,11 +595,29 @@ def your_function(x, y): # backwards compatibility alias get_namespace = array_namespace -def _check_device(xp, device): - if xp == sys.modules.get('numpy'): - if device not in ["cpu", None]: + +def _check_device(bare_xp, device): + """ + Validate dummy device on device-less array backends. + + Notes + ----- + This function is also invoked by CuPy, which does have multiple devices + if there are multiple GPUs available. + However, CuPy multi-device support is currently impossible + without using the global device or a context manager: + + https://github.com/data-apis/array-api-compat/pull/293 + """ + if bare_xp is sys.modules.get('numpy'): + if device not in ("cpu", None): raise ValueError(f"Unsupported device for NumPy: {device!r}") + elif bare_xp is sys.modules.get('dask.array'): + if device not in ("cpu", _DASK_DEVICE, None): + raise ValueError(f"Unsupported device for Dask: {device!r}") + + # Placeholder object to represent the dask device # when the array backend is not the CPU. # (since it is not easy to tell which device a dask array is on) diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index e6eff359..c5cd7489 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -25,7 +25,7 @@ ) import dask.array as da -from ...common import _aliases, array_namespace +from ...common import _aliases, _helpers, array_namespace from ...common._typing import ( Array, Device, @@ -56,6 +56,7 @@ def astype( specification for more details. """ # TODO: respect device keyword? + _helpers._check_device(da, device) if not copy and dtype == x.dtype: return x @@ -86,6 +87,7 @@ def arange( specification for more details. """ # TODO: respect device keyword? + _helpers._check_device(da, device) args = [start] if stop is not None: @@ -155,6 +157,7 @@ def asarray( specification for more details. """ # TODO: respect device keyword? + _helpers._check_device(da, device) if isinstance(obj, da.Array): if dtype is not None and dtype != obj.dtype: diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index 1d084b2b..d5b7feac 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -3,7 +3,7 @@ from typing import Optional, Union from .._internal import get_xp -from ..common import _aliases +from ..common import _aliases, _helpers from ..common._typing import NestedSequence, SupportsBufferProtocol from ._info import __array_namespace_info__ from ._typing import Array, Device, DType @@ -95,8 +95,7 @@ def asarray( See the corresponding documentation in the array library and/or the array API specification for more details. """ - if device not in ["cpu", None]: - raise ValueError(f"Unsupported device for NumPy: {device!r}") + _helpers._check_device(np, device) if hasattr(np, '_CopyMode'): if copy is None: @@ -122,6 +121,7 @@ def astype( copy: bool = True, device: Optional[Device] = None, ) -> Array: + _helpers._check_device(np, device) return x.astype(dtype=dtype, copy=copy) From 2c1cb6b515849048cd062e31462b6a193b81471c Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 3 Apr 2025 13:30:49 +0100 Subject: [PATCH 17/28] BUG: Don't import helpers in namespaces --- array_api_compat/common/_linalg.py | 2 ++ array_api_compat/cupy/__init__.py | 3 --- array_api_compat/dask/array/__init__.py | 1 + array_api_compat/numpy/__init__.py | 9 --------- array_api_compat/numpy/_aliases.py | 2 +- array_api_compat/torch/__init__.py | 6 ++---- tests/test_common.py | 2 +- 7 files changed, 7 insertions(+), 18 deletions(-) diff --git a/array_api_compat/common/_linalg.py b/array_api_compat/common/_linalg.py index c77ee3b8..d1e7ebd8 100644 --- a/array_api_compat/common/_linalg.py +++ b/array_api_compat/common/_linalg.py @@ -174,3 +174,5 @@ def trace( 'svd', 'cholesky', 'matrix_rank', 'pinv', 'matrix_norm', 'matrix_transpose', 'svdvals', 'vecdot', 'vector_norm', 'diagonal', 'trace'] + +_all_ignore = ['math', 'normalize_axis_tuple', 'get_xp', 'np', 'isdtype'] diff --git a/array_api_compat/cupy/__init__.py b/array_api_compat/cupy/__init__.py index 59e01058..9a30f95d 100644 --- a/array_api_compat/cupy/__init__.py +++ b/array_api_compat/cupy/__init__.py @@ -8,9 +8,6 @@ # See the comment in the numpy __init__.py __import__(__package__ + '.linalg') - __import__(__package__ + '.fft') -from ..common._helpers import * # noqa: F401,F403 - __array_api_version__ = '2024.12' diff --git a/array_api_compat/dask/array/__init__.py b/array_api_compat/dask/array/__init__.py index a6e69ad3..bb649306 100644 --- a/array_api_compat/dask/array/__init__.py +++ b/array_api_compat/dask/array/__init__.py @@ -5,5 +5,6 @@ __array_api_version__ = '2024.12' +# See the comment in the numpy __init__.py __import__(__package__ + '.linalg') __import__(__package__ + '.fft') diff --git a/array_api_compat/numpy/__init__.py b/array_api_compat/numpy/__init__.py index 02c55d28..6a5d9867 100644 --- a/array_api_compat/numpy/__init__.py +++ b/array_api_compat/numpy/__init__.py @@ -14,17 +14,8 @@ # It doesn't overwrite np.linalg from above. The import is generated # dynamically so that the library can be vendored. __import__(__package__ + '.linalg') - __import__(__package__ + '.fft') from .linalg import matrix_transpose, vecdot # noqa: F401 -from ..common._helpers import * # noqa: F403 - -try: - # Used in asarray(). Not present in older versions. - from numpy import _CopyMode # noqa: F401 -except ImportError: - pass - __array_api_version__ = '2024.12' diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index 1d084b2b..9e4f1174 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -86,7 +86,7 @@ def asarray( *, dtype: Optional[DType] = None, device: Optional[Device] = None, - copy: "Optional[Union[bool, np._CopyMode]]" = None, + copy: Optional[Union[bool, np._CopyMode]] = None, **kwargs, ) -> Array: """ diff --git a/array_api_compat/torch/__init__.py b/array_api_compat/torch/__init__.py index a985986e..69fd19ce 100644 --- a/array_api_compat/torch/__init__.py +++ b/array_api_compat/torch/__init__.py @@ -9,16 +9,14 @@ or 'cpu' in n or 'backward' in n): continue - exec(n + ' = torch.' + n) + exec(f"{n} = torch.{n}") +del n # These imports may overwrite names from the import * above. from ._aliases import * # noqa: F403 # See the comment in the numpy __init__.py __import__(__package__ + '.linalg') - __import__(__package__ + '.fft') -from ..common._helpers import * # noqa: F403 - __array_api_version__ = '2024.12' diff --git a/tests/test_common.py b/tests/test_common.py index bbf14572..54024d47 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -276,7 +276,7 @@ def test_asarray_copy(library): is_lib_func = globals()[is_array_functions[library]] all = xp.all if library != 'dask.array' else lambda x: xp.all(x).compute() - if library == 'numpy' and xp.__version__[0] < '2' and not hasattr(xp, '_CopyMode') : + if library == 'numpy' and xp.__version__[0] < '2' and not hasattr(np, "_CopyMode"): supports_copy_false_other_ns = False supports_copy_false_same_ns = False elif library == 'cupy': From 621494be1bd8682f1d76ae874272c12464953d3d Mon Sep 17 00:00:00 2001 From: Guido Imperiale Date: Fri, 4 Apr 2025 12:21:20 +0100 Subject: [PATCH 18/28] ENH: torch.asarray device propagation (#299) --- array_api_compat/torch/_aliases.py | 31 ++++++++++++++++++++++++------ array_api_compat/torch/_typing.py | 5 ++--- 2 files changed, 27 insertions(+), 9 deletions(-) diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 982500b0..0891525a 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -2,12 +2,13 @@ from functools import reduce as _reduce, wraps as _wraps from builtins import all as _builtin_all, any as _builtin_any -from typing import List, Optional, Sequence, Tuple, Union +from typing import Any, List, Optional, Sequence, Tuple, Union import torch from .._internal import get_xp from ..common import _aliases +from ..common._typing import NestedSequence, SupportsBufferProtocol from ._info import __array_namespace_info__ from ._typing import Array, Device, DType @@ -207,6 +208,28 @@ def can_cast(from_: Union[DType, Array], to: DType, /) -> bool: remainder = _two_arg(torch.remainder) subtract = _two_arg(torch.subtract) + +def asarray( + obj: ( + Array + | bool | int | float | complex + | NestedSequence[bool | int | float | complex] + | SupportsBufferProtocol + ), + /, + *, + dtype: DType | None = None, + device: Device | None = None, + copy: bool | None = None, + **kwargs: Any, +) -> Array: + # torch.asarray does not respect input->output device propagation + # https://github.com/pytorch/pytorch/issues/150199 + if device is None and isinstance(obj, torch.Tensor): + device = obj.device + return torch.asarray(obj, dtype=dtype, device=device, copy=copy, **kwargs) + + # These wrappers are mostly based on the fact that pytorch uses 'dim' instead # of 'axis'. @@ -282,7 +305,6 @@ def prod(x: Array, dtype: Optional[DType] = None, keepdims: bool = False, **kwargs) -> Array: - x = torch.asarray(x) ndim = x.ndim # https://github.com/pytorch/pytorch/issues/29137. Separate from the logic @@ -318,7 +340,6 @@ def sum(x: Array, dtype: Optional[DType] = None, keepdims: bool = False, **kwargs) -> Array: - x = torch.asarray(x) ndim = x.ndim # https://github.com/pytorch/pytorch/issues/29137. @@ -348,7 +369,6 @@ def any(x: Array, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, **kwargs) -> Array: - x = torch.asarray(x) ndim = x.ndim if axis == (): return x.to(torch.bool) @@ -373,7 +393,6 @@ def all(x: Array, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, **kwargs) -> Array: - x = torch.asarray(x) ndim = x.ndim if axis == (): return x.to(torch.bool) @@ -816,7 +835,7 @@ def sign(x: Array, /) -> Array: return out -__all__ = ['__array_namespace_info__', 'result_type', 'can_cast', +__all__ = ['__array_namespace_info__', 'asarray', 'result_type', 'can_cast', 'permute_dims', 'bitwise_invert', 'newaxis', 'conj', 'add', 'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'copysign', 'count_nonzero', diff --git a/array_api_compat/torch/_typing.py b/array_api_compat/torch/_typing.py index 29ad3fa7..52670871 100644 --- a/array_api_compat/torch/_typing.py +++ b/array_api_compat/torch/_typing.py @@ -1,4 +1,3 @@ -__all__ = ["Array", "DType", "Device"] +__all__ = ["Array", "Device", "DType"] -from torch import dtype as DType, Tensor as Array -from ..common._typing import Device +from torch import device as Device, dtype as DType, Tensor as Array From c629a64c928bd76fdf0bec28a1399467801364be Mon Sep 17 00:00:00 2001 From: crusaderky Date: Mon, 7 Apr 2025 10:27:05 +0100 Subject: [PATCH 19/28] Simplify test parametrization --- cupy-xfails.txt | 8 ++------ dask-xfails.txt | 8 ++------ numpy-1-21-xfails.txt | 8 ++------ numpy-1-26-xfails.txt | 8 ++------ numpy-dev-xfails.txt | 8 ++------ numpy-xfails.txt | 8 ++------ 6 files changed, 12 insertions(+), 36 deletions(-) diff --git a/cupy-xfails.txt b/cupy-xfails.txt index f4cd1e36..a30572f8 100644 --- a/cupy-xfails.txt +++ b/cupy-xfails.txt @@ -16,12 +16,8 @@ array_api_tests/test_creation_functions.py::test_asarray_arrays # attributes are np.float32 instead of float # (see also https://github.com/data-apis/array-api/issues/405) -array_api_tests/test_data_type_functions.py::test_finfo[float32-dtype] -array_api_tests/test_data_type_functions.py::test_finfo[float32-array] -array_api_tests/test_data_type_functions.py::test_finfo[float32-array.dtype] -array_api_tests/test_data_type_functions.py::test_finfo[complex64-dtype] -array_api_tests/test_data_type_functions.py::test_finfo[complex64-array] -array_api_tests/test_data_type_functions.py::test_finfo[complex64-array.dtype] +array_api_tests/test_data_type_functions.py::test_finfo[float32] +array_api_tests/test_data_type_functions.py::test_finfo[complex64] # Some array attributes are missing, and we do not wrap the array object array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__] diff --git a/dask-xfails.txt b/dask-xfails.txt index abab825c..932aeada 100644 --- a/dask-xfails.txt +++ b/dask-xfails.txt @@ -14,12 +14,8 @@ array_api_tests/test_creation_functions.py::test_eye # attributes are np.float32 instead of float # (see also https://github.com/data-apis/array-api/issues/405) -array_api_tests/test_data_type_functions.py::test_finfo[float32-dtype] -array_api_tests/test_data_type_functions.py::test_finfo[float32-array] -array_api_tests/test_data_type_functions.py::test_finfo[float32-array.dtype] -array_api_tests/test_data_type_functions.py::test_finfo[complex64-dtype] -array_api_tests/test_data_type_functions.py::test_finfo[complex64-array] -array_api_tests/test_data_type_functions.py::test_finfo[complex64-array.dtype] +array_api_tests/test_data_type_functions.py::test_finfo[float32] +array_api_tests/test_data_type_functions.py::test_finfo[complex64] # out[-1]=dask.array but should be some floating number # (I think the test is not forcing the op to be computed?) diff --git a/numpy-1-21-xfails.txt b/numpy-1-21-xfails.txt index 93a90757..66443a73 100644 --- a/numpy-1-21-xfails.txt +++ b/numpy-1-21-xfails.txt @@ -3,12 +3,8 @@ array_api_tests/test_creation_functions.py::test_asarray_arrays # attributes are np.float32 instead of float # (see also https://github.com/data-apis/array-api/issues/405) -array_api_tests/test_data_type_functions.py::test_finfo[float32-dtype] -array_api_tests/test_data_type_functions.py::test_finfo[float32-array] -array_api_tests/test_data_type_functions.py::test_finfo[float32-array.dtype] -array_api_tests/test_data_type_functions.py::test_finfo[complex64-dtype] -array_api_tests/test_data_type_functions.py::test_finfo[complex64-array] -array_api_tests/test_data_type_functions.py::test_finfo[complex64-array.dtype] +array_api_tests/test_data_type_functions.py::test_finfo[float32] +array_api_tests/test_data_type_functions.py::test_finfo[complex64] # Array methods and attributes not already on np.ndarray cannot be wrapped array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__] diff --git a/numpy-1-26-xfails.txt b/numpy-1-26-xfails.txt index 84916e73..ed95083a 100644 --- a/numpy-1-26-xfails.txt +++ b/numpy-1-26-xfails.txt @@ -1,11 +1,7 @@ # attributes are np.float32 instead of float # (see also https://github.com/data-apis/array-api/issues/405) -array_api_tests/test_data_type_functions.py::test_finfo[float32-dtype] -array_api_tests/test_data_type_functions.py::test_finfo[float32-array] -array_api_tests/test_data_type_functions.py::test_finfo[float32-array.dtype] -array_api_tests/test_data_type_functions.py::test_finfo[complex64-dtype] -array_api_tests/test_data_type_functions.py::test_finfo[complex64-array] -array_api_tests/test_data_type_functions.py::test_finfo[complex64-array.dtype] +array_api_tests/test_data_type_functions.py::test_finfo[float32] +array_api_tests/test_data_type_functions.py::test_finfo[complex64] # Array methods and attributes not already on np.ndarray cannot be wrapped array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__] diff --git a/numpy-dev-xfails.txt b/numpy-dev-xfails.txt index 31bcb63b..972d2346 100644 --- a/numpy-dev-xfails.txt +++ b/numpy-dev-xfails.txt @@ -1,11 +1,7 @@ # attributes are np.float32 instead of float # (see also https://github.com/data-apis/array-api/issues/405) -array_api_tests/test_data_type_functions.py::test_finfo[float32-dtype] -array_api_tests/test_data_type_functions.py::test_finfo[float32-array] -array_api_tests/test_data_type_functions.py::test_finfo[float32-array.dtype] -array_api_tests/test_data_type_functions.py::test_finfo[complex64-dtype] -array_api_tests/test_data_type_functions.py::test_finfo[complex64-array] -array_api_tests/test_data_type_functions.py::test_finfo[complex64-array.dtype] +array_api_tests/test_data_type_functions.py::test_finfo[float32] +array_api_tests/test_data_type_functions.py::test_finfo[complex64] # The test suite cannot properly get the signature for vecdot # https://github.com/numpy/numpy/pull/26237 diff --git a/numpy-xfails.txt b/numpy-xfails.txt index 0810aea6..0f09985e 100644 --- a/numpy-xfails.txt +++ b/numpy-xfails.txt @@ -1,11 +1,7 @@ # attributes are np.float32 instead of float # (see also https://github.com/data-apis/array-api/issues/405) -array_api_tests/test_data_type_functions.py::test_finfo[float32-dtype] -array_api_tests/test_data_type_functions.py::test_finfo[float32-array] -array_api_tests/test_data_type_functions.py::test_finfo[float32-array.dtype] -array_api_tests/test_data_type_functions.py::test_finfo[complex64-dtype] -array_api_tests/test_data_type_functions.py::test_finfo[complex64-array] -array_api_tests/test_data_type_functions.py::test_finfo[complex64-array.dtype] +array_api_tests/test_data_type_functions.py::test_finfo[float32] +array_api_tests/test_data_type_functions.py::test_finfo[complex64] # The test suite cannot properly get the signature for vecdot # https://github.com/numpy/numpy/pull/26237 From bff3bf467d6f126015179558f1b8c71242014cbc Mon Sep 17 00:00:00 2001 From: Guido Imperiale Date: Tue, 15 Apr 2025 11:48:45 +0100 Subject: [PATCH 20/28] Drop Python 3.9; test on Python 3.13; drop NumPy 1.21; skip CUDA install (#304) reviewed at https://github.com/data-apis/array-api-compat/pull/304 --- .github/workflows/array-api-tests-dask.yml | 2 +- .../workflows/array-api-tests-numpy-1-21.yml | 11 --- .../workflows/array-api-tests-numpy-1-22.yml | 12 +++ .../workflows/array-api-tests-numpy-1-26.yml | 1 + .../workflows/array-api-tests-numpy-dev.yml | 1 + .../array-api-tests-numpy-latest.yml | 3 +- .github/workflows/array-api-tests-torch.yml | 4 +- .github/workflows/array-api-tests.yml | 23 +++-- .github/workflows/tests.yml | 58 ++++++++----- array_api_compat/cupy/_typing.py | 2 +- array_api_compat/dask/array/_aliases.py | 2 +- array_api_compat/numpy/_aliases.py | 18 ++-- array_api_compat/numpy/_typing.py | 2 +- docs/supported-array-libraries.md | 17 +--- ...y-1-21-xfails.txt => numpy-1-22-xfails.txt | 83 +++---------------- numpy-1-26-xfails.txt | 3 - numpy-skips.txt | 11 --- numpy-xfails.txt | 4 +- pyproject.toml | 16 ++-- tests/test_common.py | 5 +- tests/test_dask.py | 6 +- torch-skips.txt | 11 --- torch-xfails.txt | 4 + 23 files changed, 114 insertions(+), 185 deletions(-) delete mode 100644 .github/workflows/array-api-tests-numpy-1-21.yml create mode 100644 .github/workflows/array-api-tests-numpy-1-22.yml rename numpy-1-21-xfails.txt => numpy-1-22-xfails.txt (68%) diff --git a/.github/workflows/array-api-tests-dask.yml b/.github/workflows/array-api-tests-dask.yml index 2ad98586..afc67975 100644 --- a/.github/workflows/array-api-tests-dask.yml +++ b/.github/workflows/array-api-tests-dask.yml @@ -7,7 +7,6 @@ jobs: uses: ./.github/workflows/array-api-tests.yml with: package-name: dask - package-version: '>= 2024.9.0' module-name: dask.array extra-requires: numpy # Dask is substantially slower then other libraries on unit tests. @@ -16,3 +15,4 @@ jobs: # flakiness. Before changes to dask-xfails.txt or dask-skips.txt, please run # the full test suite with at least 200 examples. pytest-extra-args: --max-examples=5 + python-versions: '[''3.10'', ''3.13'']' diff --git a/.github/workflows/array-api-tests-numpy-1-21.yml b/.github/workflows/array-api-tests-numpy-1-21.yml deleted file mode 100644 index 2d81c3cd..00000000 --- a/.github/workflows/array-api-tests-numpy-1-21.yml +++ /dev/null @@ -1,11 +0,0 @@ -name: Array API Tests (NumPy 1.21) - -on: [push, pull_request] - -jobs: - array-api-tests-numpy-1-21: - uses: ./.github/workflows/array-api-tests.yml - with: - package-name: numpy - package-version: '== 1.21.*' - xfails-file-extra: '-1-21' diff --git a/.github/workflows/array-api-tests-numpy-1-22.yml b/.github/workflows/array-api-tests-numpy-1-22.yml new file mode 100644 index 00000000..d8f60432 --- /dev/null +++ b/.github/workflows/array-api-tests-numpy-1-22.yml @@ -0,0 +1,12 @@ +name: Array API Tests (NumPy 1.22) + +on: [push, pull_request] + +jobs: + array-api-tests-numpy-1-22: + uses: ./.github/workflows/array-api-tests.yml + with: + package-name: numpy + package-version: '== 1.22.*' + xfails-file-extra: '-1-22' + python-versions: '[''3.10'']' diff --git a/.github/workflows/array-api-tests-numpy-1-26.yml b/.github/workflows/array-api-tests-numpy-1-26.yml index 660935f0..33780760 100644 --- a/.github/workflows/array-api-tests-numpy-1-26.yml +++ b/.github/workflows/array-api-tests-numpy-1-26.yml @@ -9,3 +9,4 @@ jobs: package-name: numpy package-version: '== 1.26.*' xfails-file-extra: '-1-26' + python-versions: '[''3.10'', ''3.12'']' diff --git a/.github/workflows/array-api-tests-numpy-dev.yml b/.github/workflows/array-api-tests-numpy-dev.yml index eef4269d..d6de1a53 100644 --- a/.github/workflows/array-api-tests-numpy-dev.yml +++ b/.github/workflows/array-api-tests-numpy-dev.yml @@ -9,3 +9,4 @@ jobs: package-name: numpy extra-requires: '--pre --extra-index https://pypi.anaconda.org/scientific-python-nightly-wheels/simple' xfails-file-extra: '-dev' + python-versions: '[''3.11'', ''3.13'']' diff --git a/.github/workflows/array-api-tests-numpy-latest.yml b/.github/workflows/array-api-tests-numpy-latest.yml index 36984345..4d3667f6 100644 --- a/.github/workflows/array-api-tests-numpy-latest.yml +++ b/.github/workflows/array-api-tests-numpy-latest.yml @@ -1,4 +1,4 @@ -name: Array API Tests (NumPy Latest) +name: Array API Tests (NumPy latest) on: [push, pull_request] @@ -7,3 +7,4 @@ jobs: uses: ./.github/workflows/array-api-tests.yml with: package-name: numpy + python-versions: '[''3.10'', ''3.13'']' diff --git a/.github/workflows/array-api-tests-torch.yml b/.github/workflows/array-api-tests-torch.yml index 56ab81a3..ac20df25 100644 --- a/.github/workflows/array-api-tests-torch.yml +++ b/.github/workflows/array-api-tests-torch.yml @@ -1,4 +1,4 @@ -name: Array API Tests (PyTorch Latest) +name: Array API Tests (PyTorch CPU) on: [push, pull_request] @@ -7,5 +7,7 @@ jobs: uses: ./.github/workflows/array-api-tests.yml with: package-name: torch + extra-requires: '--index-url https://download.pytorch.org/whl/cpu' extra-env-vars: | ARRAY_API_TESTS_SKIP_DTYPES=uint16,uint32,uint64 + python-versions: '[''3.10'', ''3.13'']' diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index 6ace193a..31bedde6 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -16,6 +16,10 @@ on: required: false type: string default: '>= 0' + python-versions: + required: true + type: string + description: JSON array of Python versions to test against. pytest-extra-args: required: false type: string @@ -30,7 +34,7 @@ on: extra-env-vars: required: false type: string - description: "Multiline string of environment variables to set for the test run." + description: Multiline string of environment variables to set for the test run. env: PYTEST_ARGS: "--max-examples 200 -v -rxXfE --ci ${{ inputs.pytest-extra-args }} --hypothesis-disable-deadline --durations 10" @@ -39,41 +43,44 @@ jobs: tests: runs-on: ubuntu-latest strategy: + fail-fast: false matrix: - # Min version of dask we need dropped support for Python 3.9 - # There is no numpy git tip for Python 3.9 or 3.10 - python-version: ${{ (inputs.package-name == 'dask' && fromJson('[''3.10'', ''3.11'', ''3.12'']')) || (inputs.package-name == 'numpy' && inputs.xfails-file-extra == '-dev' && fromJson('[''3.11'', ''3.12'']')) || fromJson('[''3.9'', ''3.10'', ''3.11'', ''3.12'']') }} + python-version: ${{ fromJson(inputs.python-versions) }} steps: - name: Checkout array-api-compat uses: actions/checkout@v4 with: path: array-api-compat + - name: Checkout array-api-tests uses: actions/checkout@v4 with: repository: data-apis/array-api-tests submodules: 'true' path: array-api-tests + - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} + - name: Set Extra Environment Variables # Set additional environment variables if provided if: inputs.extra-env-vars run: | echo "${{ inputs.extra-env-vars }}" >> $GITHUB_ENV + - name: Install dependencies - # NumPy 1.21 doesn't support Python 3.11. There doesn't seem to be a way - # to put this in the numpy 1.21 config file. - if: "! ((matrix.python-version == '3.11' || matrix.python-version == '3.12') && inputs.package-name == 'numpy' && contains(inputs.package-version, '1.21'))" run: | python -m pip install --upgrade pip python -m pip install '${{ inputs.package-name }} ${{ inputs.package-version }}' ${{ inputs.extra-requires }} python -m pip install -r ${GITHUB_WORKSPACE}/array-api-tests/requirements.txt + + - name: Dump pip environment + run: pip freeze + - name: Run the array API testsuite (${{ inputs.package-name }}) - if: "! ((matrix.python-version == '3.11' || matrix.python-version == '3.12') && inputs.package-name == 'numpy' && contains(inputs.package-version, '1.21'))" env: ARRAY_API_TESTS_MODULE: array_api_compat.${{ inputs.module-name || inputs.package-name }} ARRAY_API_TESTS_VERSION: 2024.12 diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 54f6f402..81a05b3f 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -4,15 +4,24 @@ jobs: tests: runs-on: ubuntu-latest strategy: + fail-fast: false matrix: - python-version: ['3.9', '3.10', '3.11', '3.12'] - numpy-version: ['1.21', '1.26', '2.0', 'dev'] - exclude: - - python-version: '3.11' - numpy-version: '1.21' - - python-version: '3.12' - numpy-version: '1.21' - fail-fast: true + include: + - numpy-version: '1.22' + python-version: '3.10' + - numpy-version: '1.26' + python-version: '3.10' + - numpy-version: '1.26' + python-version: '3.12' + - numpy-version: 'latest' + python-version: '3.10' + - numpy-version: 'latest' + python-version: '3.13' + - numpy-version: 'dev' + python-version: '3.11' + - numpy-version: 'dev' + python-version: '3.13' + steps: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 @@ -21,22 +30,29 @@ jobs: - name: Install Dependencies run: | python -m pip install --upgrade pip + python -m pip install pytest + if [ "${{ matrix.numpy-version }}" == "dev" ]; then - PIP_EXTRA='numpy --pre --extra-index-url https://pypi.anaconda.org/scientific-python-nightly-wheels/simple' - elif [ "${{ matrix.numpy-version }}" == "1.21" ]; then - PIP_EXTRA='numpy==1.21.*' + python -m pip install numpy --pre --extra-index-url https://pypi.anaconda.org/scientific-python-nightly-wheels/simple + elif [ "${{ matrix.numpy-version }}" == "1.22" ]; then + python -m pip install 'numpy==1.22.*' + elif [ "${{ matrix.numpy-version }}" == "1.26" ]; then + python -m pip install 'numpy==1.26.*' else - PIP_EXTRA='numpy==1.26.*' + # Don't `pip install .[dev]` as it would pull in the whole torch cuda stack + python -m pip install array-api-strict dask[array] jax[cpu] numpy sparse + python -m pip install torch --index-url https://download.pytorch.org/whl/cpu + if [ "${{ matrix.python-version }}" != "3.13" ]; then + # onnx wheels are not available on Python 3.13 at the moment of writing + python -m pip install ndonnx + fi fi - python -m pip install .[dev] $PIP_EXTRA + - name: Dump pip environment + run: pip freeze - - name: Run Tests - run: | - if [[ "${{ matrix.numpy-version }}" == "1.21" || "${{ matrix.numpy-version }}" == "dev" ]]; then - PYTEST_EXTRA=(-k "numpy and not jax and not torch and not dask and not sparse") - fi - pytest -v "${PYTEST_EXTRA[@]}" + - name: Test it installs + run: python -m pip install . - # Make sure it installs - python -m pip install . + - name: Run Tests + run: pytest -v diff --git a/array_api_compat/cupy/_typing.py b/array_api_compat/cupy/_typing.py index 66af5d19..d8e49ca7 100644 --- a/array_api_compat/cupy/_typing.py +++ b/array_api_compat/cupy/_typing.py @@ -10,7 +10,7 @@ from cupy.cuda.device import Device if TYPE_CHECKING: - # NumPy 1.x on Python 3.9 and 3.10 fails to parse np.dtype[] + # NumPy 1.x on Python 3.10 fails to parse np.dtype[] DType = cp.dtype[ cp.intp | cp.int8 diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index 4733b1a6..e7ddde78 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -147,7 +147,7 @@ def asarray( *, dtype: Optional[DType] = None, device: Optional[Device] = None, - copy: Optional[Union[bool, np._CopyMode]] = None, + copy: Optional[bool] = None, **kwargs, ) -> Array: """ diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index 59a0b8f4..d1fd46a1 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -99,18 +99,12 @@ def asarray( """ _helpers._check_device(np, device) - if hasattr(np, '_CopyMode'): - if copy is None: - copy = np._CopyMode.IF_NEEDED - elif copy is False: - copy = np._CopyMode.NEVER - elif copy is True: - copy = np._CopyMode.ALWAYS - else: - # Not present in older NumPys. In this case, we cannot really support - # copy=False. - if copy is False: - raise NotImplementedError("asarray(copy=False) requires a newer version of NumPy.") + if copy is None: + copy = np._CopyMode.IF_NEEDED + elif copy is False: + copy = np._CopyMode.NEVER + elif copy is True: + copy = np._CopyMode.ALWAYS return np.array(obj, copy=copy, dtype=dtype, **kwargs) diff --git a/array_api_compat/numpy/_typing.py b/array_api_compat/numpy/_typing.py index 6a18a3b2..a6c96924 100644 --- a/array_api_compat/numpy/_typing.py +++ b/array_api_compat/numpy/_typing.py @@ -10,7 +10,7 @@ Device = Literal["cpu"] if TYPE_CHECKING: - # NumPy 1.x on Python 3.9 and 3.10 fails to parse np.dtype[] + # NumPy 1.x on Python 3.10 fails to parse np.dtype[] DType = np.dtype[ np.intp | np.int8 diff --git a/docs/supported-array-libraries.md b/docs/supported-array-libraries.md index 4519c4ac..46fcdc27 100644 --- a/docs/supported-array-libraries.md +++ b/docs/supported-array-libraries.md @@ -36,23 +36,16 @@ deviations from the standard should be noted: 50](https://numpy.org/neps/nep-0050-scalar-promotion.html) and https://github.com/numpy/numpy/issues/22341) -- `asarray()` does not support `copy=False`. - - Functions which are not wrapped may not have the same type annotations as the spec. - Functions which are not wrapped may not use positional-only arguments. -The minimum supported NumPy version is 1.21. However, this older version of +The minimum supported NumPy version is 1.22. However, this older version of NumPy has a few issues: - `unique_*` will not compare nans as unequal. -- `finfo()` has no `smallest_normal`. - No `from_dlpack` or `__dlpack__`. -- `argmax()` and `argmin()` do not have `keepdims`. -- `qr()` doesn't support matrix stacks. -- `asarray()` doesn't support `copy=True` (as noted above, `copy=False` is not - supported even in the latest NumPy). - Type promotion behavior will be value based for 0-D arrays (and there is no `NPY_PROMOTION_STATE=weak` to disable this). @@ -72,8 +65,8 @@ version. attribute in the spec. Use the {func}`~.size()` helper function as a portable workaround. -- PyTorch does not have unsigned integer types other than `uint8`, and no - attempt is made to implement them here. +- PyTorch has incomplete support for unsigned integer types other + than `uint8`, and no attempt is made to implement them here. - PyTorch has type promotion semantics that differ from the array API specification for 0-D tensor objects. The array functions in this wrapper @@ -100,8 +93,6 @@ version. - As with NumPy, type annotations and positional-only arguments may not exactly match the spec for functions that are not wrapped at all. -The minimum supported PyTorch version is 1.13. - (jax-support)= ## [JAX](https://jax.readthedocs.io/en/latest/) @@ -131,8 +122,6 @@ For `linalg`, several methods are missing, for example: - `matrix_rank` Other methods may only be partially implemented or return incorrect results at times. -The minimum supported Dask version is 2023.12.0. - (sparse-support)= ## [Sparse](https://sparse.pydata.org/en/stable/) diff --git a/numpy-1-21-xfails.txt b/numpy-1-22-xfails.txt similarity index 68% rename from numpy-1-21-xfails.txt rename to numpy-1-22-xfails.txt index 66443a73..93edf311 100644 --- a/numpy-1-21-xfails.txt +++ b/numpy-1-22-xfails.txt @@ -1,6 +1,3 @@ -# asarray(copy=False) is not yet implemented -array_api_tests/test_creation_functions.py::test_asarray_arrays - # attributes are np.float32 instead of float # (see also https://github.com/data-apis/array-api/issues/405) array_api_tests/test_data_type_functions.py::test_finfo[float32] @@ -39,38 +36,24 @@ array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and # https://github.com/numpy/numpy/issues/21213 array_api_tests/meta/test_hypothesis_helpers.py::test_symmetric_matrices -# NumPy 1.21 specific XFAILS +# NumPy 1.22 specific XFAILS ############################ -# finfo has no smallest_normal -array_api_tests/test_data_type_functions.py::test_finfo - -# dlpack stuff -array_api_tests/test_has_names.py::test_has_names[creation-from_dlpack] -array_api_tests/test_has_names.py::test_has_names[array_method-__dlpack__] -array_api_tests/test_has_names.py::test_has_names[array_method-__dlpack_device__] -array_api_tests/test_signatures.py::test_func_signature[from_dlpack] -array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] -array_api_tests/test_signatures.py::test_array_method_signature[__dlpack_device__] - -# qr() doesn't support matrix stacks -array_api_tests/test_linalg.py::test_qr - # cross has some promotion bug that is fixed in newer numpy versions array_api_tests/test_linalg.py::test_cross +# linspace(-0.0, -1.0, num=1) returns +0.0 instead of -0.0. +# Fixed in newer numpy versions. +array_api_tests/test_creation_functions.py::test_linspace + # vector_norm with ord=-1 which has since been fixed # https://github.com/numpy/numpy/issues/21083 array_api_tests/test_linalg.py::test_vector_norm -# argmax and argmin do not support keepdims -array_api_tests/test_searching_functions.py::test_argmax -array_api_tests/test_searching_functions.py::test_argmin -array_api_tests/test_signatures.py::test_func_signature[argmax] -array_api_tests/test_signatures.py::test_func_signature[argmin] - -# NumPy 1.21 doesn't support NPY_PROMOTION_STATE=weak, so many tests fail with +# NumPy 1.22 doesn't support NPY_PROMOTION_STATE=weak, so many tests fail with # type promotion issues +# NOTE: some of these may not fail until one runs array-api-tests with +# --max-examples 100000 array_api_tests/test_manipulation_functions.py::test_concat array_api_tests/test_operators_and_elementwise_functions.py::test_add[__add__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_add[__add__(x1, x2)] @@ -109,6 +92,7 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[_ array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[floor_divide(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_greater[__gt__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_greater[greater(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[__ge__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[greater_equal(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_hypot array_api_tests/test_operators_and_elementwise_functions.py::test_less[__lt__(x1, x2)] @@ -136,53 +120,11 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__isu array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__sub__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__sub__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[subtract(x1, x2)] -array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i < 0 and x2_i is +0) -> -infinity] -array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i < 0 and x2_i is -0) -> +infinity] -array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i > 0 and x2_i is +0) -> +infinity] -array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i > 0 and x2_i is -0) -> -infinity] -array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +0 and x2_i < 0) -> -0] -array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +0 and x2_i > 0) -> +0] -array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -0 and x2_i < 0) -> +0] -array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -0 and x2_i > 0) -> -0] -array_api_tests/test_special_cases.py::test_binary[__mod__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> x2_i] -array_api_tests/test_special_cases.py::test_binary[__mod__(isfinite(x1_i) and x1_i < 0 and x2_i is -infinity) -> x1_i] -array_api_tests/test_special_cases.py::test_binary[__mod__(isfinite(x1_i) and x1_i > 0 and x2_i is +infinity) -> x1_i] -array_api_tests/test_special_cases.py::test_binary[__mod__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> x2_i] -array_api_tests/test_special_cases.py::test_binary[__mod__(x1_i is +0 and x2_i < 0) -> -0] -array_api_tests/test_special_cases.py::test_binary[__mod__(x1_i is +0 and x2_i > 0) -> +0] -array_api_tests/test_special_cases.py::test_binary[__mod__(x1_i is -0 and x2_i < 0) -> -0] -array_api_tests/test_special_cases.py::test_binary[__mod__(x1_i is -0 and x2_i > 0) -> +0] -array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i < 0 and isfinite(x1_i) and isfinite(x2_i) and not x2_i.is_integer()) -> NaN] -array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is +0 and x2_i < 0) -> +infinity] -array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is +0 and x2_i > 0) -> +0] -array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -0 and x2_i < 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity] -array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -0 and x2_i < 0 and x2_i.is_integer() and x2_i % 2 == 1) -> -infinity] -array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -0 and x2_i > 0 and x2_i.is_integer() and x2_i % 2 == 1) -> -0] -array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -infinity and x2_i < 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0] -array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -infinity and x2_i < 0 and x2_i.is_integer() and x2_i % 2 == 1) -> -0] -array_api_tests/test_special_cases.py::test_binary[atan2(x1_i < 0 and x2_i is +0) -> roughly -pi/2] -array_api_tests/test_special_cases.py::test_binary[atan2(x1_i < 0 and x2_i is -0) -> roughly -pi/2] -array_api_tests/test_special_cases.py::test_binary[atan2(x1_i > 0 and x2_i is +0) -> roughly +pi/2] -array_api_tests/test_special_cases.py::test_binary[atan2(x1_i > 0 and x2_i is -0) -> roughly +pi/2] -array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i < 0 and x2_i is +0) -> -infinity] -array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i < 0 and x2_i is -0) -> +infinity] -array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i > 0 and x2_i is +0) -> +infinity] -array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i > 0 and x2_i is -0) -> -infinity] -array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +0 and x2_i < 0) -> -0] -array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +0 and x2_i > 0) -> +0] -array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -0 and x2_i < 0) -> +0] -array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -0 and x2_i > 0) -> -0] -array_api_tests/test_special_cases.py::test_binary[remainder(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> x2_i] -array_api_tests/test_special_cases.py::test_binary[remainder(isfinite(x1_i) and x1_i < 0 and x2_i is -infinity) -> x1_i] -array_api_tests/test_special_cases.py::test_binary[remainder(isfinite(x1_i) and x1_i > 0 and x2_i is +infinity) -> x1_i] -array_api_tests/test_special_cases.py::test_binary[remainder(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> x2_i] -array_api_tests/test_special_cases.py::test_binary[remainder(x1_i is +0 and x2_i < 0) -> -0] -array_api_tests/test_special_cases.py::test_binary[remainder(x1_i is +0 and x2_i > 0) -> +0] -array_api_tests/test_special_cases.py::test_binary[remainder(x1_i is -0 and x2_i < 0) -> -0] -array_api_tests/test_special_cases.py::test_binary[remainder(x1_i is -0 and x2_i > 0) -> +0] +array_api_tests/test_searching_functions.py::test_where array_api_tests/test_special_cases.py::test_iop[__iadd__(x1_i is -0 and x2_i is -0) -> -0] # 2023.12 support +array_api_tests/test_has_names.py::test_has_names[creation-from_dlpack] array_api_tests/test_signatures.py::test_func_signature[from_dlpack] array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] # uint64 repeats not supported @@ -215,6 +157,3 @@ array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] - -# numpy < 2 bug: type promotion of asarray([], 'float32') and (np.finfo(float32).max + 1) -> float64 -array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real diff --git a/numpy-1-26-xfails.txt b/numpy-1-26-xfails.txt index ed95083a..51e1a658 100644 --- a/numpy-1-26-xfails.txt +++ b/numpy-1-26-xfails.txt @@ -69,6 +69,3 @@ array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] - -# numpy < 2 bug: type promotion of asarray([], 'float32') and (finfo(float32).max + 1) gives float64 not float32 -array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real diff --git a/numpy-skips.txt b/numpy-skips.txt index cbf7235b..e69de29b 100644 --- a/numpy-skips.txt +++ b/numpy-skips.txt @@ -1,11 +0,0 @@ -# These tests cause a core dump on CI, so we have to skip them entirely -array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__ifloordiv__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__ifloordiv__(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[floor_divide(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__imod__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__imod__(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__mod__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__mod__(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[remainder(x1, x2)] diff --git a/numpy-xfails.txt b/numpy-xfails.txt index 0f09985e..632b4ec3 100644 --- a/numpy-xfails.txt +++ b/numpy-xfails.txt @@ -9,8 +9,6 @@ array_api_tests/test_signatures.py::test_func_signature[vecdot] array_api_tests/test_signatures.py::test_extension_func_signature[linalg.vecdot] # 2023.12 support -array_api_tests/test_signatures.py::test_func_signature[from_dlpack] -array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] # uint64 repeats not supported array_api_tests/test_manipulation_functions.py::test_repeat @@ -20,6 +18,8 @@ array_api_tests/test_signatures.py::test_func_signature[bitwise_left_shift] array_api_tests/test_signatures.py::test_func_signature[bitwise_or] array_api_tests/test_signatures.py::test_func_signature[bitwise_right_shift] array_api_tests/test_signatures.py::test_func_signature[bitwise_xor] + +# Stubs have a comment: (**note**: libraries may return ``NaN`` to match Python behavior.); Apparently, NumPy does just that array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] diff --git a/pyproject.toml b/pyproject.toml index f17c720f..aacebd11 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,13 +7,12 @@ name = "array-api-compat" dynamic = ["version"] description = "A wrapper around NumPy and other array libraries to make them compatible with the Array API standard" readme = "README.md" -requires-python = ">=3.9" +requires-python = ">=3.10" license = "MIT" authors = [{name = "Consortium for Python Data API Standards"}] classifiers = [ "Operating System :: OS Independent", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", @@ -24,11 +23,14 @@ classifiers = [ [project.optional-dependencies] cupy = ["cupy"] -dask = ["dask"] +dask = ["dask>=2024.9.0"] jax = ["jax"] -numpy = ["numpy"] +# Note: array-api-compat follows scikit-learn minimum dependencies, which support +# much older versions of NumPy than what SPEC0 recommends. +numpy = ["numpy>=1.22"] pytorch = ["torch"] sparse = ["sparse>=0.15.1"] +ndonnx = ["ndonnx"] docs = [ "furo", "linkify-it-py", @@ -39,13 +41,13 @@ docs = [ ] dev = [ "array-api-strict", - "dask[array]", + "dask[array]>=2024.9.0", "jax[cpu]", - "numpy", + "numpy>=1.22", "pytest", "torch", "sparse>=0.15.1", - "ndonnx; python_version>=\"3.10\"" + "ndonnx" ] [project.urls] diff --git a/tests/test_common.py b/tests/test_common.py index 54024d47..6b1aa160 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -276,10 +276,7 @@ def test_asarray_copy(library): is_lib_func = globals()[is_array_functions[library]] all = xp.all if library != 'dask.array' else lambda x: xp.all(x).compute() - if library == 'numpy' and xp.__version__[0] < '2' and not hasattr(np, "_CopyMode"): - supports_copy_false_other_ns = False - supports_copy_false_same_ns = False - elif library == 'cupy': + if library == 'cupy': supports_copy_false_other_ns = False supports_copy_false_same_ns = False elif library == 'dask.array': diff --git a/tests/test_dask.py b/tests/test_dask.py index 69c738f6..fb0a84d4 100644 --- a/tests/test_dask.py +++ b/tests/test_dask.py @@ -1,6 +1,5 @@ from contextlib import contextmanager -import array_api_strict import numpy as np import pytest @@ -171,9 +170,10 @@ def test_sort_argsort_chunk_size(xp, func, shape, chunks): @pytest.mark.parametrize("func", ["sort", "argsort"]) def test_sort_argsort_meta(xp, func): """Test meta-namespace other than numpy""" - typ = type(array_api_strict.asarray(0)) + mxp = pytest.importorskip("array_api_strict") + typ = type(mxp.asarray(0)) a = da.random.random(10) - b = a.map_blocks(array_api_strict.asarray) + b = a.map_blocks(mxp.asarray) assert isinstance(b._meta, typ) c = getattr(xp, func)(b) assert isinstance(c._meta, typ) diff --git a/torch-skips.txt b/torch-skips.txt index cbf7235b..e69de29b 100644 --- a/torch-skips.txt +++ b/torch-skips.txt @@ -1,11 +0,0 @@ -# These tests cause a core dump on CI, so we have to skip them entirely -array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__ifloordiv__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__ifloordiv__(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[floor_divide(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__imod__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__imod__(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__mod__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__mod__(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[remainder(x1, x2)] diff --git a/torch-xfails.txt b/torch-xfails.txt index e556fa4f..abee88b1 100644 --- a/torch-xfails.txt +++ b/torch-xfails.txt @@ -29,6 +29,10 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__trued array_api_tests/test_operators_and_elementwise_functions.py::test_equal[__eq__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_not_equal[__ne__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x, s)] +array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__ifloordiv__(x, s)] +array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__ifloordiv__(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[floor_divide(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_greater[__gt__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[__ge__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_less[__lt__(x1, x2)] From 00e7cceb338025d9428af2bb6afbe7eaac8cf414 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Tue, 15 Apr 2025 11:53:21 +0200 Subject: [PATCH 21/28] BUG: add torch.repeat --- array_api_compat/torch/_aliases.py | 7 ++++++- torch-xfails.txt | 3 +-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index a2ed1449..0a604b8c 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -574,6 +574,11 @@ def count_nonzero( return result +# "repeat" is torch.repeat_interleave; also the dim argument +def repeat(x: Array, repeats: int | Array, /, *, axis: int | None = None) -> Array: + return torch.repeat_interleave(x, repeats, axis) + + def where( condition: Array, x1: Array | bool | int | float | complex, @@ -854,6 +859,6 @@ def sign(x: Array, /) -> Array: 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult', 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values', 'matmul', 'matrix_transpose', 'vecdot', 'tensordot', 'isdtype', - 'take', 'take_along_axis', 'sign', 'finfo', 'iinfo'] + 'take', 'take_along_axis', 'sign', 'finfo', 'iinfo', 'repeat'] _all_ignore = ['torch', 'get_xp'] diff --git a/torch-xfails.txt b/torch-xfails.txt index e556fa4f..ab11f457 100644 --- a/torch-xfails.txt +++ b/torch-xfails.txt @@ -120,9 +120,8 @@ array_api_tests/test_data_type_functions.py::test_finfo_dtype array_api_tests/test_data_type_functions.py::test_iinfo_dtype # 2023.12 support -array_api_tests/test_has_names.py::test_has_names[manipulation-repeat] +# https://github.com/pytorch/pytorch/issues/151311: torch.repeat_interleave rejects short integers array_api_tests/test_manipulation_functions.py::test_repeat -array_api_tests/test_signatures.py::test_func_signature[repeat] # Argument 'device' missing from signature array_api_tests/test_signatures.py::test_func_signature[from_dlpack] # Argument 'max_version' missing from signature From d743dc13e16a2328e3ce0951dd3633629b6537a6 Mon Sep 17 00:00:00 2001 From: Guido Imperiale Date: Tue, 15 Apr 2025 11:54:15 +0100 Subject: [PATCH 22/28] MAINT: `__array_namespace_info__` docstrings tweaks (#300) --- array_api_compat/common/_aliases.py | 2 +- array_api_compat/cupy/_info.py | 20 ++++++++++---- array_api_compat/dask/array/_info.py | 19 +++++++------ array_api_compat/numpy/_info.py | 8 +++--- array_api_compat/torch/_info.py | 41 ++++++++++++++++++---------- 5 files changed, 56 insertions(+), 34 deletions(-) diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 46cbb359..351b5bd6 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -18,7 +18,7 @@ # These functions are modified from the NumPy versions. -# Creation functions add the device keyword (which does nothing for NumPy) +# Creation functions add the device keyword (which does nothing for NumPy and Dask) def arange( start: Union[int, float], diff --git a/array_api_compat/cupy/_info.py b/array_api_compat/cupy/_info.py index 790621e4..78e48a33 100644 --- a/array_api_compat/cupy/_info.py +++ b/array_api_compat/cupy/_info.py @@ -26,6 +26,7 @@ complex128, ) + class __array_namespace_info__: """ Get the array API inspection namespace for CuPy. @@ -49,7 +50,7 @@ class __array_namespace_info__: Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.default_dtypes() {'real floating': cupy.float64, 'complex floating': cupy.complex128, @@ -94,13 +95,13 @@ def capabilities(self): >>> info = xp.__array_namespace_info__() >>> info.capabilities() {'boolean indexing': True, - 'data-dependent shapes': True} + 'data-dependent shapes': True, + 'max dimensions': 64} """ return { "boolean indexing": True, "data-dependent shapes": True, - # 'max rank' will be part of the 2024.12 standard "max dimensions": 64, } @@ -117,7 +118,7 @@ def default_device(self): Returns ------- - device : str + device : Device The default device used for new CuPy arrays. Examples @@ -126,6 +127,15 @@ def default_device(self): >>> info.default_device() Device(0) + Notes + ----- + This method returns the static default device when CuPy is initialized. + However, the *current* device used by creation functions (``empty`` etc.) + can be changed globally or with a context manager. + + See Also + -------- + https://github.com/data-apis/array-api/issues/835 """ return cuda.Device(0) @@ -312,7 +322,7 @@ def devices(self): Returns ------- - devices : list of str + devices : list[Device] The devices supported by CuPy. See Also diff --git a/array_api_compat/dask/array/_info.py b/array_api_compat/dask/array/_info.py index fc70b5a2..614f43d9 100644 --- a/array_api_compat/dask/array/_info.py +++ b/array_api_compat/dask/array/_info.py @@ -50,7 +50,7 @@ class __array_namespace_info__: Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.default_dtypes() {'real floating': dask.float64, 'complex floating': dask.complex128, @@ -103,10 +103,11 @@ def capabilities(self): Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.capabilities() {'boolean indexing': True, - 'data-dependent shapes': True} + 'data-dependent shapes': True, + 'max dimensions': 64} """ return { @@ -130,12 +131,12 @@ def default_device(self): Returns ------- - device : str + device : Device The default device used for new Dask arrays. Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.default_device() 'cpu' @@ -173,7 +174,7 @@ def default_dtypes(self, *, device=None): Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.default_dtypes() {'real floating': dask.float64, 'complex floating': dask.complex128, @@ -239,7 +240,7 @@ def dtypes(self, *, device=None, kind=None): Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.dtypes(kind='signed integer') {'int8': dask.int8, 'int16': dask.int16, @@ -335,7 +336,7 @@ def devices(self): Returns ------- - devices : list of str + devices : list[Device] The devices supported by Dask. See Also @@ -347,7 +348,7 @@ def devices(self): Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.devices() ['cpu', DASK_DEVICE] diff --git a/array_api_compat/numpy/_info.py b/array_api_compat/numpy/_info.py index e706d118..365855b8 100644 --- a/array_api_compat/numpy/_info.py +++ b/array_api_compat/numpy/_info.py @@ -94,13 +94,13 @@ def capabilities(self): >>> info = np.__array_namespace_info__() >>> info.capabilities() {'boolean indexing': True, - 'data-dependent shapes': True} + 'data-dependent shapes': True, + 'max dimensions': 64} """ return { "boolean indexing": True, "data-dependent shapes": True, - # 'max rank' will be part of the 2024.12 standard "max dimensions": 64, } @@ -119,7 +119,7 @@ def default_device(self): Returns ------- - device : str + device : Device The default device used for new NumPy arrays. Examples @@ -326,7 +326,7 @@ def devices(self): Returns ------- - devices : list of str + devices : list[Device] The devices supported by NumPy. See Also diff --git a/array_api_compat/torch/_info.py b/array_api_compat/torch/_info.py index 34fbcb21..818e5d37 100644 --- a/array_api_compat/torch/_info.py +++ b/array_api_compat/torch/_info.py @@ -34,7 +34,7 @@ class __array_namespace_info__: Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.default_dtypes() {'real floating': numpy.float64, 'complex floating': numpy.complex128, @@ -76,16 +76,16 @@ def capabilities(self): Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.capabilities() {'boolean indexing': True, - 'data-dependent shapes': True} + 'data-dependent shapes': True, + 'max dimensions': 64} """ return { "boolean indexing": True, "data-dependent shapes": True, - # 'max rank' will be part of the 2024.12 standard "max dimensions": 64, } @@ -102,15 +102,24 @@ def default_device(self): Returns ------- - device : str + device : Device The default device used for new PyTorch arrays. Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.default_device() - 'cpu' + device(type='cpu') + Notes + ----- + This method returns the static default device when PyTorch is initialized. + However, the *current* device used by creation functions (``empty`` etc.) + can be changed at runtime. + + See Also + -------- + https://github.com/data-apis/array-api/issues/835 """ return torch.device("cpu") @@ -120,9 +129,9 @@ def default_dtypes(self, *, device=None): Parameters ---------- - device : str, optional - The device to get the default data types for. For PyTorch, only - ``'cpu'`` is allowed. + device : Device, optional + The device to get the default data types for. + Unused for PyTorch, as all devices use the same default dtypes. Returns ------- @@ -139,7 +148,7 @@ def default_dtypes(self, *, device=None): Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.default_dtypes() {'real floating': torch.float32, 'complex floating': torch.complex64, @@ -250,8 +259,9 @@ def dtypes(self, *, device=None, kind=None): Parameters ---------- - device : str, optional + device : Device, optional The device to get the data types for. + Unused for PyTorch, as all devices use the same dtypes. kind : str or tuple of str, optional The kind of data types to return. If ``None``, all data types are returned. If a string, only data types of that kind are returned. @@ -287,7 +297,7 @@ def dtypes(self, *, device=None, kind=None): Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.dtypes(kind='signed integer') {'int8': numpy.int8, 'int16': numpy.int16, @@ -310,7 +320,7 @@ def devices(self): Returns ------- - devices : list of str + devices : list[Device] The devices supported by PyTorch. See Also @@ -322,7 +332,7 @@ def devices(self): Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.devices() [device(type='cpu'), device(type='mps', index=0), device(type='meta')] @@ -333,6 +343,7 @@ def devices(self): # device: try: torch.device('notadevice') + raise AssertionError("unreachable") # pragma: nocover except RuntimeError as e: # The error message is something like: # "Expected one of cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, ideep, hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, hpu, mtia, privateuseone device type at start of device string: notadevice" From 9194c5cb7706e08f1a1092aece1fce76ac6e089a Mon Sep 17 00:00:00 2001 From: Guido Imperiale Date: Tue, 15 Apr 2025 12:04:09 +0100 Subject: [PATCH 23/28] MAINT: simplify `torch` dtype promotion (#303) reviewed at https://github.com/data-apis/array-api-compat/pull/303 --- array_api_compat/torch/_aliases.py | 99 ++++++++++++------------------ 1 file changed, 40 insertions(+), 59 deletions(-) diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index a2ed1449..5370803f 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -35,47 +35,23 @@ torch.complex128, } -_promotion_table = { - # bool - (torch.bool, torch.bool): torch.bool, +_promotion_table = { # ints - (torch.int8, torch.int8): torch.int8, (torch.int8, torch.int16): torch.int16, (torch.int8, torch.int32): torch.int32, (torch.int8, torch.int64): torch.int64, - (torch.int16, torch.int8): torch.int16, - (torch.int16, torch.int16): torch.int16, (torch.int16, torch.int32): torch.int32, (torch.int16, torch.int64): torch.int64, - (torch.int32, torch.int8): torch.int32, - (torch.int32, torch.int16): torch.int32, - (torch.int32, torch.int32): torch.int32, (torch.int32, torch.int64): torch.int64, - (torch.int64, torch.int8): torch.int64, - (torch.int64, torch.int16): torch.int64, - (torch.int64, torch.int32): torch.int64, - (torch.int64, torch.int64): torch.int64, - # uints - (torch.uint8, torch.uint8): torch.uint8, # ints and uints (mixed sign) - (torch.int8, torch.uint8): torch.int16, - (torch.int16, torch.uint8): torch.int16, - (torch.int32, torch.uint8): torch.int32, - (torch.int64, torch.uint8): torch.int64, (torch.uint8, torch.int8): torch.int16, (torch.uint8, torch.int16): torch.int16, (torch.uint8, torch.int32): torch.int32, (torch.uint8, torch.int64): torch.int64, # floats - (torch.float32, torch.float32): torch.float32, (torch.float32, torch.float64): torch.float64, - (torch.float64, torch.float32): torch.float64, - (torch.float64, torch.float64): torch.float64, # complexes - (torch.complex64, torch.complex64): torch.complex64, (torch.complex64, torch.complex128): torch.complex128, - (torch.complex128, torch.complex64): torch.complex128, - (torch.complex128, torch.complex128): torch.complex128, # Mixed float and complex (torch.float32, torch.complex64): torch.complex64, (torch.float32, torch.complex128): torch.complex128, @@ -83,6 +59,9 @@ (torch.float64, torch.complex128): torch.complex128, } +_promotion_table.update({(b, a): c for (a, b), c in _promotion_table.items()}) +_promotion_table.update({(a, a): a for a in _array_api_dtypes}) + def _two_arg(f): @_wraps(f) @@ -150,13 +129,18 @@ def result_type( return _reduce(_result_type, others + scalars) -def _result_type(x, y): +def _result_type( + x: Array | DType | bool | int | float | complex, + y: Array | DType | bool | int | float | complex, +) -> DType: if not (isinstance(x, _py_scalars) or isinstance(y, _py_scalars)): - xdt = x.dtype if not isinstance(x, torch.dtype) else x - ydt = y.dtype if not isinstance(y, torch.dtype) else y + xdt = x if isinstance(x, torch.dtype) else x.dtype + ydt = y if isinstance(y, torch.dtype) else y.dtype - if (xdt, ydt) in _promotion_table: + try: return _promotion_table[xdt, ydt] + except KeyError: + pass # This doesn't result_type(dtype, dtype) for non-array API dtypes # because torch.result_type only accepts tensors. This does however, allow @@ -301,6 +285,25 @@ def _reduce_multiple_axes(f, x, axis, keepdims=False, **kwargs): out = torch.unsqueeze(out, a) return out + +def _sum_prod_no_axis(x: Array, dtype: DType | None) -> Array: + """ + Implements `sum(..., axis=())` and `prod(..., axis=())`. + + Works around https://github.com/pytorch/pytorch/issues/29137 + """ + if dtype is not None: + return x.clone() if dtype == x.dtype else x.to(dtype) + + # We can't upcast uint8 according to the spec because there is no + # torch.uint64, so at least upcast to int64 which is what prod does + # when axis=None. + if x.dtype in (torch.uint8, torch.int8, torch.int16, torch.int32): + return x.to(torch.int64) + + return x.clone() + + def prod(x: Array, /, *, @@ -308,20 +311,9 @@ def prod(x: Array, dtype: Optional[DType] = None, keepdims: bool = False, **kwargs) -> Array: - ndim = x.ndim - # https://github.com/pytorch/pytorch/issues/29137. Separate from the logic - # below because it still needs to upcast. if axis == (): - if dtype is None: - # We can't upcast uint8 according to the spec because there is no - # torch.uint64, so at least upcast to int64 which is what sum does - # when axis=None. - if x.dtype in [torch.int8, torch.int16, torch.int32, torch.uint8]: - return x.to(torch.int64) - return x.clone() - return x.to(dtype) - + return _sum_prod_no_axis(x, dtype) # torch.prod doesn't support multiple axes # (https://github.com/pytorch/pytorch/issues/56586). if isinstance(axis, tuple): @@ -330,7 +322,7 @@ def prod(x: Array, # torch doesn't support keepdims with axis=None # (https://github.com/pytorch/pytorch/issues/71209) res = torch.prod(x, dtype=dtype, **kwargs) - res = _axis_none_keepdims(res, ndim, keepdims) + res = _axis_none_keepdims(res, x.ndim, keepdims) return res return torch.prod(x, axis, dtype=dtype, keepdims=keepdims, **kwargs) @@ -343,25 +335,14 @@ def sum(x: Array, dtype: Optional[DType] = None, keepdims: bool = False, **kwargs) -> Array: - ndim = x.ndim - # https://github.com/pytorch/pytorch/issues/29137. - # Make sure it upcasts. if axis == (): - if dtype is None: - # We can't upcast uint8 according to the spec because there is no - # torch.uint64, so at least upcast to int64 which is what sum does - # when axis=None. - if x.dtype in [torch.int8, torch.int16, torch.int32, torch.uint8]: - return x.to(torch.int64) - return x.clone() - return x.to(dtype) - + return _sum_prod_no_axis(x, dtype) if axis is None: # torch doesn't support keepdims with axis=None # (https://github.com/pytorch/pytorch/issues/71209) res = torch.sum(x, dtype=dtype, **kwargs) - res = _axis_none_keepdims(res, ndim, keepdims) + res = _axis_none_keepdims(res, x.ndim, keepdims) return res return torch.sum(x, axis, dtype=dtype, keepdims=keepdims, **kwargs) @@ -372,7 +353,7 @@ def any(x: Array, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, **kwargs) -> Array: - ndim = x.ndim + if axis == (): return x.to(torch.bool) # torch.any doesn't support multiple axes @@ -384,7 +365,7 @@ def any(x: Array, # torch doesn't support keepdims with axis=None # (https://github.com/pytorch/pytorch/issues/71209) res = torch.any(x, **kwargs) - res = _axis_none_keepdims(res, ndim, keepdims) + res = _axis_none_keepdims(res, x.ndim, keepdims) return res.to(torch.bool) # torch.any doesn't return bool for uint8 @@ -396,7 +377,7 @@ def all(x: Array, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, **kwargs) -> Array: - ndim = x.ndim + if axis == (): return x.to(torch.bool) # torch.all doesn't support multiple axes @@ -408,7 +389,7 @@ def all(x: Array, # torch doesn't support keepdims with axis=None # (https://github.com/pytorch/pytorch/issues/71209) res = torch.all(x, **kwargs) - res = _axis_none_keepdims(res, ndim, keepdims) + res = _axis_none_keepdims(res, x.ndim, keepdims) return res.to(torch.bool) # torch.all doesn't return bool for uint8 From b94efc1f5e490a23c0ca74aafb93cc3118471f46 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Tue, 15 Apr 2025 14:37:45 +0200 Subject: [PATCH 24/28] TST: skip testing nextafter with scalars on torch --- torch-xfails.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/torch-xfails.txt b/torch-xfails.txt index f8333d90..538403a3 100644 --- a/torch-xfails.txt +++ b/torch-xfails.txt @@ -144,6 +144,7 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_sc array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[logaddexp] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[maximum] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[minimum] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[nextafter] # https://github.com/pytorch/pytorch/issues/149815 array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[equal] From 205c967d658de24b2738dcae8d91684a1f99d2cd Mon Sep 17 00:00:00 2001 From: Joren Hammudoglu Date: Thu, 17 Apr 2025 20:14:48 +0200 Subject: [PATCH 25/28] TYP: Type annotations overhaul, episode 2 (#288) * TYP: annotate `_internal.get_xp` (and curse at `ParamSpec` for being so useless) * TYP: fix (or ignore) typing errors in `common._helpers` (and curse at cupy) * TYP: fix typing errors in `common._fft` * TYP: fix typing errors in `common._aliases` * TYP: fix typing errors in `common._linalg` * TYP: fix/ignore typing errors in `numpy.__init__` * TYP: fix typing errors in `numpy._typing` * TYP: fix typing errors in `numpy._aliases` * TYP: fix typing errors in `numpy._info` * TYP: fix typing errors in `numpy._fft` * TYP: it's a bad idea to import `TypeAlias` from `typing` on `python<3.10` * TYP: it's also a bad idea to import `TypeGuard` from `typing` on `python<3.10` * TYP: don't scare the prehistoric `dtype` from numpy 1.21 * TYP: dust off the DeLorean * TYP: figure out how to drive a DeLorean * TYP: apply review suggestions Co-authored-by: crusaderky * TYP: sprinkle some `TypeAlias`es and `Final`s around * TYP: `__dir__` * TYP: fix typing errors in `numpy.linalg` * TYP: add a `common._typing.Capabilities` typed dict type * TYP: `__array_namespace_info__` helper types * TYP: `dask.array` typing fixes and improvements * STY: give the `=` some breathing room Co-authored-by: Lucas Colley * STY: apply review suggestions Co-authored-by: lucascolley --------- Co-authored-by: crusaderky Co-authored-by: Lucas Colley --- array_api_compat/_internal.py | 25 +- array_api_compat/common/__init__.py | 2 +- array_api_compat/common/_aliases.py | 331 +++++++++++++++--------- array_api_compat/common/_fft.py | 69 ++--- array_api_compat/common/_helpers.py | 287 +++++++++++++------- array_api_compat/common/_linalg.py | 110 ++++++-- array_api_compat/common/_typing.py | 148 ++++++++++- array_api_compat/dask/array/__init__.py | 8 +- array_api_compat/dask/array/_aliases.py | 162 +++++++----- array_api_compat/dask/array/_info.py | 96 +++++-- array_api_compat/dask/array/linalg.py | 22 +- array_api_compat/numpy/__init__.py | 28 +- array_api_compat/numpy/_aliases.py | 86 +++--- array_api_compat/numpy/_info.py | 42 ++- array_api_compat/numpy/_typing.py | 35 ++- array_api_compat/numpy/fft.py | 16 +- array_api_compat/numpy/linalg.py | 97 +++++-- 17 files changed, 1076 insertions(+), 488 deletions(-) diff --git a/array_api_compat/_internal.py b/array_api_compat/_internal.py index 170a1ff9..cd8d939f 100644 --- a/array_api_compat/_internal.py +++ b/array_api_compat/_internal.py @@ -2,10 +2,16 @@ Internal helpers """ +from collections.abc import Callable from functools import wraps from inspect import signature +from types import ModuleType +from typing import TypeVar -def get_xp(xp): +_T = TypeVar("_T") + + +def get_xp(xp: ModuleType) -> Callable[[Callable[..., _T]], Callable[..., _T]]: """ Decorator to automatically replace xp with the corresponding array module. @@ -22,14 +28,14 @@ def func(x, /, xp, kwarg=None): """ - def inner(f): + def inner(f: Callable[..., _T], /) -> Callable[..., _T]: @wraps(f) - def wrapped_f(*args, **kwargs): + def wrapped_f(*args: object, **kwargs: object) -> object: return f(*args, xp=xp, **kwargs) sig = signature(f) new_sig = sig.replace( - parameters=[sig.parameters[i] for i in sig.parameters if i != "xp"] + parameters=[par for i, par in sig.parameters.items() if i != "xp"] ) if wrapped_f.__doc__ is None: @@ -40,7 +46,14 @@ def wrapped_f(*args, **kwargs): specification for more details. """ - wrapped_f.__signature__ = new_sig - return wrapped_f + wrapped_f.__signature__ = new_sig # pyright: ignore[reportAttributeAccessIssue] + return wrapped_f # pyright: ignore[reportReturnType] return inner + + +__all__ = ["get_xp"] + + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/common/__init__.py b/array_api_compat/common/__init__.py index 91ab1c40..82360807 100644 --- a/array_api_compat/common/__init__.py +++ b/array_api_compat/common/__init__.py @@ -1 +1 @@ -from ._helpers import * # noqa: F403 +from ._helpers import * # noqa: F403 diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 351b5bd6..8ea9162a 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -5,158 +5,170 @@ from __future__ import annotations import inspect -from typing import Any, NamedTuple, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Sequence, cast +from ._helpers import _check_device, array_namespace +from ._helpers import device as _get_device +from ._helpers import is_cupy_namespace as _is_cupy_namespace from ._typing import Array, Device, DType, Namespace -from ._helpers import ( - array_namespace, - _check_device, - device as _get_device, - is_cupy_namespace as _is_cupy_namespace -) +if TYPE_CHECKING: + # TODO: import from typing (requires Python >=3.13) + from typing_extensions import TypeIs # These functions are modified from the NumPy versions. # Creation functions add the device keyword (which does nothing for NumPy and Dask) + def arange( - start: Union[int, float], + start: float, /, - stop: Optional[Union[int, float]] = None, - step: Union[int, float] = 1, + stop: float | None = None, + step: float = 1, *, xp: Namespace, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, ) -> Array: _check_device(xp, device) return xp.arange(start, stop=stop, step=step, dtype=dtype, **kwargs) + def empty( - shape: Union[int, Tuple[int, ...]], + shape: int | tuple[int, ...], xp: Namespace, *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, ) -> Array: _check_device(xp, device) return xp.empty(shape, dtype=dtype, **kwargs) + def empty_like( x: Array, /, xp: Namespace, *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, ) -> Array: _check_device(xp, device) return xp.empty_like(x, dtype=dtype, **kwargs) + def eye( n_rows: int, - n_cols: Optional[int] = None, + n_cols: int | None = None, /, *, xp: Namespace, k: int = 0, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, ) -> Array: _check_device(xp, device) return xp.eye(n_rows, M=n_cols, k=k, dtype=dtype, **kwargs) + def full( - shape: Union[int, Tuple[int, ...]], - fill_value: bool | int | float | complex, + shape: int | tuple[int, ...], + fill_value: complex, xp: Namespace, *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, ) -> Array: _check_device(xp, device) return xp.full(shape, fill_value, dtype=dtype, **kwargs) + def full_like( x: Array, /, - fill_value: bool | int | float | complex, + fill_value: complex, *, xp: Namespace, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, ) -> Array: _check_device(xp, device) return xp.full_like(x, fill_value, dtype=dtype, **kwargs) + def linspace( - start: Union[int, float], - stop: Union[int, float], + start: float, + stop: float, /, num: int, *, xp: Namespace, - dtype: Optional[DType] = None, - device: Optional[Device] = None, + dtype: DType | None = None, + device: Device | None = None, endpoint: bool = True, - **kwargs, + **kwargs: object, ) -> Array: _check_device(xp, device) return xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint, **kwargs) + def ones( - shape: Union[int, Tuple[int, ...]], + shape: int | tuple[int, ...], xp: Namespace, *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, ) -> Array: _check_device(xp, device) return xp.ones(shape, dtype=dtype, **kwargs) + def ones_like( x: Array, /, xp: Namespace, *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, ) -> Array: _check_device(xp, device) return xp.ones_like(x, dtype=dtype, **kwargs) + def zeros( - shape: Union[int, Tuple[int, ...]], + shape: int | tuple[int, ...], xp: Namespace, *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, ) -> Array: _check_device(xp, device) return xp.zeros(shape, dtype=dtype, **kwargs) + def zeros_like( x: Array, /, xp: Namespace, *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, ) -> Array: _check_device(xp, device) return xp.zeros_like(x, dtype=dtype, **kwargs) + # np.unique() is split into four functions in the array API: # unique_all, unique_counts, unique_inverse, and unique_values (this is done # to remove polymorphic return types). @@ -164,6 +176,7 @@ def zeros_like( # The functions here return namedtuples (np.unique() returns a normal # tuple). + # Note that these named tuples aren't actually part of the standard namespace, # but I don't see any issue with exporting the names here regardless. class UniqueAllResult(NamedTuple): @@ -188,10 +201,11 @@ def _unique_kwargs(xp: Namespace) -> dict[str, bool]: # trying to parse version numbers, just check if equal_nan is in the # signature. s = inspect.signature(xp.unique) - if 'equal_nan' in s.parameters: - return {'equal_nan': False} + if "equal_nan" in s.parameters: + return {"equal_nan": False} return {} + def unique_all(x: Array, /, xp: Namespace) -> UniqueAllResult: kwargs = _unique_kwargs(xp) values, indices, inverse_indices, counts = xp.unique( @@ -215,11 +229,7 @@ def unique_all(x: Array, /, xp: Namespace) -> UniqueAllResult: def unique_counts(x: Array, /, xp: Namespace) -> UniqueCountsResult: kwargs = _unique_kwargs(xp) res = xp.unique( - x, - return_counts=True, - return_index=False, - return_inverse=False, - **kwargs + x, return_counts=True, return_index=False, return_inverse=False, **kwargs ) return UniqueCountsResult(*res) @@ -250,51 +260,58 @@ def unique_values(x: Array, /, xp: Namespace) -> Array: **kwargs, ) + # These functions have different keyword argument names + def std( x: Array, /, xp: Namespace, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - correction: Union[int, float] = 0.0, # correction instead of ddof + axis: int | tuple[int, ...] | None = None, + correction: float = 0.0, # correction instead of ddof keepdims: bool = False, - **kwargs, + **kwargs: object, ) -> Array: return xp.std(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs) + def var( x: Array, /, xp: Namespace, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - correction: Union[int, float] = 0.0, # correction instead of ddof + axis: int | tuple[int, ...] | None = None, + correction: float = 0.0, # correction instead of ddof keepdims: bool = False, - **kwargs, + **kwargs: object, ) -> Array: return xp.var(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs) + # cumulative_sum is renamed from cumsum, and adds the include_initial keyword # argument + def cumulative_sum( x: Array, /, xp: Namespace, *, - axis: Optional[int] = None, - dtype: Optional[DType] = None, + axis: int | None = None, + dtype: DType | None = None, include_initial: bool = False, - **kwargs, + **kwargs: object, ) -> Array: wrapped_xp = array_namespace(x) # TODO: The standard is not clear about what should happen when x.ndim == 0. if axis is None: if x.ndim > 1: - raise ValueError("axis must be specified in cumulative_sum for more than one dimension") + raise ValueError( + "axis must be specified in cumulative_sum for more than one dimension" + ) axis = 0 res = xp.cumsum(x, axis=axis, dtype=dtype, **kwargs) @@ -304,7 +321,12 @@ def cumulative_sum( initial_shape = list(x.shape) initial_shape[axis] = 1 res = xp.concatenate( - [wrapped_xp.zeros(shape=initial_shape, dtype=res.dtype, device=_get_device(res)), res], + [ + wrapped_xp.zeros( + shape=initial_shape, dtype=res.dtype, device=_get_device(res) + ), + res, + ], axis=axis, ) return res @@ -315,16 +337,18 @@ def cumulative_prod( /, xp: Namespace, *, - axis: Optional[int] = None, - dtype: Optional[DType] = None, + axis: int | None = None, + dtype: DType | None = None, include_initial: bool = False, - **kwargs, + **kwargs: object, ) -> Array: wrapped_xp = array_namespace(x) if axis is None: if x.ndim > 1: - raise ValueError("axis must be specified in cumulative_prod for more than one dimension") + raise ValueError( + "axis must be specified in cumulative_prod for more than one dimension" + ) axis = 0 res = xp.cumprod(x, axis=axis, dtype=dtype, **kwargs) @@ -334,24 +358,30 @@ def cumulative_prod( initial_shape = list(x.shape) initial_shape[axis] = 1 res = xp.concatenate( - [wrapped_xp.ones(shape=initial_shape, dtype=res.dtype, device=_get_device(res)), res], + [ + wrapped_xp.ones( + shape=initial_shape, dtype=res.dtype, device=_get_device(res) + ), + res, + ], axis=axis, ) return res + # The min and max argument names in clip are different and not optional in numpy, and type # promotion behavior is different. def clip( x: Array, /, - min: Optional[Union[int, float, Array]] = None, - max: Optional[Union[int, float, Array]] = None, + min: float | Array | None = None, + max: float | Array | None = None, *, xp: Namespace, # TODO: np.clip has other ufunc kwargs - out: Optional[Array] = None, + out: Array | None = None, ) -> Array: - def _isscalar(a): + def _isscalar(a: object) -> TypeIs[int | float | None]: return isinstance(a, (int, float, type(None))) min_shape = () if _isscalar(min) else min.shape @@ -378,7 +408,6 @@ def _isscalar(a): # but an answer of 0 might be preferred. See # https://github.com/numpy/numpy/issues/24976 for more discussion on this issue. - # At least handle the case of Python integers correctly (see # https://github.com/numpy/numpy/pull/26892). if wrapped_xp.isdtype(x.dtype, "integral"): @@ -390,6 +419,7 @@ def _isscalar(a): dev = _get_device(x) if out is None: out = wrapped_xp.empty(result_shape, dtype=x.dtype, device=dev) + assert out is not None # workaround for a type-narrowing issue in pyright out[()] = x if min is not None: @@ -407,19 +437,21 @@ def _isscalar(a): # Return a scalar for 0-D return out[()] + # Unlike transpose(), the axes argument to permute_dims() is required. -def permute_dims(x: Array, /, axes: Tuple[int, ...], xp: Namespace) -> Array: +def permute_dims(x: Array, /, axes: tuple[int, ...], xp: Namespace) -> Array: return xp.transpose(x, axes) + # np.reshape calls the keyword argument 'newshape' instead of 'shape' def reshape( x: Array, /, - shape: Tuple[int, ...], + shape: tuple[int, ...], xp: Namespace, *, copy: Optional[bool] = None, - **kwargs, + **kwargs: object, ) -> Array: if copy is True: x = x.copy() @@ -429,6 +461,7 @@ def reshape( return y return xp.reshape(x, shape, **kwargs) + # The descending keyword is new in sort and argsort, and 'kind' replaced with # 'stable' def argsort( @@ -439,13 +472,13 @@ def argsort( axis: int = -1, descending: bool = False, stable: bool = True, - **kwargs, + **kwargs: object, ) -> Array: # Note: this keyword argument is different, and the default is different. # We set it in kwargs like this because numpy.sort uses kind='quicksort' # as the default whereas cupy.sort uses kind=None. if stable: - kwargs['kind'] = "stable" + kwargs["kind"] = "stable" if not descending: res = xp.argsort(x, axis=axis, **kwargs) else: @@ -462,6 +495,7 @@ def argsort( res = max_i - res return res + def sort( x: Array, /, @@ -470,68 +504,78 @@ def sort( axis: int = -1, descending: bool = False, stable: bool = True, - **kwargs, + **kwargs: object, ) -> Array: # Note: this keyword argument is different, and the default is different. # We set it in kwargs like this because numpy.sort uses kind='quicksort' # as the default whereas cupy.sort uses kind=None. if stable: - kwargs['kind'] = "stable" + kwargs["kind"] = "stable" res = xp.sort(x, axis=axis, **kwargs) if descending: res = xp.flip(res, axis=axis) return res + # nonzero should error for zero-dimensional arrays -def nonzero(x: Array, /, xp: Namespace, **kwargs) -> Tuple[Array, ...]: +def nonzero(x: Array, /, xp: Namespace, **kwargs: object) -> tuple[Array, ...]: if x.ndim == 0: raise ValueError("nonzero() does not support zero-dimensional arrays") return xp.nonzero(x, **kwargs) + # ceil, floor, and trunc return integers for integer inputs -def ceil(x: Array, /, xp: Namespace, **kwargs) -> Array: + +def ceil(x: Array, /, xp: Namespace, **kwargs: object) -> Array: if xp.issubdtype(x.dtype, xp.integer): return x return xp.ceil(x, **kwargs) -def floor(x: Array, /, xp: Namespace, **kwargs) -> Array: + +def floor(x: Array, /, xp: Namespace, **kwargs: object) -> Array: if xp.issubdtype(x.dtype, xp.integer): return x return xp.floor(x, **kwargs) -def trunc(x: Array, /, xp: Namespace, **kwargs) -> Array: + +def trunc(x: Array, /, xp: Namespace, **kwargs: object) -> Array: if xp.issubdtype(x.dtype, xp.integer): return x return xp.trunc(x, **kwargs) + # linear algebra functions -def matmul(x1: Array, x2: Array, /, xp: Namespace, **kwargs) -> Array: + +def matmul(x1: Array, x2: Array, /, xp: Namespace, **kwargs: object) -> Array: return xp.matmul(x1, x2, **kwargs) + # Unlike transpose, matrix_transpose only transposes the last two axes. def matrix_transpose(x: Array, /, xp: Namespace) -> Array: if x.ndim < 2: raise ValueError("x must be at least 2-dimensional for matrix_transpose") return xp.swapaxes(x, -1, -2) + def tensordot( x1: Array, x2: Array, /, xp: Namespace, *, - axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2, - **kwargs, + axes: int | tuple[Sequence[int], Sequence[int]] = 2, + **kwargs: object, ) -> Array: return xp.tensordot(x1, x2, axes=axes, **kwargs) + def vecdot(x1: Array, x2: Array, /, xp: Namespace, *, axis: int = -1) -> Array: if x1.shape[axis] != x2.shape[axis]: raise ValueError("x1 and x2 must have the same size along the given axis") - if hasattr(xp, 'broadcast_tensors'): + if hasattr(xp, "broadcast_tensors"): _broadcast = xp.broadcast_tensors else: _broadcast = xp.broadcast_arrays @@ -543,14 +587,16 @@ def vecdot(x1: Array, x2: Array, /, xp: Namespace, *, axis: int = -1) -> Array: res = xp.conj(x1_[..., None, :]) @ x2_[..., None] return res[..., 0, 0] + # isdtype is a new function in the 2022.12 array API specification. + def isdtype( dtype: DType, - kind: Union[DType, str, Tuple[Union[DType, str], ...]], + kind: DType | str | tuple[DType | str, ...], xp: Namespace, *, - _tuple: bool = True, # Disallow nested tuples + _tuple: bool = True, # Disallow nested tuples ) -> bool: """ Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``. @@ -563,21 +609,24 @@ def isdtype( for more details """ if isinstance(kind, tuple) and _tuple: - return any(isdtype(dtype, k, xp, _tuple=False) for k in kind) + return any( + isdtype(dtype, k, xp, _tuple=False) + for k in cast("tuple[DType | str, ...]", kind) + ) elif isinstance(kind, str): - if kind == 'bool': + if kind == "bool": return dtype == xp.bool_ - elif kind == 'signed integer': + elif kind == "signed integer": return xp.issubdtype(dtype, xp.signedinteger) - elif kind == 'unsigned integer': + elif kind == "unsigned integer": return xp.issubdtype(dtype, xp.unsignedinteger) - elif kind == 'integral': + elif kind == "integral": return xp.issubdtype(dtype, xp.integer) - elif kind == 'real floating': + elif kind == "real floating": return xp.issubdtype(dtype, xp.floating) - elif kind == 'complex floating': + elif kind == "complex floating": return xp.issubdtype(dtype, xp.complexfloating) - elif kind == 'numeric': + elif kind == "numeric": return xp.issubdtype(dtype, xp.number) else: raise ValueError(f"Unrecognized data type kind: {kind!r}") @@ -588,24 +637,27 @@ def isdtype( # array_api_strict implementation will be very strict. return dtype == kind + # unstack is a new function in the 2023.12 array API standard -def unstack(x: Array, /, xp: Namespace, *, axis: int = 0) -> Tuple[Array, ...]: +def unstack(x: Array, /, xp: Namespace, *, axis: int = 0) -> tuple[Array, ...]: if x.ndim == 0: raise ValueError("Input array must be at least 1-d.") return tuple(xp.moveaxis(x, axis, 0)) + # numpy 1.26 does not use the standard definition for sign on complex numbers -def sign(x: Array, /, xp: Namespace, **kwargs) -> Array: - if isdtype(x.dtype, 'complex floating', xp=xp): - out = (x/xp.abs(x, **kwargs))[...] + +def sign(x: Array, /, xp: Namespace, **kwargs: object) -> Array: + if isdtype(x.dtype, "complex floating", xp=xp): + out = (x / xp.abs(x, **kwargs))[...] # sign(0) = 0 but the above formula would give nan - out[x == 0+0j] = 0+0j + out[x == 0j] = 0j else: out = xp.sign(x, **kwargs) # CuPy sign() does not propagate nans. See # https://github.com/data-apis/array-api-compat/issues/136 - if _is_cupy_namespace(xp) and isdtype(x.dtype, 'real floating', xp=xp): + if _is_cupy_namespace(xp) and isdtype(x.dtype, "real floating", xp=xp): out[xp.isnan(x)] = xp.nan return out[()] @@ -626,13 +678,50 @@ def iinfo(type_: DType | Array, /, xp: Namespace) -> Any: return xp.iinfo(type_.dtype) -__all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like', - 'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like', - 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult', - 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values', - 'std', 'var', 'cumulative_sum', 'cumulative_prod','clip', 'permute_dims', - 'reshape', 'argsort', 'sort', 'nonzero', 'ceil', 'floor', 'trunc', - 'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype', - 'unstack', 'sign', 'finfo', 'iinfo'] - -_all_ignore = ['inspect', 'array_namespace', 'NamedTuple'] +__all__ = [ + "arange", + "empty", + "empty_like", + "eye", + "full", + "full_like", + "linspace", + "ones", + "ones_like", + "zeros", + "zeros_like", + "UniqueAllResult", + "UniqueCountsResult", + "UniqueInverseResult", + "unique_all", + "unique_counts", + "unique_inverse", + "unique_values", + "std", + "var", + "cumulative_sum", + "cumulative_prod", + "clip", + "permute_dims", + "reshape", + "argsort", + "sort", + "nonzero", + "ceil", + "floor", + "trunc", + "matmul", + "matrix_transpose", + "tensordot", + "vecdot", + "isdtype", + "unstack", + "sign", + "finfo", + "iinfo", +] +_all_ignore = ["inspect", "array_namespace", "NamedTuple"] + + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/common/_fft.py b/array_api_compat/common/_fft.py index bd2a4e1a..18839d37 100644 --- a/array_api_compat/common/_fft.py +++ b/array_api_compat/common/_fft.py @@ -1,9 +1,11 @@ from __future__ import annotations from collections.abc import Sequence -from typing import Union, Optional, Literal +from typing import Literal, TypeAlias -from ._typing import Device, Array, DType, Namespace +from ._typing import Array, Device, DType, Namespace + +_Norm: TypeAlias = Literal["backward", "ortho", "forward"] # Note: NumPy fft functions improperly upcast float32 and complex64 to # complex128, which is why we require wrapping them all here. @@ -13,9 +15,9 @@ def fft( /, xp: Namespace, *, - n: Optional[int] = None, + n: int | None = None, axis: int = -1, - norm: Literal["backward", "ortho", "forward"] = "backward", + norm: _Norm = "backward", ) -> Array: res = xp.fft.fft(x, n=n, axis=axis, norm=norm) if x.dtype in [xp.float32, xp.complex64]: @@ -27,9 +29,9 @@ def ifft( /, xp: Namespace, *, - n: Optional[int] = None, + n: int | None = None, axis: int = -1, - norm: Literal["backward", "ortho", "forward"] = "backward", + norm: _Norm = "backward", ) -> Array: res = xp.fft.ifft(x, n=n, axis=axis, norm=norm) if x.dtype in [xp.float32, xp.complex64]: @@ -41,9 +43,9 @@ def fftn( /, xp: Namespace, *, - s: Sequence[int] = None, - axes: Sequence[int] = None, - norm: Literal["backward", "ortho", "forward"] = "backward", + s: Sequence[int] | None = None, + axes: Sequence[int] | None = None, + norm: _Norm = "backward", ) -> Array: res = xp.fft.fftn(x, s=s, axes=axes, norm=norm) if x.dtype in [xp.float32, xp.complex64]: @@ -55,9 +57,9 @@ def ifftn( /, xp: Namespace, *, - s: Sequence[int] = None, - axes: Sequence[int] = None, - norm: Literal["backward", "ortho", "forward"] = "backward", + s: Sequence[int] | None = None, + axes: Sequence[int] | None = None, + norm: _Norm = "backward", ) -> Array: res = xp.fft.ifftn(x, s=s, axes=axes, norm=norm) if x.dtype in [xp.float32, xp.complex64]: @@ -69,9 +71,9 @@ def rfft( /, xp: Namespace, *, - n: Optional[int] = None, + n: int | None = None, axis: int = -1, - norm: Literal["backward", "ortho", "forward"] = "backward", + norm: _Norm = "backward", ) -> Array: res = xp.fft.rfft(x, n=n, axis=axis, norm=norm) if x.dtype == xp.float32: @@ -83,9 +85,9 @@ def irfft( /, xp: Namespace, *, - n: Optional[int] = None, + n: int | None = None, axis: int = -1, - norm: Literal["backward", "ortho", "forward"] = "backward", + norm: _Norm = "backward", ) -> Array: res = xp.fft.irfft(x, n=n, axis=axis, norm=norm) if x.dtype == xp.complex64: @@ -97,9 +99,9 @@ def rfftn( /, xp: Namespace, *, - s: Sequence[int] = None, - axes: Sequence[int] = None, - norm: Literal["backward", "ortho", "forward"] = "backward", + s: Sequence[int] | None = None, + axes: Sequence[int] | None = None, + norm: _Norm = "backward", ) -> Array: res = xp.fft.rfftn(x, s=s, axes=axes, norm=norm) if x.dtype == xp.float32: @@ -111,9 +113,9 @@ def irfftn( /, xp: Namespace, *, - s: Sequence[int] = None, - axes: Sequence[int] = None, - norm: Literal["backward", "ortho", "forward"] = "backward", + s: Sequence[int] | None = None, + axes: Sequence[int] | None = None, + norm: _Norm = "backward", ) -> Array: res = xp.fft.irfftn(x, s=s, axes=axes, norm=norm) if x.dtype == xp.complex64: @@ -125,9 +127,9 @@ def hfft( /, xp: Namespace, *, - n: Optional[int] = None, + n: int | None = None, axis: int = -1, - norm: Literal["backward", "ortho", "forward"] = "backward", + norm: _Norm = "backward", ) -> Array: res = xp.fft.hfft(x, n=n, axis=axis, norm=norm) if x.dtype in [xp.float32, xp.complex64]: @@ -139,9 +141,9 @@ def ihfft( /, xp: Namespace, *, - n: Optional[int] = None, + n: int | None = None, axis: int = -1, - norm: Literal["backward", "ortho", "forward"] = "backward", + norm: _Norm = "backward", ) -> Array: res = xp.fft.ihfft(x, n=n, axis=axis, norm=norm) if x.dtype in [xp.float32, xp.complex64]: @@ -154,8 +156,8 @@ def fftfreq( xp: Namespace, *, d: float = 1.0, - dtype: Optional[DType] = None, - device: Optional[Device] = None, + dtype: DType | None = None, + device: Device | None = None, ) -> Array: if device not in ["cpu", None]: raise ValueError(f"Unsupported device {device!r}") @@ -170,8 +172,8 @@ def rfftfreq( xp: Namespace, *, d: float = 1.0, - dtype: Optional[DType] = None, - device: Optional[Device] = None, + dtype: DType | None = None, + device: Device | None = None, ) -> Array: if device not in ["cpu", None]: raise ValueError(f"Unsupported device {device!r}") @@ -181,12 +183,12 @@ def rfftfreq( return res def fftshift( - x: Array, /, xp: Namespace, *, axes: Union[int, Sequence[int]] = None + x: Array, /, xp: Namespace, *, axes: int | Sequence[int] | None = None ) -> Array: return xp.fft.fftshift(x, axes=axes) def ifftshift( - x: Array, /, xp: Namespace, *, axes: Union[int, Sequence[int]] = None + x: Array, /, xp: Namespace, *, axes: int | Sequence[int] | None = None ) -> Array: return xp.fft.ifftshift(x, axes=axes) @@ -206,3 +208,6 @@ def ifftshift( "fftshift", "ifftshift", ] + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 67c619b8..db3e4cd7 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -5,33 +5,82 @@ that are in __all__ are intended as additional helper functions for use by end users of the compat library. """ + from __future__ import annotations -import sys -import math import inspect +import math +import sys import warnings -from typing import Optional, Union, Any +from collections.abc import Collection +from typing import ( + TYPE_CHECKING, + Any, + Final, + Literal, + SupportsIndex, + TypeAlias, + TypeGuard, + TypeVar, + cast, + overload, +) + +from ._typing import Array, Device, HasShape, Namespace, SupportsArrayNamespace + +if TYPE_CHECKING: + + import dask.array as da + import jax + import ndonnx as ndx + import numpy as np + import numpy.typing as npt + import sparse # pyright: ignore[reportMissingTypeStubs] + import torch + + # TODO: import from typing (requires Python >=3.13) + from typing_extensions import TypeIs, TypeVar -from ._typing import Array, Device, Namespace + _SizeT = TypeVar("_SizeT", bound = int | None) + _ZeroGradientArray: TypeAlias = npt.NDArray[np.void] + _CupyArray: TypeAlias = Any # cupy has no py.typed -def _is_jax_zero_gradient_array(x: object) -> bool: + _ArrayApiObj: TypeAlias = ( + npt.NDArray[Any] + | da.Array + | jax.Array + | ndx.Array + | sparse.SparseArray + | torch.Tensor + | SupportsArrayNamespace[Any] + | _CupyArray + ) + +_API_VERSIONS_OLD: Final = frozenset({"2021.12", "2022.12", "2023.12"}) +_API_VERSIONS: Final = _API_VERSIONS_OLD | frozenset({"2024.12"}) + + +def _is_jax_zero_gradient_array(x: object) -> TypeGuard[_ZeroGradientArray]: """Return True if `x` is a zero-gradient array. These arrays are a design quirk of Jax that may one day be removed. See https://github.com/google/jax/issues/20620. """ - if 'numpy' not in sys.modules or 'jax' not in sys.modules: + if "numpy" not in sys.modules or "jax" not in sys.modules: return False - import numpy as np import jax + import numpy as np - return isinstance(x, np.ndarray) and x.dtype == jax.float0 + jax_float0 = cast("np.dtype[np.void]", jax.float0) + return ( + isinstance(x, np.ndarray) + and cast("npt.NDArray[np.void]", x).dtype == jax_float0 + ) -def is_numpy_array(x: object) -> bool: +def is_numpy_array(x: object) -> TypeGuard[npt.NDArray[Any]]: """ Return True if `x` is a NumPy array. @@ -53,14 +102,14 @@ def is_numpy_array(x: object) -> bool: is_pydata_sparse_array """ # Avoid importing NumPy if it isn't already - if 'numpy' not in sys.modules: + if "numpy" not in sys.modules: return False import numpy as np # TODO: Should we reject ndarray subclasses? return (isinstance(x, (np.ndarray, np.generic)) - and not _is_jax_zero_gradient_array(x)) + and not _is_jax_zero_gradient_array(x)) # pyright: ignore[reportUnknownArgumentType] # fmt: skip def is_cupy_array(x: object) -> bool: @@ -85,16 +134,16 @@ def is_cupy_array(x: object) -> bool: is_pydata_sparse_array """ # Avoid importing CuPy if it isn't already - if 'cupy' not in sys.modules: + if "cupy" not in sys.modules: return False - import cupy as cp + import cupy as cp # pyright: ignore[reportMissingTypeStubs] # TODO: Should we reject ndarray subclasses? - return isinstance(x, cp.ndarray) + return isinstance(x, cp.ndarray) # pyright: ignore[reportUnknownMemberType] -def is_torch_array(x: object) -> bool: +def is_torch_array(x: object) -> TypeIs[torch.Tensor]: """ Return True if `x` is a PyTorch tensor. @@ -113,7 +162,7 @@ def is_torch_array(x: object) -> bool: is_pydata_sparse_array """ # Avoid importing torch if it isn't already - if 'torch' not in sys.modules: + if "torch" not in sys.modules: return False import torch @@ -122,7 +171,7 @@ def is_torch_array(x: object) -> bool: return isinstance(x, torch.Tensor) -def is_ndonnx_array(x: object) -> bool: +def is_ndonnx_array(x: object) -> TypeIs[ndx.Array]: """ Return True if `x` is a ndonnx Array. @@ -142,7 +191,7 @@ def is_ndonnx_array(x: object) -> bool: is_pydata_sparse_array """ # Avoid importing torch if it isn't already - if 'ndonnx' not in sys.modules: + if "ndonnx" not in sys.modules: return False import ndonnx as ndx @@ -150,7 +199,7 @@ def is_ndonnx_array(x: object) -> bool: return isinstance(x, ndx.Array) -def is_dask_array(x: object) -> bool: +def is_dask_array(x: object) -> TypeIs[da.Array]: """ Return True if `x` is a dask.array Array. @@ -170,7 +219,7 @@ def is_dask_array(x: object) -> bool: is_pydata_sparse_array """ # Avoid importing dask if it isn't already - if 'dask.array' not in sys.modules: + if "dask.array" not in sys.modules: return False import dask.array @@ -178,7 +227,7 @@ def is_dask_array(x: object) -> bool: return isinstance(x, dask.array.Array) -def is_jax_array(x: object) -> bool: +def is_jax_array(x: object) -> TypeIs[jax.Array]: """ Return True if `x` is a JAX array. @@ -199,7 +248,7 @@ def is_jax_array(x: object) -> bool: is_pydata_sparse_array """ # Avoid importing jax if it isn't already - if 'jax' not in sys.modules: + if "jax" not in sys.modules: return False import jax @@ -207,7 +256,7 @@ def is_jax_array(x: object) -> bool: return isinstance(x, jax.Array) or _is_jax_zero_gradient_array(x) -def is_pydata_sparse_array(x) -> bool: +def is_pydata_sparse_array(x: object) -> TypeIs[sparse.SparseArray]: """ Return True if `x` is an array from the `sparse` package. @@ -228,16 +277,16 @@ def is_pydata_sparse_array(x) -> bool: is_jax_array """ # Avoid importing jax if it isn't already - if 'sparse' not in sys.modules: + if "sparse" not in sys.modules: return False - import sparse + import sparse # pyright: ignore[reportMissingTypeStubs] # TODO: Account for other backends. return isinstance(x, sparse.SparseArray) -def is_array_api_obj(x: object) -> bool: +def is_array_api_obj(x: object) -> TypeIs[_ArrayApiObj]: # pyright: ignore[reportUnknownParameterType] """ Return True if `x` is an array API compatible array object. @@ -252,18 +301,20 @@ def is_array_api_obj(x: object) -> bool: is_dask_array is_jax_array """ - return is_numpy_array(x) \ - or is_cupy_array(x) \ - or is_torch_array(x) \ - or is_dask_array(x) \ - or is_jax_array(x) \ - or is_pydata_sparse_array(x) \ - or hasattr(x, '__array_namespace__') + return ( + is_numpy_array(x) + or is_cupy_array(x) + or is_torch_array(x) + or is_dask_array(x) + or is_jax_array(x) + or is_pydata_sparse_array(x) + or hasattr(x, "__array_namespace__") + ) def _compat_module_name() -> str: - assert __name__.endswith('.common._helpers') - return __name__.removesuffix('.common._helpers') + assert __name__.endswith(".common._helpers") + return __name__.removesuffix(".common._helpers") def is_numpy_namespace(xp: Namespace) -> bool: @@ -284,7 +335,7 @@ def is_numpy_namespace(xp: Namespace) -> bool: is_pydata_sparse_namespace is_array_api_strict_namespace """ - return xp.__name__ in {'numpy', _compat_module_name() + '.numpy'} + return xp.__name__ in {"numpy", _compat_module_name() + ".numpy"} def is_cupy_namespace(xp: Namespace) -> bool: @@ -305,7 +356,7 @@ def is_cupy_namespace(xp: Namespace) -> bool: is_pydata_sparse_namespace is_array_api_strict_namespace """ - return xp.__name__ in {'cupy', _compat_module_name() + '.cupy'} + return xp.__name__ in {"cupy", _compat_module_name() + ".cupy"} def is_torch_namespace(xp: Namespace) -> bool: @@ -326,7 +377,7 @@ def is_torch_namespace(xp: Namespace) -> bool: is_pydata_sparse_namespace is_array_api_strict_namespace """ - return xp.__name__ in {'torch', _compat_module_name() + '.torch'} + return xp.__name__ in {"torch", _compat_module_name() + ".torch"} def is_ndonnx_namespace(xp: Namespace) -> bool: @@ -345,7 +396,7 @@ def is_ndonnx_namespace(xp: Namespace) -> bool: is_pydata_sparse_namespace is_array_api_strict_namespace """ - return xp.__name__ == 'ndonnx' + return xp.__name__ == "ndonnx" def is_dask_namespace(xp: Namespace) -> bool: @@ -366,7 +417,7 @@ def is_dask_namespace(xp: Namespace) -> bool: is_pydata_sparse_namespace is_array_api_strict_namespace """ - return xp.__name__ in {'dask.array', _compat_module_name() + '.dask.array'} + return xp.__name__ in {"dask.array", _compat_module_name() + ".dask.array"} def is_jax_namespace(xp: Namespace) -> bool: @@ -388,7 +439,7 @@ def is_jax_namespace(xp: Namespace) -> bool: is_pydata_sparse_namespace is_array_api_strict_namespace """ - return xp.__name__ in {'jax.numpy', 'jax.experimental.array_api'} + return xp.__name__ in {"jax.numpy", "jax.experimental.array_api"} def is_pydata_sparse_namespace(xp: Namespace) -> bool: @@ -407,7 +458,7 @@ def is_pydata_sparse_namespace(xp: Namespace) -> bool: is_jax_namespace is_array_api_strict_namespace """ - return xp.__name__ == 'sparse' + return xp.__name__ == "sparse" def is_array_api_strict_namespace(xp: Namespace) -> bool: @@ -426,21 +477,24 @@ def is_array_api_strict_namespace(xp: Namespace) -> bool: is_jax_namespace is_pydata_sparse_namespace """ - return xp.__name__ == 'array_api_strict' + return xp.__name__ == "array_api_strict" -def _check_api_version(api_version: str) -> None: - if api_version in ['2021.12', '2022.12', '2023.12']: - warnings.warn(f"The {api_version} version of the array API specification was requested but the returned namespace is actually version 2024.12") - elif api_version is not None and api_version not in ['2021.12', '2022.12', - '2023.12', '2024.12']: - raise ValueError("Only the 2024.12 version of the array API specification is currently supported") +def _check_api_version(api_version: str | None) -> None: + if api_version in _API_VERSIONS_OLD: + warnings.warn( + f"The {api_version} version of the array API specification was requested but the returned namespace is actually version 2024.12" + ) + elif api_version is not None and api_version not in _API_VERSIONS: + raise ValueError( + "Only the 2024.12 version of the array API specification is currently supported" + ) def array_namespace( - *xs: Union[Array, bool, int, float, complex, None], - api_version: Optional[str] = None, - use_compat: Optional[bool] = None, + *xs: Array | complex | None, + api_version: str | None = None, + use_compat: bool | None = None, ) -> Namespace: """ Get the array API compatible namespace for the arrays `xs`. @@ -510,11 +564,13 @@ def your_function(x, y): _use_compat = use_compat in [None, True] - namespaces = set() + namespaces: set[Namespace] = set() for x in xs: if is_numpy_array(x): - from .. import numpy as numpy_namespace import numpy as np + + from .. import numpy as numpy_namespace + if use_compat is True: _check_api_version(api_version) namespaces.add(numpy_namespace) @@ -528,25 +584,31 @@ def your_function(x, y): if _use_compat: _check_api_version(api_version) from .. import cupy as cupy_namespace + namespaces.add(cupy_namespace) else: - import cupy as cp + import cupy as cp # pyright: ignore[reportMissingTypeStubs] + namespaces.add(cp) elif is_torch_array(x): if _use_compat: _check_api_version(api_version) from .. import torch as torch_namespace + namespaces.add(torch_namespace) else: import torch + namespaces.add(torch) elif is_dask_array(x): if _use_compat: _check_api_version(api_version) from ..dask import array as dask_namespace + namespaces.add(dask_namespace) else: import dask.array as da + namespaces.add(da) elif is_jax_array(x): if use_compat is True: @@ -558,23 +620,27 @@ def your_function(x, y): # JAX v0.4.32 and newer implements the array API directly in jax.numpy. # For older JAX versions, it is available via jax.experimental.array_api. import jax.numpy + if hasattr(jax.numpy, "__array_api_version__"): jnp = jax.numpy else: - import jax.experimental.array_api as jnp + import jax.experimental.array_api as jnp # pyright: ignore[reportMissingImports] namespaces.add(jnp) elif is_pydata_sparse_array(x): if use_compat is True: _check_api_version(api_version) raise ValueError("`sparse` does not have an array-api-compat wrapper") else: - import sparse + import sparse # pyright: ignore[reportMissingTypeStubs] # `sparse` is already an array namespace. We do not have a wrapper # submodule for it. namespaces.add(sparse) - elif hasattr(x, '__array_namespace__'): + elif hasattr(x, "__array_namespace__"): if use_compat is True: - raise ValueError("The given array does not have an array-api-compat wrapper") + raise ValueError( + "The given array does not have an array-api-compat wrapper" + ) + x = cast("SupportsArrayNamespace[Any]", x) namespaces.add(x.__array_namespace__(api_version=api_version)) elif isinstance(x, (bool, int, float, complex, type(None))): continue @@ -588,15 +654,16 @@ def your_function(x, y): if len(namespaces) != 1: raise TypeError(f"Multiple namespaces for array inputs: {namespaces}") - xp, = namespaces + (xp,) = namespaces return xp + # backwards compatibility alias get_namespace = array_namespace -def _check_device(bare_xp, device): +def _check_device(bare_xp: Namespace, device: Device) -> None: # pyright: ignore[reportUnusedFunction] """ Validate dummy device on device-less array backends. @@ -609,11 +676,11 @@ def _check_device(bare_xp, device): https://github.com/data-apis/array-api-compat/pull/293 """ - if bare_xp is sys.modules.get('numpy'): + if bare_xp is sys.modules.get("numpy"): if device not in ("cpu", None): raise ValueError(f"Unsupported device for NumPy: {device!r}") - elif bare_xp is sys.modules.get('dask.array'): + elif bare_xp is sys.modules.get("dask.array"): if device not in ("cpu", _DASK_DEVICE, None): raise ValueError(f"Unsupported device for Dask: {device!r}") @@ -622,18 +689,20 @@ def _check_device(bare_xp, device): # when the array backend is not the CPU. # (since it is not easy to tell which device a dask array is on) class _dask_device: - def __repr__(self): + def __repr__(self) -> Literal["DASK_DEVICE"]: return "DASK_DEVICE" + _DASK_DEVICE = _dask_device() + # device() is not on numpy.ndarray or dask.array and to_device() is not on numpy.ndarray # or cupy.ndarray. They are not included in array objects of this library # because this library just reuses the respective ndarray classes without # wrapping or subclassing them. These helper functions can be used instead of # the wrapper functions for libraries that need to support both NumPy/CuPy and # other libraries that use devices. -def device(x: Array, /) -> Device: +def device(x: _ArrayApiObj, /) -> Device: """ Hardware device the array data resides on. @@ -669,7 +738,7 @@ def device(x: Array, /) -> Device: return "cpu" elif is_dask_array(x): # Peek at the metadata of the Dask array to determine type - if is_numpy_array(x._meta): + if is_numpy_array(x._meta): # pyright: ignore # Must be on CPU since backed by numpy return "cpu" return _DASK_DEVICE @@ -679,7 +748,7 @@ def device(x: Array, /) -> Device: # Return None in this case. Note that this workaround breaks # the standard and will result in new arrays being created on the # default device instead of the same device as the input array(s). - x_device = getattr(x, 'device', None) + x_device = getattr(x, "device", None) # Older JAX releases had .device() as a method, which has been replaced # with a property in accordance with the standard. if inspect.ismethod(x_device): @@ -688,27 +757,34 @@ def device(x: Array, /) -> Device: return x_device elif is_pydata_sparse_array(x): # `sparse` will gain `.device`, so check for this first. - x_device = getattr(x, 'device', None) + x_device = getattr(x, "device", None) if x_device is not None: return x_device # Everything but DOK has this attr. try: - inner = x.data + inner = x.data # pyright: ignore except AttributeError: return "cpu" # Return the device of the constituent array - return device(inner) - return x.device + return device(inner) # pyright: ignore + return x.device # pyright: ignore + # Prevent shadowing, used below _device = device + # Based on cupy.array_api.Array.to_device -def _cupy_to_device(x, device, /, stream=None): - import cupy as cp - from cupy.cuda import Device as _Device - from cupy.cuda import stream as stream_module - from cupy_backends.cuda.api import runtime +def _cupy_to_device( + x: _CupyArray, + device: Device, + /, + stream: int | Any | None = None, +) -> _CupyArray: + import cupy as cp # pyright: ignore[reportMissingTypeStubs] + from cupy.cuda import Device as _Device # pyright: ignore + from cupy.cuda import stream as stream_module # pyright: ignore + from cupy_backends.cuda.api import runtime # pyright: ignore if device == x.device: return x @@ -721,33 +797,40 @@ def _cupy_to_device(x, device, /, stream=None): raise ValueError(f"Unsupported device {device!r}") else: # see cupy/cupy#5985 for the reason how we handle device/stream here - prev_device = runtime.getDevice() - prev_stream: stream_module.Stream = None + prev_device: Any = runtime.getDevice() # pyright: ignore[reportUnknownMemberType] + prev_stream = None if stream is not None: - prev_stream = stream_module.get_current_stream() + prev_stream: Any = stream_module.get_current_stream() # pyright: ignore # stream can be an int as specified in __dlpack__, or a CuPy stream if isinstance(stream, int): - stream = cp.cuda.ExternalStream(stream) - elif isinstance(stream, cp.cuda.Stream): + stream = cp.cuda.ExternalStream(stream) # pyright: ignore + elif isinstance(stream, cp.cuda.Stream): # pyright: ignore[reportUnknownMemberType] pass else: - raise ValueError('the input stream is not recognized') - stream.use() + raise ValueError("the input stream is not recognized") + stream.use() # pyright: ignore[reportUnknownMemberType] try: - runtime.setDevice(device.id) + runtime.setDevice(device.id) # pyright: ignore[reportUnknownMemberType] arr = x.copy() finally: - runtime.setDevice(prev_device) + runtime.setDevice(prev_device) # pyright: ignore[reportUnknownMemberType] if stream is not None: prev_stream.use() return arr -def _torch_to_device(x, device, /, stream=None): + +def _torch_to_device( + x: torch.Tensor, + device: torch.device | str | int, + /, + stream: None = None, +) -> torch.Tensor: if stream is not None: raise NotImplementedError return x.to(device) -def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] = None) -> Array: + +def to_device(x: Array, device: Device, /, *, stream: int | Any | None = None) -> Array: """ Copy the array from the device on which it currently resides to the specified ``device``. @@ -767,7 +850,7 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] a ``device`` object (see the `Device Support `__ section of the array API specification). - stream: Optional[Union[int, Any]] + stream: int | Any | None stream object to use during copy. In addition to the types supported in ``array.__dlpack__``, implementations may choose to support any library-specific stream object with the caveat that any code using @@ -799,25 +882,26 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] if is_numpy_array(x): if stream is not None: raise ValueError("The stream argument to to_device() is not supported") - if device == 'cpu': + if device == "cpu": return x raise ValueError(f"Unsupported device {device!r}") elif is_cupy_array(x): # cupy does not yet have to_device return _cupy_to_device(x, device, stream=stream) elif is_torch_array(x): - return _torch_to_device(x, device, stream=stream) + return _torch_to_device(x, device, stream=stream) # pyright: ignore[reportArgumentType] elif is_dask_array(x): if stream is not None: raise ValueError("The stream argument to to_device() is not supported") # TODO: What if our array is on the GPU already? - if device == 'cpu': + if device == "cpu": return x raise ValueError(f"Unsupported device {device!r}") elif is_jax_array(x): if not hasattr(x, "__array_namespace__"): # In JAX v0.4.31 and older, this import adds to_device method to x... - import jax.experimental.array_api # noqa: F401 + import jax.experimental.array_api # noqa: F401 # pyright: ignore + # ... but only on eager JAX. It won't work inside jax.jit. if not hasattr(x, "to_device"): return x @@ -826,10 +910,16 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] # Perform trivial check to return the same array if # device is same instead of err-ing. return x - return x.to_device(device, stream=stream) + return x.to_device(device, stream=stream) # pyright: ignore -def size(x: Array) -> int | None: +@overload +def size(x: HasShape[Collection[SupportsIndex]]) -> int: ... +@overload +def size(x: HasShape[Collection[None]]) -> None: ... +@overload +def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None: ... +def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None: """ Return the total number of elements of x. @@ -844,7 +934,7 @@ def size(x: Array) -> int | None: # Lazy API compliant arrays, such as ndonnx, can contain None in their shape if None in x.shape: return None - out = math.prod(x.shape) + out = math.prod(cast("Collection[SupportsIndex]", x.shape)) # dask.array.Array.shape can contain NaN return None if math.isnan(out) else out @@ -907,7 +997,7 @@ def is_lazy_array(x: object) -> bool: # on __bool__ (dask is one such example, which however is special-cased above). # Select a single point of the array - s = size(x) + s = size(cast("HasShape[Collection[SupportsIndex | None]]", x)) if s is None: return True xp = array_namespace(x) @@ -952,4 +1042,7 @@ def is_lazy_array(x: object) -> bool: "to_device", ] -_all_ignore = ['sys', 'math', 'inspect', 'warnings'] +_all_ignore = ["sys", "math", "inspect", "warnings"] + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/common/_linalg.py b/array_api_compat/common/_linalg.py index d1e7ebd8..7e002aed 100644 --- a/array_api_compat/common/_linalg.py +++ b/array_api_compat/common/_linalg.py @@ -1,23 +1,33 @@ from __future__ import annotations import math -from typing import Literal, NamedTuple, Optional, Tuple, Union +from typing import Literal, NamedTuple, cast import numpy as np + if np.__version__[0] == "2": from numpy.lib.array_utils import normalize_axis_tuple else: from numpy.core.numeric import normalize_axis_tuple -from ._aliases import matmul, matrix_transpose, tensordot, vecdot, isdtype from .._internal import get_xp -from ._typing import Array, Namespace +from ._aliases import isdtype, matmul, matrix_transpose, tensordot, vecdot +from ._typing import Array, DType, Namespace + # These are in the main NumPy namespace but not in numpy.linalg -def cross(x1: Array, x2: Array, /, xp: Namespace, *, axis: int = -1, **kwargs) -> Array: +def cross( + x1: Array, + x2: Array, + /, + xp: Namespace, + *, + axis: int = -1, + **kwargs: object, +) -> Array: return xp.cross(x1, x2, axis=axis, **kwargs) -def outer(x1: Array, x2: Array, /, xp: Namespace, **kwargs) -> Array: +def outer(x1: Array, x2: Array, /, xp: Namespace, **kwargs: object) -> Array: return xp.outer(x1, x2, **kwargs) class EighResult(NamedTuple): @@ -39,46 +49,66 @@ class SVDResult(NamedTuple): # These functions are the same as their NumPy counterparts except they return # a namedtuple. -def eigh(x: Array, /, xp: Namespace, **kwargs) -> EighResult: +def eigh(x: Array, /, xp: Namespace, **kwargs: object) -> EighResult: return EighResult(*xp.linalg.eigh(x, **kwargs)) -def qr(x: Array, /, xp: Namespace, *, mode: Literal['reduced', 'complete'] = 'reduced', - **kwargs) -> QRResult: +def qr( + x: Array, + /, + xp: Namespace, + *, + mode: Literal["reduced", "complete"] = "reduced", + **kwargs: object, +) -> QRResult: return QRResult(*xp.linalg.qr(x, mode=mode, **kwargs)) -def slogdet(x: Array, /, xp: Namespace, **kwargs) -> SlogdetResult: +def slogdet(x: Array, /, xp: Namespace, **kwargs: object) -> SlogdetResult: return SlogdetResult(*xp.linalg.slogdet(x, **kwargs)) def svd( - x: Array, /, xp: Namespace, *, full_matrices: bool = True, **kwargs + x: Array, + /, + xp: Namespace, + *, + full_matrices: bool = True, + **kwargs: object, ) -> SVDResult: return SVDResult(*xp.linalg.svd(x, full_matrices=full_matrices, **kwargs)) # These functions have additional keyword arguments # The upper keyword argument is new from NumPy -def cholesky(x: Array, /, xp: Namespace, *, upper: bool = False, **kwargs) -> Array: +def cholesky( + x: Array, + /, + xp: Namespace, + *, + upper: bool = False, + **kwargs: object, +) -> Array: L = xp.linalg.cholesky(x, **kwargs) if upper: U = get_xp(xp)(matrix_transpose)(L) if get_xp(xp)(isdtype)(U.dtype, 'complex floating'): - U = xp.conj(U) + U = xp.conj(U) # pyright: ignore[reportConstantRedefinition] return U return L # The rtol keyword argument of matrix_rank() and pinv() is new from NumPy. # Note that it has a different semantic meaning from tol and rcond. -def matrix_rank(x: Array, - /, - xp: Namespace, - *, - rtol: Optional[Union[float, Array]] = None, - **kwargs) -> Array: +def matrix_rank( + x: Array, + /, + xp: Namespace, + *, + rtol: float | Array | None = None, + **kwargs: object, +) -> Array: # this is different from xp.linalg.matrix_rank, which supports 1 # dimensional arrays. if x.ndim < 2: raise xp.linalg.LinAlgError("1-dimensional array given. Array must be at least two-dimensional") - S = get_xp(xp)(svdvals)(x, **kwargs) + S: Array = get_xp(xp)(svdvals)(x, **kwargs) if rtol is None: tol = S.max(axis=-1, keepdims=True) * max(x.shape[-2:]) * xp.finfo(S.dtype).eps else: @@ -88,7 +118,12 @@ def matrix_rank(x: Array, return xp.count_nonzero(S > tol, axis=-1) def pinv( - x: Array, /, xp: Namespace, *, rtol: Optional[Union[float, Array]] = None, **kwargs + x: Array, + /, + xp: Namespace, + *, + rtol: float | Array | None = None, + **kwargs: object, ) -> Array: # this is different from xp.linalg.pinv, which does not multiply the # default tolerance by max(M, N). @@ -104,13 +139,13 @@ def matrix_norm( xp: Namespace, *, keepdims: bool = False, - ord: Optional[Union[int, float, Literal['fro', 'nuc']]] = 'fro', + ord: float | Literal["fro", "nuc"] | None = "fro", ) -> Array: return xp.linalg.norm(x, axis=(-2, -1), keepdims=keepdims, ord=ord) # svdvals is not in NumPy (but it is in SciPy). It is equivalent to # xp.linalg.svd(compute_uv=False). -def svdvals(x: Array, /, xp: Namespace) -> Union[Array, Tuple[Array, ...]]: +def svdvals(x: Array, /, xp: Namespace) -> Array | tuple[Array, ...]: return xp.linalg.svd(x, compute_uv=False) def vector_norm( @@ -118,9 +153,9 @@ def vector_norm( /, xp: Namespace, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, + axis: int | tuple[int, ...] | None = None, keepdims: bool = False, - ord: Optional[Union[int, float]] = 2, + ord: float = 2, ) -> Array: # xp.linalg.norm tries to do a matrix norm whenever axis is a 2-tuple or # when axis=None and the input is 2-D, so to force a vector norm, we make @@ -133,7 +168,10 @@ def vector_norm( elif isinstance(axis, tuple): # Note: The axis argument supports any number of axes, whereas # xp.linalg.norm() only supports a single axis for vector norm. - normalized_axis = normalize_axis_tuple(axis, x.ndim) + normalized_axis = cast( + "tuple[int, ...]", + normalize_axis_tuple(axis, x.ndim), # pyright: ignore[reportCallIssue] + ) rest = tuple(i for i in range(x.ndim) if i not in normalized_axis) newshape = axis + rest _x = xp.transpose(x, newshape).reshape( @@ -149,7 +187,13 @@ def vector_norm( # We can't reuse xp.linalg.norm(keepdims) because of the reshape hacks # above to avoid matrix norm logic. shape = list(x.shape) - _axis = normalize_axis_tuple(range(x.ndim) if axis is None else axis, x.ndim) + _axis = cast( + "tuple[int, ...]", + normalize_axis_tuple( # pyright: ignore[reportCallIssue] + range(x.ndim) if axis is None else axis, + x.ndim, + ), + ) for i in _axis: shape[i] = 1 res = xp.reshape(res, tuple(shape)) @@ -159,11 +203,17 @@ def vector_norm( # xp.diagonal and xp.trace operate on the first two axes whereas these # operates on the last two -def diagonal(x: Array, /, xp: Namespace, *, offset: int = 0, **kwargs) -> Array: +def diagonal(x: Array, /, xp: Namespace, *, offset: int = 0, **kwargs: object) -> Array: return xp.diagonal(x, offset=offset, axis1=-2, axis2=-1, **kwargs) def trace( - x: Array, /, xp: Namespace, *, offset: int = 0, dtype=None, **kwargs + x: Array, + /, + xp: Namespace, + *, + offset: int = 0, + dtype: DType | None = None, + **kwargs: object, ) -> Array: return xp.asarray( xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs) @@ -176,3 +226,7 @@ def trace( 'trace'] _all_ignore = ['math', 'normalize_axis_tuple', 'get_xp', 'np', 'isdtype'] + + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/common/_typing.py b/array_api_compat/common/_typing.py index 4c3b356b..d7deade1 100644 --- a/array_api_compat/common/_typing.py +++ b/array_api_compat/common/_typing.py @@ -1,24 +1,150 @@ from __future__ import annotations + +from collections.abc import Mapping from types import ModuleType as Namespace -from typing import Any, TypeVar, Protocol +from typing import TYPE_CHECKING, Literal, Protocol, TypeAlias, TypedDict, TypeVar + +if TYPE_CHECKING: + from _typeshed import Incomplete + + SupportsBufferProtocol: TypeAlias = Incomplete + Array: TypeAlias = Incomplete + Device: TypeAlias = Incomplete + DType: TypeAlias = Incomplete +else: + SupportsBufferProtocol = object + Array = object + Device = object + DType = object + + +_T_co = TypeVar("_T_co", covariant=True) + + +class NestedSequence(Protocol[_T_co]): + def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ... + def __len__(self, /) -> int: ... + + +class SupportsArrayNamespace(Protocol[_T_co]): + def __array_namespace__(self, /, *, api_version: str | None) -> _T_co: ... + + +class HasShape(Protocol[_T_co]): + @property + def shape(self, /) -> _T_co: ... + + +# Return type of `__array_namespace_info__.default_dtypes` +Capabilities = TypedDict( + "Capabilities", + { + "boolean indexing": bool, + "data-dependent shapes": bool, + "max dimensions": int, + }, +) + +# Return type of `__array_namespace_info__.default_dtypes` +DefaultDTypes = TypedDict( + "DefaultDTypes", + { + "real floating": DType, + "complex floating": DType, + "integral": DType, + "indexing": DType, + }, +) + + +_DTypeKind: TypeAlias = Literal[ + "bool", + "signed integer", + "unsigned integer", + "integral", + "real floating", + "complex floating", + "numeric", +] +# Type of the `kind` parameter in `__array_namespace_info__.dtypes` +DTypeKind: TypeAlias = _DTypeKind | tuple[_DTypeKind, ...] + + +# `__array_namespace_info__.dtypes(kind="bool")` +class DTypesBool(TypedDict): + bool: DType + + +# `__array_namespace_info__.dtypes(kind="signed integer")` +class DTypesSigned(TypedDict): + int8: DType + int16: DType + int32: DType + int64: DType + + +# `__array_namespace_info__.dtypes(kind="unsigned integer")` +class DTypesUnsigned(TypedDict): + uint8: DType + uint16: DType + uint32: DType + uint64: DType + + +# `__array_namespace_info__.dtypes(kind="integral")` +class DTypesIntegral(DTypesSigned, DTypesUnsigned): + pass + + +# `__array_namespace_info__.dtypes(kind="real floating")` +class DTypesReal(TypedDict): + float32: DType + float64: DType + + +# `__array_namespace_info__.dtypes(kind="complex floating")` +class DTypesComplex(TypedDict): + complex64: DType + complex128: DType + + +# `__array_namespace_info__.dtypes(kind="numeric")` +class DTypesNumeric(DTypesIntegral, DTypesReal, DTypesComplex): + pass + + +# `__array_namespace_info__.dtypes(kind=None)` (default) +class DTypesAll(DTypesBool, DTypesNumeric): + pass + + +# `__array_namespace_info__.dtypes(kind=?)` (fallback) +DTypesAny: TypeAlias = Mapping[str, DType] + __all__ = [ "Array", + "Capabilities", "DType", + "DTypeKind", + "DTypesAny", + "DTypesAll", + "DTypesBool", + "DTypesNumeric", + "DTypesIntegral", + "DTypesSigned", + "DTypesUnsigned", + "DTypesReal", + "DTypesComplex", + "DefaultDTypes", "Device", + "HasShape", "Namespace", "NestedSequence", + "SupportsArrayNamespace", "SupportsBufferProtocol", ] -_T_co = TypeVar("_T_co", covariant=True) - -class NestedSequence(Protocol[_T_co]): - def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ... - def __len__(self, /) -> int: ... - -SupportsBufferProtocol = Any -Array = Any -Device = Any -DType = Any +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/dask/array/__init__.py b/array_api_compat/dask/array/__init__.py index bb649306..1e47b960 100644 --- a/array_api_compat/dask/array/__init__.py +++ b/array_api_compat/dask/array/__init__.py @@ -1,9 +1,11 @@ -from dask.array import * # noqa: F403 +from typing import Final + +from dask.array import * # noqa: F403 # These imports may overwrite names from the import * above. -from ._aliases import * # noqa: F403 +from ._aliases import * # noqa: F403 -__array_api_version__ = '2024.12' +__array_api_version__: Final = "2024.12" # See the comment in the numpy __init__.py __import__(__package__ + '.linalg') diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index e7ddde78..9687a9cd 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -1,28 +1,38 @@ +# pyright: reportPrivateUsage=false +# pyright: reportUnknownArgumentType=false +# pyright: reportUnknownMemberType=false +# pyright: reportUnknownVariableType=false + from __future__ import annotations -from typing import Callable, Optional, Union +from builtins import bool as py_bool +from collections.abc import Callable +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from typing_extensions import TypeIs +import dask.array as da import numpy as np +from numpy import bool_ as bool from numpy import ( - # dtypes - bool_ as bool, + can_cast, + complex64, + complex128, float32, float64, int8, int16, int32, int64, + result_type, uint8, uint16, uint32, uint64, - complex64, - complex128, - can_cast, - result_type, ) -import dask.array as da +from ..._internal import get_xp from ...common import _aliases, _helpers, array_namespace from ...common._typing import ( Array, @@ -31,7 +41,6 @@ NestedSequence, SupportsBufferProtocol, ) -from ..._internal import get_xp from ._info import __array_namespace_info__ isdtype = get_xp(np)(_aliases.isdtype) @@ -44,8 +53,8 @@ def astype( dtype: DType, /, *, - copy: bool = True, - device: Optional[Device] = None, + copy: py_bool = True, + device: Device | None = None, ) -> Array: """ Array API compatibility wrapper for astype(). @@ -69,14 +78,14 @@ def astype( # not pass stop/step as keyword arguments, which will cause # an error with dask def arange( - start: Union[int, float], + start: float, /, - stop: Optional[Union[int, float]] = None, - step: Union[int, float] = 1, + stop: float | None = None, + step: float = 1, *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, ) -> Array: """ Array API compatibility wrapper for arange(). @@ -87,7 +96,7 @@ def arange( # TODO: respect device keyword? _helpers._check_device(da, device) - args = [start] + args: list[Any] = [start] if stop is not None: args.append(stop) else: @@ -137,18 +146,13 @@ def arange( # asarray also adds the copy keyword, which is not present in numpy 1.0. def asarray( - obj: ( - Array - | bool | int | float | complex - | NestedSequence[bool | int | float | complex] - | SupportsBufferProtocol - ), + obj: complex | NestedSequence[complex] | Array | SupportsBufferProtocol, /, *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - copy: Optional[bool] = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + copy: py_bool | None = None, + **kwargs: object, ) -> Array: """ Array API compatibility wrapper for asarray(). @@ -164,7 +168,7 @@ def asarray( if copy is False: raise ValueError("Unable to avoid copy when changing dtype") obj = obj.astype(dtype) - return obj.copy() if copy else obj + return obj.copy() if copy else obj # pyright: ignore[reportAttributeAccessIssue] if copy is False: raise NotImplementedError( @@ -177,22 +181,21 @@ def asarray( return da.from_array(obj) -from dask.array import ( - # Element wise aliases - arccos as acos, - arccosh as acosh, - arcsin as asin, - arcsinh as asinh, - arctan as atan, - arctan2 as atan2, - arctanh as atanh, - left_shift as bitwise_left_shift, - right_shift as bitwise_right_shift, - invert as bitwise_invert, - power as pow, - # Other - concatenate as concat, -) +# Element wise aliases +from dask.array import arccos as acos +from dask.array import arccosh as acosh +from dask.array import arcsin as asin +from dask.array import arcsinh as asinh +from dask.array import arctan as atan +from dask.array import arctan2 as atan2 +from dask.array import arctanh as atanh + +# Other +from dask.array import concatenate as concat +from dask.array import invert as bitwise_invert +from dask.array import left_shift as bitwise_left_shift +from dask.array import power as pow +from dask.array import right_shift as bitwise_right_shift # dask.array.clip does not work unless all three arguments are provided. @@ -202,8 +205,8 @@ def asarray( def clip( x: Array, /, - min: Optional[Union[int, float, Array]] = None, - max: Optional[Union[int, float, Array]] = None, + min: float | Array | None = None, + max: float | Array | None = None, ) -> Array: """ Array API compatibility wrapper for clip(). @@ -212,8 +215,8 @@ def clip( specification for more details. """ - def _isscalar(a): - return isinstance(a, (int, float, type(None))) + def _isscalar(a: float | Array | None, /) -> TypeIs[float | None]: + return a is None or isinstance(a, (int, float)) min_shape = () if _isscalar(min) else min.shape max_shape = () if _isscalar(max) else max.shape @@ -266,7 +269,12 @@ def _ensure_single_chunk(x: Array, axis: int) -> tuple[Array, Callable[[Array], def sort( - x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True + x: Array, + /, + *, + axis: int = -1, + descending: py_bool = False, + stable: py_bool = True, ) -> Array: """ Array API compatibility layer around the lack of sort() in Dask. @@ -296,7 +304,12 @@ def sort( def argsort( - x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True + x: Array, + /, + *, + axis: int = -1, + descending: py_bool = False, + stable: py_bool = True, ) -> Array: """ Array API compatibility layer around the lack of argsort() in Dask. @@ -330,25 +343,34 @@ def argsort( # dask.array.count_nonzero does not have keepdims def count_nonzero( x: Array, - axis=None, - keepdims=False + axis: int | None = None, + keepdims: py_bool = False, ) -> Array: - result = da.count_nonzero(x, axis) - if keepdims: - if axis is None: - return da.reshape(result, [1]*x.ndim) - return da.expand_dims(result, axis) - return result - - + result = da.count_nonzero(x, axis) + if keepdims: + if axis is None: + return da.reshape(result, [1] * x.ndim) + return da.expand_dims(result, axis) + return result + + +__all__ = [ + "__array_namespace_info__", + "count_nonzero", + "bool", + "int8", "int16", "int32", "int64", + "uint8", "uint16", "uint32", "uint64", + "float32", "float64", + "complex64", "complex128", + "asarray", "astype", "can_cast", "result_type", + "pow", + "concat", + "acos", "acosh", "asin", "asinh", "atan", "atan2", "atanh", + "bitwise_left_shift", "bitwise_right_shift", "bitwise_invert", +] # fmt: skip +__all__ += _aliases.__all__ +_all_ignore = ["array_namespace", "get_xp", "da", "np"] -__all__ = _aliases.__all__ + [ - '__array_namespace_info__', 'asarray', 'astype', 'acos', - 'acosh', 'asin', 'asinh', 'atan', 'atan2', - 'atanh', 'bitwise_left_shift', 'bitwise_invert', - 'bitwise_right_shift', 'concat', 'pow', 'can_cast', - 'result_type', 'bool', 'float32', 'float64', 'int8', 'int16', 'int32', 'int64', - 'uint8', 'uint16', 'uint32', 'uint64', 'complex64', 'complex128', - 'can_cast', 'count_nonzero', 'result_type'] -_all_ignore = ["array_namespace", "get_xp", "da", "np"] +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/dask/array/_info.py b/array_api_compat/dask/array/_info.py index 614f43d9..9e4d736f 100644 --- a/array_api_compat/dask/array/_info.py +++ b/array_api_compat/dask/array/_info.py @@ -7,25 +7,51 @@ more details. """ + +# pyright: reportPrivateUsage=false + +from __future__ import annotations + +from typing import Literal as L +from typing import TypeAlias, overload + +from numpy import bool_ as bool from numpy import ( + complex64, + complex128, dtype, - bool_ as bool, - intp, + float32, + float64, int8, int16, int32, int64, + intp, uint8, uint16, uint32, uint64, - float32, - float64, - complex64, - complex128, ) -from ...common._helpers import _DASK_DEVICE +from ...common._helpers import _DASK_DEVICE, _dask_device +from ...common._typing import ( + Capabilities, + DefaultDTypes, + DType, + DTypeKind, + DTypesAll, + DTypesAny, + DTypesBool, + DTypesComplex, + DTypesIntegral, + DTypesNumeric, + DTypesReal, + DTypesSigned, + DTypesUnsigned, +) + +_Device: TypeAlias = L["cpu"] | _dask_device + class __array_namespace_info__: """ @@ -59,9 +85,9 @@ class __array_namespace_info__: """ - __module__ = 'dask.array' + __module__ = "dask.array" - def capabilities(self): + def capabilities(self) -> Capabilities: """ Return a dictionary of array API library capabilities. @@ -116,7 +142,7 @@ def capabilities(self): "max dimensions": 64, } - def default_device(self): + def default_device(self) -> L["cpu"]: """ The default device used for new Dask arrays. @@ -143,7 +169,7 @@ def default_device(self): """ return "cpu" - def default_dtypes(self, *, device=None): + def default_dtypes(self, /, *, device: _Device | None = None) -> DefaultDTypes: """ The default data types used for new Dask arrays. @@ -184,8 +210,8 @@ def default_dtypes(self, *, device=None): """ if device not in ["cpu", _DASK_DEVICE, None]: raise ValueError( - 'Device not understood. Only "cpu" or _DASK_DEVICE is allowed, but received:' - f' {device}' + f'Device not understood. Only "cpu" or _DASK_DEVICE is allowed, ' + f"but received: {device!r}" ) return { "real floating": dtype(float64), @@ -194,7 +220,41 @@ def default_dtypes(self, *, device=None): "indexing": dtype(intp), } - def dtypes(self, *, device=None, kind=None): + @overload + def dtypes( + self, /, *, device: _Device | None = None, kind: None = None + ) -> DTypesAll: ... + @overload + def dtypes( + self, /, *, device: _Device | None = None, kind: L["bool"] + ) -> DTypesBool: ... + @overload + def dtypes( + self, /, *, device: _Device | None = None, kind: L["signed integer"] + ) -> DTypesSigned: ... + @overload + def dtypes( + self, /, *, device: _Device | None = None, kind: L["unsigned integer"] + ) -> DTypesUnsigned: ... + @overload + def dtypes( + self, /, *, device: _Device | None = None, kind: L["integral"] + ) -> DTypesIntegral: ... + @overload + def dtypes( + self, /, *, device: _Device | None = None, kind: L["real floating"] + ) -> DTypesReal: ... + @overload + def dtypes( + self, /, *, device: _Device | None = None, kind: L["complex floating"] + ) -> DTypesComplex: ... + @overload + def dtypes( + self, /, *, device: _Device | None = None, kind: L["numeric"] + ) -> DTypesNumeric: ... + def dtypes( + self, /, *, device: _Device | None = None, kind: DTypeKind | None = None + ) -> DTypesAny: """ The array API data types supported by Dask. @@ -251,7 +311,7 @@ def dtypes(self, *, device=None, kind=None): if device not in ["cpu", _DASK_DEVICE, None]: raise ValueError( 'Device not understood. Only "cpu" or _DASK_DEVICE is allowed, but received:' - f' {device}' + f" {device}" ) if kind is None: return { @@ -321,14 +381,14 @@ def dtypes(self, *, device=None, kind=None): "complex64": dtype(complex64), "complex128": dtype(complex128), } - if isinstance(kind, tuple): - res = {} + if isinstance(kind, tuple): # type: ignore[reportUnnecessaryIsinstanceCall] + res: dict[str, DType] = {} for k in kind: res.update(self.dtypes(kind=k)) return res raise ValueError(f"unsupported kind: {kind!r}") - def devices(self): + def devices(self) -> list[_Device]: """ The devices supported by Dask. diff --git a/array_api_compat/dask/array/linalg.py b/array_api_compat/dask/array/linalg.py index bd53f0df..0825386e 100644 --- a/array_api_compat/dask/array/linalg.py +++ b/array_api_compat/dask/array/linalg.py @@ -3,15 +3,16 @@ from typing import Literal import dask.array as da + +# The `matmul` and `tensordot` functions are in both the main and linalg namespaces +from dask.array import matmul, outer, tensordot + # Exports -from dask.array.linalg import * # noqa: F403 -from dask.array import outer -# These functions are in both the main and linalg namespaces -from dask.array import matmul, tensordot +from dask.array.linalg import * # noqa: F403 from ..._internal import get_xp from ...common import _linalg -from ...common._typing import Array +from ...common._typing import Array as _Array from ._aliases import matrix_transpose, vecdot # dask.array.linalg doesn't have __all__. If it is added, replace this with @@ -32,8 +33,11 @@ # supports the mode keyword on QR # https://github.com/dask/dask/issues/10388 #qr = get_xp(da)(_linalg.qr) -def qr(x: Array, mode: Literal['reduced', 'complete'] = 'reduced', - **kwargs) -> QRResult: +def qr( + x: _Array, + mode: Literal["reduced", "complete"] = "reduced", + **kwargs: object, +) -> QRResult: if mode != "reduced": raise ValueError("dask arrays only support using mode='reduced'") return QRResult(*da.linalg.qr(x, **kwargs)) @@ -46,12 +50,12 @@ def qr(x: Array, mode: Literal['reduced', 'complete'] = 'reduced', # Wrap the svd functions to not pass full_matrices to dask # when full_matrices=False (as that is the default behavior for dask), # and dask doesn't have the full_matrices keyword -def svd(x: Array, full_matrices: bool = True, **kwargs) -> SVDResult: +def svd(x: _Array, full_matrices: bool = True, **kwargs) -> SVDResult: if full_matrices: raise ValueError("full_matrics=True is not supported by dask.") return da.linalg.svd(x, coerce_signs=False, **kwargs) -def svdvals(x: Array) -> Array: +def svdvals(x: _Array) -> _Array: # TODO: can't avoid computing U or V for dask _, s, _ = svd(x) return s diff --git a/array_api_compat/numpy/__init__.py b/array_api_compat/numpy/__init__.py index 6a5d9867..f7b558ba 100644 --- a/array_api_compat/numpy/__init__.py +++ b/array_api_compat/numpy/__init__.py @@ -1,10 +1,16 @@ -from numpy import * # noqa: F403 +# ruff: noqa: PLC0414 +from typing import Final + +from numpy import * # noqa: F403 # pyright: ignore[reportWildcardImportFromLibrary] # from numpy import * doesn't overwrite these builtin names -from numpy import abs, max, min, round # noqa: F401 +from numpy import abs as abs +from numpy import max as max +from numpy import min as min +from numpy import round as round # These imports may overwrite names from the import * above. -from ._aliases import * # noqa: F403 +from ._aliases import * # noqa: F403 # Don't know why, but we have to do an absolute import to import linalg. If we # instead do @@ -13,9 +19,17 @@ # # It doesn't overwrite np.linalg from above. The import is generated # dynamically so that the library can be vendored. -__import__(__package__ + '.linalg') -__import__(__package__ + '.fft') +__import__(__package__ + ".linalg") + +__import__(__package__ + ".fft") + +from ..common._helpers import * # noqa: F403 +from .linalg import matrix_transpose, vecdot # noqa: F401 -from .linalg import matrix_transpose, vecdot # noqa: F401 +try: + # Used in asarray(). Not present in older versions. + from numpy import _CopyMode # noqa: F401 +except ImportError: + pass -__array_api_version__ = '2024.12' +__array_api_version__: Final = "2024.12" diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index d1fd46a1..d8792611 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -1,6 +1,10 @@ +# pyright: reportPrivateUsage=false from __future__ import annotations -from typing import Optional, Union +from builtins import bool as py_bool +from typing import TYPE_CHECKING, Any, Literal, TypeAlias, cast + +import numpy as np from .._internal import get_xp from ..common import _aliases, _helpers @@ -8,7 +12,12 @@ from ._info import __array_namespace_info__ from ._typing import Array, Device, DType -import numpy as np +if TYPE_CHECKING: + from typing_extensions import Buffer, TypeIs + +# The values of the `_CopyMode` enum can be either `False`, `True`, or `2`: +# https://github.com/numpy/numpy/blob/5a8a6a79d9c2fff8f07dcab5d41e14f8508d673f/numpy/_globals.pyi#L7-L10 +_Copy: TypeAlias = py_bool | Literal[2] | np._CopyMode bool = np.bool_ @@ -65,9 +74,9 @@ iinfo = get_xp(np)(_aliases.iinfo) -def _supports_buffer_protocol(obj): +def _supports_buffer_protocol(obj: object) -> TypeIs[Buffer]: # pyright: ignore[reportUnusedFunction] try: - memoryview(obj) + memoryview(obj) # pyright: ignore[reportArgumentType] except TypeError: return False return True @@ -78,18 +87,13 @@ def _supports_buffer_protocol(obj): # complicated enough that it's easier to define it separately for each module # rather than trying to combine everything into one function in common/ def asarray( - obj: ( - Array - | bool | int | float | complex - | NestedSequence[bool | int | float | complex] - | SupportsBufferProtocol - ), + obj: Array | complex | NestedSequence[complex] | SupportsBufferProtocol, /, *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - copy: Optional[Union[bool, np._CopyMode]] = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + copy: _Copy | None = None, + **kwargs: Any, ) -> Array: """ Array API compatibility wrapper for asarray(). @@ -106,7 +110,7 @@ def asarray( elif copy is True: copy = np._CopyMode.ALWAYS - return np.array(obj, copy=copy, dtype=dtype, **kwargs) + return np.array(obj, copy=copy, dtype=dtype, **kwargs) # pyright: ignore def astype( @@ -114,8 +118,8 @@ def astype( dtype: DType, /, *, - copy: bool = True, - device: Optional[Device] = None, + copy: py_bool = True, + device: Device | None = None, ) -> Array: _helpers._check_device(np, device) return x.astype(dtype=dtype, copy=copy) @@ -123,8 +127,14 @@ def astype( # count_nonzero returns a python int for axis=None and keepdims=False # https://github.com/numpy/numpy/issues/17562 -def count_nonzero(x: Array, axis=None, keepdims=False) -> Array: - result = np.count_nonzero(x, axis=axis, keepdims=keepdims) +def count_nonzero( + x: Array, + axis: int | tuple[int, ...] | None = None, + keepdims: py_bool = False, +) -> Array: + # NOTE: this is currently incorrectly typed in numpy, but will be fixed in + # numpy 2.2.5 and 2.3.0: https://github.com/numpy/numpy/pull/28750 + result = cast("Any", np.count_nonzero(x, axis=axis, keepdims=keepdims)) # pyright: ignore[reportArgumentType, reportCallIssue] if axis is None and not keepdims: return np.asarray(result) return result @@ -132,25 +142,43 @@ def count_nonzero(x: Array, axis=None, keepdims=False) -> Array: # These functions are completely new here. If the library already has them # (i.e., numpy 2.0), use the library version instead of our wrapper. -if hasattr(np, 'vecdot'): +if hasattr(np, "vecdot"): vecdot = np.vecdot else: vecdot = get_xp(np)(_aliases.vecdot) -if hasattr(np, 'isdtype'): +if hasattr(np, "isdtype"): isdtype = np.isdtype else: isdtype = get_xp(np)(_aliases.isdtype) -if hasattr(np, 'unstack'): +if hasattr(np, "unstack"): unstack = np.unstack else: unstack = get_xp(np)(_aliases.unstack) -__all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', 'astype', - 'acos', 'acosh', 'asin', 'asinh', 'atan', - 'atan2', 'atanh', 'bitwise_left_shift', - 'bitwise_invert', 'bitwise_right_shift', - 'bool', 'concat', 'count_nonzero', 'pow'] - -_all_ignore = ['np', 'get_xp'] +__all__ = [ + "__array_namespace_info__", + "asarray", + "astype", + "acos", + "acosh", + "asin", + "asinh", + "atan", + "atan2", + "atanh", + "bitwise_left_shift", + "bitwise_invert", + "bitwise_right_shift", + "bool", + "concat", + "count_nonzero", + "pow", +] +__all__ += _aliases.__all__ +_all_ignore = ["np", "get_xp"] + + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/numpy/_info.py b/array_api_compat/numpy/_info.py index 365855b8..f307f62c 100644 --- a/array_api_compat/numpy/_info.py +++ b/array_api_compat/numpy/_info.py @@ -7,24 +7,28 @@ more details. """ +from __future__ import annotations + +from numpy import bool_ as bool from numpy import ( + complex64, + complex128, dtype, - bool_ as bool, - intp, + float32, + float64, int8, int16, int32, int64, + intp, uint8, uint16, uint32, uint64, - float32, - float64, - complex64, - complex128, ) +from ._typing import Device, DType + class __array_namespace_info__: """ @@ -131,7 +135,11 @@ def default_device(self): """ return "cpu" - def default_dtypes(self, *, device=None): + def default_dtypes( + self, + *, + device: Device | None = None, + ) -> dict[str, dtype[intp | float64 | complex128]]: """ The default data types used for new NumPy arrays. @@ -183,7 +191,12 @@ def default_dtypes(self, *, device=None): "indexing": dtype(intp), } - def dtypes(self, *, device=None, kind=None): + def dtypes( + self, + *, + device: Device | None = None, + kind: str | tuple[str, ...] | None = None, + ) -> dict[str, DType]: """ The array API data types supported by NumPy. @@ -260,7 +273,7 @@ def dtypes(self, *, device=None, kind=None): "complex128": dtype(complex128), } if kind == "bool": - return {"bool": bool} + return {"bool": dtype(bool)} if kind == "signed integer": return { "int8": dtype(int8), @@ -312,13 +325,13 @@ def dtypes(self, *, device=None, kind=None): "complex128": dtype(complex128), } if isinstance(kind, tuple): - res = {} + res: dict[str, DType] = {} for k in kind: res.update(self.dtypes(kind=k)) return res raise ValueError(f"unsupported kind: {kind!r}") - def devices(self): + def devices(self) -> list[Device]: """ The devices supported by NumPy. @@ -344,3 +357,10 @@ def devices(self): """ return ["cpu"] + + +__all__ = ["__array_namespace_info__"] + + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/numpy/_typing.py b/array_api_compat/numpy/_typing.py index a6c96924..e771c788 100644 --- a/array_api_compat/numpy/_typing.py +++ b/array_api_compat/numpy/_typing.py @@ -1,31 +1,30 @@ from __future__ import annotations -__all__ = ["Array", "DType", "Device"] -_all_ignore = ["np"] - -from typing import Literal, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Literal, TypeAlias import numpy as np -from numpy import ndarray as Array -Device = Literal["cpu"] +Device: TypeAlias = Literal["cpu"] + if TYPE_CHECKING: + # NumPy 1.x on Python 3.10 fails to parse np.dtype[] - DType = np.dtype[ - np.intp - | np.int8 - | np.int16 - | np.int32 - | np.int64 - | np.uint8 - | np.uint16 - | np.uint32 - | np.uint64 + DType: TypeAlias = np.dtype[ + np.bool_ + | np.integer[Any] | np.float32 | np.float64 | np.complex64 | np.complex128 - | np.bool ] + Array: TypeAlias = np.ndarray[Any, DType] else: - DType = np.dtype + DType: TypeAlias = np.dtype + Array: TypeAlias = np.ndarray + +__all__ = ["Array", "DType", "Device"] +_all_ignore = ["np"] + + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/numpy/fft.py b/array_api_compat/numpy/fft.py index 28667594..06875f00 100644 --- a/array_api_compat/numpy/fft.py +++ b/array_api_compat/numpy/fft.py @@ -1,10 +1,9 @@ -from numpy.fft import * # noqa: F403 +import numpy as np from numpy.fft import __all__ as fft_all +from numpy.fft import fft2, ifft2, irfft2, rfft2 -from ..common import _fft from .._internal import get_xp - -import numpy as np +from ..common import _fft fft = get_xp(np)(_fft.fft) ifft = get_xp(np)(_fft.ifft) @@ -21,7 +20,14 @@ fftshift = get_xp(np)(_fft.fftshift) ifftshift = get_xp(np)(_fft.ifftshift) -__all__ = fft_all + _fft.__all__ + +__all__ = ["rfft2", "irfft2", "fft2", "ifft2"] +__all__ += _fft.__all__ + + +def __dir__() -> list[str]: + return __all__ + del get_xp del np diff --git a/array_api_compat/numpy/linalg.py b/array_api_compat/numpy/linalg.py index 8f01593b..2d3e731d 100644 --- a/array_api_compat/numpy/linalg.py +++ b/array_api_compat/numpy/linalg.py @@ -1,14 +1,35 @@ -from numpy.linalg import * # noqa: F403 -from numpy.linalg import __all__ as linalg_all -import numpy as _np +# pyright: reportAttributeAccessIssue=false +# pyright: reportUnknownArgumentType=false +# pyright: reportUnknownMemberType=false +# pyright: reportUnknownVariableType=false + +from __future__ import annotations + +import numpy as np + +# intersection of `np.linalg.__all__` on numpy 1.22 and 2.2, minus `_linalg.__all__` +from numpy.linalg import ( + LinAlgError, + cond, + det, + eig, + eigvals, + eigvalsh, + inv, + lstsq, + matrix_power, + multi_dot, + norm, + tensorinv, + tensorsolve, +) -from ..common import _linalg from .._internal import get_xp +from ..common import _linalg # These functions are in both the main and linalg namespaces -from ._aliases import matmul, matrix_transpose, tensordot, vecdot # noqa: F401 - -import numpy as np +from ._aliases import matmul, matrix_transpose, tensordot, vecdot # noqa: F401 +from ._typing import Array cross = get_xp(np)(_linalg.cross) outer = get_xp(np)(_linalg.outer) @@ -38,19 +59,28 @@ # To workaround this, the below is the code from np.linalg.solve except # only calling solve1 in the exactly 1D case. + # This code is here instead of in common because it is numpy specific. Also # note that CuPy's solve() does not currently support broadcasting (see # https://github.com/cupy/cupy/blob/main/cupy/cublas.py#L43). -def solve(x1: _np.ndarray, x2: _np.ndarray, /) -> _np.ndarray: +def solve(x1: Array, x2: Array, /) -> Array: try: from numpy.linalg._linalg import ( - _makearray, _assert_stacked_2d, _assert_stacked_square, - _commonType, isComplexType, _raise_linalgerror_singular + _assert_stacked_2d, + _assert_stacked_square, + _commonType, + _makearray, + _raise_linalgerror_singular, + isComplexType, ) except ImportError: from numpy.linalg.linalg import ( - _makearray, _assert_stacked_2d, _assert_stacked_square, - _commonType, isComplexType, _raise_linalgerror_singular + _assert_stacked_2d, + _assert_stacked_square, + _commonType, + _makearray, + _raise_linalgerror_singular, + isComplexType, ) from numpy.linalg import _umath_linalg @@ -61,6 +91,7 @@ def solve(x1: _np.ndarray, x2: _np.ndarray, /) -> _np.ndarray: t, result_t = _commonType(x1, x2) # This part is different from np.linalg.solve + gufunc: np.ufunc if x2.ndim == 1: gufunc = _umath_linalg.solve1 else: @@ -68,23 +99,45 @@ def solve(x1: _np.ndarray, x2: _np.ndarray, /) -> _np.ndarray: # This does nothing currently but is left in because it will be relevant # when complex dtype support is added to the spec in 2022. - signature = 'DD->D' if isComplexType(t) else 'dd->d' - with _np.errstate(call=_raise_linalgerror_singular, invalid='call', - over='ignore', divide='ignore', under='ignore'): - r = gufunc(x1, x2, signature=signature) + signature = "DD->D" if isComplexType(t) else "dd->d" + with np.errstate( + call=_raise_linalgerror_singular, + invalid="call", + over="ignore", + divide="ignore", + under="ignore", + ): + r: Array = gufunc(x1, x2, signature=signature) return wrap(r.astype(result_t, copy=False)) + # These functions are completely new here. If the library already has them # (i.e., numpy 2.0), use the library version instead of our wrapper. -if hasattr(np.linalg, 'vector_norm'): +if hasattr(np.linalg, "vector_norm"): vector_norm = np.linalg.vector_norm else: vector_norm = get_xp(np)(_linalg.vector_norm) -__all__ = linalg_all + _linalg.__all__ + ['solve'] -del get_xp -del np -del linalg_all -del _linalg +__all__ = [ + "LinAlgError", + "cond", + "det", + "eig", + "eigvals", + "eigvalsh", + "inv", + "lstsq", + "matrix_power", + "multi_dot", + "norm", + "tensorinv", + "tensorsolve", +] +__all__ += _linalg.__all__ +__all__ += ["solve", "vector_norm"] + + +def __dir__() -> list[str]: + return __all__ From 5e14b53a3558765a8f9b921c72f0249cc0c1c5b9 Mon Sep 17 00:00:00 2001 From: Joren Hammudoglu Date: Sat, 19 Apr 2025 16:08:41 +0200 Subject: [PATCH 26/28] TYP: reject `bool` in the `ord` params of `vector_norm` and `matrix_norm` (#310) * TYP: auto-plagiarize the optypean `Just*` types * TYP: reject `bool` in the `ord` params of `vector_norm` and `matrix_norm` * TYP: remove accidental type alias * TYP: Tighten the `ord` param of `matrix_norm` Co-authored-by: Lucas Colley --------- Co-authored-by: Lucas Colley --- array_api_compat/common/_linalg.py | 6 ++-- array_api_compat/common/_typing.py | 44 +++++++++++++++++++++++++++++- array_api_compat/torch/linalg.py | 8 ++++-- 3 files changed, 52 insertions(+), 6 deletions(-) diff --git a/array_api_compat/common/_linalg.py b/array_api_compat/common/_linalg.py index 7e002aed..7ad87a1b 100644 --- a/array_api_compat/common/_linalg.py +++ b/array_api_compat/common/_linalg.py @@ -12,7 +12,7 @@ from .._internal import get_xp from ._aliases import isdtype, matmul, matrix_transpose, tensordot, vecdot -from ._typing import Array, DType, Namespace +from ._typing import Array, DType, JustFloat, JustInt, Namespace # These are in the main NumPy namespace but not in numpy.linalg @@ -139,7 +139,7 @@ def matrix_norm( xp: Namespace, *, keepdims: bool = False, - ord: float | Literal["fro", "nuc"] | None = "fro", + ord: Literal[1, 2, -1, -2] | JustFloat | Literal["fro", "nuc"] | None = "fro", ) -> Array: return xp.linalg.norm(x, axis=(-2, -1), keepdims=keepdims, ord=ord) @@ -155,7 +155,7 @@ def vector_norm( *, axis: int | tuple[int, ...] | None = None, keepdims: bool = False, - ord: float = 2, + ord: JustInt | JustFloat = 2, ) -> Array: # xp.linalg.norm tries to do a matrix norm whenever axis is a 2-tuple or # when axis=None and the input is 2-D, so to force a vector norm, we make diff --git a/array_api_compat/common/_typing.py b/array_api_compat/common/_typing.py index d7deade1..cd26feeb 100644 --- a/array_api_compat/common/_typing.py +++ b/array_api_compat/common/_typing.py @@ -2,7 +2,15 @@ from collections.abc import Mapping from types import ModuleType as Namespace -from typing import TYPE_CHECKING, Literal, Protocol, TypeAlias, TypedDict, TypeVar +from typing import ( + TYPE_CHECKING, + Literal, + Protocol, + TypeAlias, + TypedDict, + TypeVar, + final, +) if TYPE_CHECKING: from _typeshed import Incomplete @@ -21,6 +29,37 @@ _T_co = TypeVar("_T_co", covariant=True) +# These "Just" types are equivalent to the `Just` type from the `optype` library, +# apart from them not being `@runtime_checkable`. +# - docs: https://github.com/jorenham/optype/blob/master/README.md#just +# - code: https://github.com/jorenham/optype/blob/master/optype/_core/_just.py +@final +class JustInt(Protocol): + @property + def __class__(self, /) -> type[int]: ... + @__class__.setter + def __class__(self, value: type[int], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride] + + +@final +class JustFloat(Protocol): + @property + def __class__(self, /) -> type[float]: ... + @__class__.setter + def __class__(self, value: type[float], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride] + + +@final +class JustComplex(Protocol): + @property + def __class__(self, /) -> type[complex]: ... + @__class__.setter + def __class__(self, value: type[complex], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride] + + +# + + class NestedSequence(Protocol[_T_co]): def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ... def __len__(self, /) -> int: ... @@ -140,6 +179,9 @@ class DTypesAll(DTypesBool, DTypesNumeric): "Device", "HasShape", "Namespace", + "JustInt", + "JustFloat", + "JustComplex", "NestedSequence", "SupportsArrayNamespace", "SupportsBufferProtocol", diff --git a/array_api_compat/torch/linalg.py b/array_api_compat/torch/linalg.py index 1ff7319d..70d72405 100644 --- a/array_api_compat/torch/linalg.py +++ b/array_api_compat/torch/linalg.py @@ -16,6 +16,7 @@ # These functions are in both the main and linalg namespaces from ._aliases import matmul, matrix_transpose, tensordot from ._typing import Array, DType +from ..common._typing import JustInt, JustFloat # Note: torch.linalg.cross does not default to axis=-1 (it defaults to the # first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743 @@ -84,8 +85,8 @@ def vector_norm( *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, - # float stands for inf | -inf, which are not valid for Literal - ord: Union[int, float] = 2, + # JustFloat stands for inf | -inf, which are not valid for Literal + ord: JustInt | JustFloat = 2, **kwargs, ) -> Array: # torch.vector_norm incorrectly treats axis=() the same as axis=None @@ -115,3 +116,6 @@ def vector_norm( _all_ignore = ['torch_linalg', 'sum'] del linalg_all + +def __dir__() -> list[str]: + return __all__ From 52e01beae335c088d25bd6d76f5ae44a231800f5 Mon Sep 17 00:00:00 2001 From: Guido Imperiale Date: Mon, 21 Apr 2025 19:38:53 +0100 Subject: [PATCH 27/28] ENH: cache helper functions (#308) * ENH: cache helper functions --- array_api_compat/common/_helpers.py | 192 ++++++++++++++++------------ 1 file changed, 108 insertions(+), 84 deletions(-) diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index db3e4cd7..d50e0d83 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -12,7 +12,8 @@ import math import sys import warnings -from collections.abc import Collection +from collections.abc import Collection, Hashable +from functools import lru_cache from typing import ( TYPE_CHECKING, Any, @@ -61,23 +62,37 @@ _API_VERSIONS: Final = _API_VERSIONS_OLD | frozenset({"2024.12"}) +@lru_cache(100) +def _issubclass_fast(cls: type, modname: str, clsname: str) -> bool: + try: + mod = sys.modules[modname] + except KeyError: + return False + parent_cls = getattr(mod, clsname) + return issubclass(cls, parent_cls) + + def _is_jax_zero_gradient_array(x: object) -> TypeGuard[_ZeroGradientArray]: """Return True if `x` is a zero-gradient array. These arrays are a design quirk of Jax that may one day be removed. See https://github.com/google/jax/issues/20620. """ - if "numpy" not in sys.modules or "jax" not in sys.modules: + # Fast exit + try: + dtype = x.dtype # type: ignore[attr-defined] + except AttributeError: + return False + cls = cast(Hashable, type(dtype)) + if not _issubclass_fast(cls, "numpy.dtypes", "VoidDType"): return False - import jax - import numpy as np + if "jax" not in sys.modules: + return False - jax_float0 = cast("np.dtype[np.void]", jax.float0) - return ( - isinstance(x, np.ndarray) - and cast("npt.NDArray[np.void]", x).dtype == jax_float0 - ) + import jax + # jax.float0 is a np.dtype([('float0', 'V')]) + return dtype == jax.float0 def is_numpy_array(x: object) -> TypeGuard[npt.NDArray[Any]]: @@ -101,15 +116,12 @@ def is_numpy_array(x: object) -> TypeGuard[npt.NDArray[Any]]: is_jax_array is_pydata_sparse_array """ - # Avoid importing NumPy if it isn't already - if "numpy" not in sys.modules: - return False - - import numpy as np - # TODO: Should we reject ndarray subclasses? - return (isinstance(x, (np.ndarray, np.generic)) - and not _is_jax_zero_gradient_array(x)) # pyright: ignore[reportUnknownArgumentType] # fmt: skip + cls = cast(Hashable, type(x)) + return ( + _issubclass_fast(cls, "numpy", "ndarray") + or _issubclass_fast(cls, "numpy", "generic") + ) and not _is_jax_zero_gradient_array(x) def is_cupy_array(x: object) -> bool: @@ -133,14 +145,8 @@ def is_cupy_array(x: object) -> bool: is_jax_array is_pydata_sparse_array """ - # Avoid importing CuPy if it isn't already - if "cupy" not in sys.modules: - return False - - import cupy as cp # pyright: ignore[reportMissingTypeStubs] - - # TODO: Should we reject ndarray subclasses? - return isinstance(x, cp.ndarray) # pyright: ignore[reportUnknownMemberType] + cls = cast(Hashable, type(x)) + return _issubclass_fast(cls, "cupy", "ndarray") def is_torch_array(x: object) -> TypeIs[torch.Tensor]: @@ -161,14 +167,8 @@ def is_torch_array(x: object) -> TypeIs[torch.Tensor]: is_jax_array is_pydata_sparse_array """ - # Avoid importing torch if it isn't already - if "torch" not in sys.modules: - return False - - import torch - - # TODO: Should we reject ndarray subclasses? - return isinstance(x, torch.Tensor) + cls = cast(Hashable, type(x)) + return _issubclass_fast(cls, "torch", "Tensor") def is_ndonnx_array(x: object) -> TypeIs[ndx.Array]: @@ -190,13 +190,8 @@ def is_ndonnx_array(x: object) -> TypeIs[ndx.Array]: is_jax_array is_pydata_sparse_array """ - # Avoid importing torch if it isn't already - if "ndonnx" not in sys.modules: - return False - - import ndonnx as ndx - - return isinstance(x, ndx.Array) + cls = cast(Hashable, type(x)) + return _issubclass_fast(cls, "ndonnx", "Array") def is_dask_array(x: object) -> TypeIs[da.Array]: @@ -218,13 +213,8 @@ def is_dask_array(x: object) -> TypeIs[da.Array]: is_jax_array is_pydata_sparse_array """ - # Avoid importing dask if it isn't already - if "dask.array" not in sys.modules: - return False - - import dask.array - - return isinstance(x, dask.array.Array) + cls = cast(Hashable, type(x)) + return _issubclass_fast(cls, "dask.array", "Array") def is_jax_array(x: object) -> TypeIs[jax.Array]: @@ -247,13 +237,8 @@ def is_jax_array(x: object) -> TypeIs[jax.Array]: is_dask_array is_pydata_sparse_array """ - # Avoid importing jax if it isn't already - if "jax" not in sys.modules: - return False - - import jax - - return isinstance(x, jax.Array) or _is_jax_zero_gradient_array(x) + cls = cast(Hashable, type(x)) + return _issubclass_fast(cls, "jax", "Array") or _is_jax_zero_gradient_array(x) def is_pydata_sparse_array(x: object) -> TypeIs[sparse.SparseArray]: @@ -276,14 +261,9 @@ def is_pydata_sparse_array(x: object) -> TypeIs[sparse.SparseArray]: is_dask_array is_jax_array """ - # Avoid importing jax if it isn't already - if "sparse" not in sys.modules: - return False - - import sparse # pyright: ignore[reportMissingTypeStubs] - # TODO: Account for other backends. - return isinstance(x, sparse.SparseArray) + cls = cast(Hashable, type(x)) + return _issubclass_fast(cls, "sparse", "SparseArray") def is_array_api_obj(x: object) -> TypeIs[_ArrayApiObj]: # pyright: ignore[reportUnknownParameterType] @@ -302,13 +282,23 @@ def is_array_api_obj(x: object) -> TypeIs[_ArrayApiObj]: # pyright: ignore[repo is_jax_array """ return ( - is_numpy_array(x) - or is_cupy_array(x) - or is_torch_array(x) - or is_dask_array(x) - or is_jax_array(x) - or is_pydata_sparse_array(x) - or hasattr(x, "__array_namespace__") + hasattr(x, '__array_namespace__') + or _is_array_api_cls(cast(Hashable, type(x))) + ) + + +@lru_cache(100) +def _is_array_api_cls(cls: type) -> bool: + return ( + # TODO: drop support for numpy<2 which didn't have __array_namespace__ + _issubclass_fast(cls, "numpy", "ndarray") + or _issubclass_fast(cls, "numpy", "generic") + or _issubclass_fast(cls, "cupy", "ndarray") + or _issubclass_fast(cls, "torch", "Tensor") + or _issubclass_fast(cls, "dask.array", "Array") + or _issubclass_fast(cls, "sparse", "SparseArray") + # TODO: drop support for jax<0.4.32 which didn't have __array_namespace__ + or _issubclass_fast(cls, "jax", "Array") ) @@ -317,6 +307,7 @@ def _compat_module_name() -> str: return __name__.removesuffix(".common._helpers") +@lru_cache(100) def is_numpy_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is a NumPy namespace. @@ -338,6 +329,7 @@ def is_numpy_namespace(xp: Namespace) -> bool: return xp.__name__ in {"numpy", _compat_module_name() + ".numpy"} +@lru_cache(100) def is_cupy_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is a CuPy namespace. @@ -359,6 +351,7 @@ def is_cupy_namespace(xp: Namespace) -> bool: return xp.__name__ in {"cupy", _compat_module_name() + ".cupy"} +@lru_cache(100) def is_torch_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is a PyTorch namespace. @@ -399,6 +392,7 @@ def is_ndonnx_namespace(xp: Namespace) -> bool: return xp.__name__ == "ndonnx" +@lru_cache(100) def is_dask_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is a Dask namespace. @@ -939,6 +933,19 @@ def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None: return None if math.isnan(out) else out +@lru_cache(100) +def _is_writeable_cls(cls: type) -> bool | None: + if ( + _issubclass_fast(cls, "numpy", "generic") + or _issubclass_fast(cls, "jax", "Array") + or _issubclass_fast(cls, "sparse", "SparseArray") + ): + return False + if _is_array_api_cls(cls): + return True + return None + + def is_writeable_array(x: object) -> bool: """ Return False if ``x.__setitem__`` is expected to raise; True otherwise. @@ -949,11 +956,32 @@ def is_writeable_array(x: object) -> bool: As there is no standard way to check if an array is writeable without actually writing to it, this function blindly returns True for all unknown array types. """ - if is_numpy_array(x): - return x.flags.writeable - if is_jax_array(x) or is_pydata_sparse_array(x): + cls = cast(Hashable, type(x)) + if _issubclass_fast(cls, "numpy", "ndarray"): + return cast("npt.NDArray", x).flags.writeable + res = _is_writeable_cls(cls) + if res is not None: + return res + return hasattr(x, '__array_namespace__') + + +@lru_cache(100) +def _is_lazy_cls(cls: type) -> bool | None: + if ( + _issubclass_fast(cls, "numpy", "ndarray") + or _issubclass_fast(cls, "numpy", "generic") + or _issubclass_fast(cls, "cupy", "ndarray") + or _issubclass_fast(cls, "torch", "Tensor") + or _issubclass_fast(cls, "sparse", "SparseArray") + ): return False - return is_array_api_obj(x) + if ( + _issubclass_fast(cls, "jax", "Array") + or _issubclass_fast(cls, "dask.array", "Array") + or _issubclass_fast(cls, "ndonnx", "Array") + ): + return True + return None def is_lazy_array(x: object) -> bool: @@ -969,14 +997,6 @@ def is_lazy_array(x: object) -> bool: This function errs on the side of caution for array types that may or may not be lazy, e.g. JAX arrays, by always returning True for them. """ - if ( - is_numpy_array(x) - or is_cupy_array(x) - or is_torch_array(x) - or is_pydata_sparse_array(x) - ): - return False - # **JAX note:** while it is possible to determine if you're inside or outside # jax.jit by testing the subclass of a jax.Array object, as well as testing bool() # as we do below for unknown arrays, this is not recommended by JAX best practices. @@ -986,10 +1006,14 @@ def is_lazy_array(x: object) -> bool: # compatibility, is highly detrimental to performance as the whole graph will end # up being computed multiple times. - if is_jax_array(x) or is_dask_array(x) or is_ndonnx_array(x): - return True + # Note: skipping reclassification of JAX zero gradient arrays, as one will + # exclusively get them once they leave a jax.grad JIT context. + cls = cast(Hashable, type(x)) + res = _is_lazy_cls(cls) + if res is not None: + return res - if not is_array_api_obj(x): + if not hasattr(x, "__array_namespace__"): return False # Unknown Array API compatible object. Note that this test may have dire consequences @@ -1042,7 +1066,7 @@ def is_lazy_array(x: object) -> bool: "to_device", ] -_all_ignore = ["sys", "math", "inspect", "warnings"] +_all_ignore = ['lru_cache', 'sys', 'math', 'inspect', 'warnings'] def __dir__() -> list[str]: return __all__ From e600449a645c2e6ce5a2276da0006491f097c096 Mon Sep 17 00:00:00 2001 From: Guido Imperiale Date: Thu, 24 Apr 2025 10:09:45 +0100 Subject: [PATCH 28/28] ENH: Simplify CuPy `asarray` and `to_device` (#314) reviewed at https://github.com/data-apis/array-api-compat/pull/314 --- array_api_compat/common/_helpers.py | 48 ++++++++++------------------- array_api_compat/cupy/_aliases.py | 30 +++++------------- cupy-xfails.txt | 3 -- tests/test_common.py | 24 +++++++++------ tests/test_cupy.py | 22 +++++++++++++ 5 files changed, 61 insertions(+), 66 deletions(-) create mode 100644 tests/test_cupy.py diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index d50e0d83..77175d0d 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -775,42 +775,28 @@ def _cupy_to_device( /, stream: int | Any | None = None, ) -> _CupyArray: - import cupy as cp # pyright: ignore[reportMissingTypeStubs] - from cupy.cuda import Device as _Device # pyright: ignore - from cupy.cuda import stream as stream_module # pyright: ignore - from cupy_backends.cuda.api import runtime # pyright: ignore + import cupy as cp - if device == x.device: - return x - elif device == "cpu": + if device == "cpu": # allowing us to use `to_device(x, "cpu")` # is useful for portable test swapping between # host and device backends return x.get() - elif not isinstance(device, _Device): - raise ValueError(f"Unsupported device {device!r}") - else: - # see cupy/cupy#5985 for the reason how we handle device/stream here - prev_device: Any = runtime.getDevice() # pyright: ignore[reportUnknownMemberType] - prev_stream = None - if stream is not None: - prev_stream: Any = stream_module.get_current_stream() # pyright: ignore - # stream can be an int as specified in __dlpack__, or a CuPy stream - if isinstance(stream, int): - stream = cp.cuda.ExternalStream(stream) # pyright: ignore - elif isinstance(stream, cp.cuda.Stream): # pyright: ignore[reportUnknownMemberType] - pass - else: - raise ValueError("the input stream is not recognized") - stream.use() # pyright: ignore[reportUnknownMemberType] - try: - runtime.setDevice(device.id) # pyright: ignore[reportUnknownMemberType] - arr = x.copy() - finally: - runtime.setDevice(prev_device) # pyright: ignore[reportUnknownMemberType] - if stream is not None: - prev_stream.use() - return arr + if not isinstance(device, cp.cuda.Device): + raise TypeError(f"Unsupported device type {device!r}") + + if stream is None: + with device: + return cp.asarray(x) + + # stream can be an int as specified in __dlpack__, or a CuPy stream + if isinstance(stream, int): + stream = cp.cuda.ExternalStream(stream) + elif not isinstance(stream, cp.cuda.Stream): + raise TypeError(f"Unsupported stream type {stream!r}") + + with device, stream: + return cp.asarray(x) def _torch_to_device( diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index fd1460ae..adb74bff 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -64,8 +64,6 @@ finfo = get_xp(cp)(_aliases.finfo) iinfo = get_xp(cp)(_aliases.iinfo) -_copy_default = object() - # asarray also adds the copy keyword, which is not present in numpy 1.0. def asarray( @@ -79,7 +77,7 @@ def asarray( *, dtype: Optional[DType] = None, device: Optional[Device] = None, - copy: Optional[bool] = _copy_default, + copy: Optional[bool] = None, **kwargs, ) -> Array: """ @@ -89,25 +87,13 @@ def asarray( specification for more details. """ with cp.cuda.Device(device): - # cupy is like NumPy 1.26 (except without _CopyMode). See the comments - # in asarray in numpy/_aliases.py. - if copy is not _copy_default: - # A future version of CuPy will change the meaning of copy=False - # to mean no-copy. We don't know for certain what version it will - # be yet, so to avoid breaking that version, we use a different - # default value for copy so asarray(obj) with no copy kwarg will - # always do the copy-if-needed behavior. - - # This will still need to be updated to remove the - # NotImplementedError for copy=False, but at least this won't - # break the default or existing behavior. - if copy is None: - copy = False - elif copy is False: - raise NotImplementedError("asarray(copy=False) is not yet supported in cupy") - kwargs['copy'] = copy - - return cp.array(obj, dtype=dtype, **kwargs) + if copy is None: + return cp.asarray(obj, dtype=dtype, **kwargs) + else: + res = cp.array(obj, dtype=dtype, copy=copy, **kwargs) + if not copy and res is not obj: + raise ValueError("Unable to avoid copy while creating an array as requested") + return res def astype( diff --git a/cupy-xfails.txt b/cupy-xfails.txt index a30572f8..df85d9ca 100644 --- a/cupy-xfails.txt +++ b/cupy-xfails.txt @@ -11,9 +11,6 @@ array_api_tests/test_array_object.py::test_scalar_casting[__index__(int64)] # testsuite bug (https://github.com/data-apis/array-api-tests/issues/172) array_api_tests/test_array_object.py::test_getitem -# copy=False is not yet implemented -array_api_tests/test_creation_functions.py::test_asarray_arrays - # attributes are np.float32 instead of float # (see also https://github.com/data-apis/array-api/issues/405) array_api_tests/test_data_type_functions.py::test_finfo[float32] diff --git a/tests/test_common.py b/tests/test_common.py index 6b1aa160..d1933899 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -17,6 +17,7 @@ from array_api_compat import ( device, is_array_api_obj, is_lazy_array, is_writeable_array, size, to_device ) +from array_api_compat.common._helpers import _DASK_DEVICE from ._helpers import all_libraries, import_, wrapped_libraries, xfail @@ -189,23 +190,26 @@ class C: @pytest.mark.parametrize("library", all_libraries) -def test_device(library, request): +def test_device_to_device(library, request): if library == "ndonnx": - xfail(request, reason="Needs ndonnx >=0.9.4") + xfail(request, reason="Stub raises ValueError") + if library == "sparse": + xfail(request, reason="No __array_namespace_info__()") xp = import_(library, wrapper=True) + devices = xp.__array_namespace_info__().devices() - # We can't test much for device() and to_device() other than that - # x.to_device(x.device) works. - + # Default device x = xp.asarray([1, 2, 3]) dev = device(x) - x2 = to_device(x, dev) - assert device(x2) == device(x) - - x3 = xp.asarray(x, device=dev) - assert device(x3) == device(x) + for dev in devices: + if dev is None: # JAX >=0.5.3 + continue + if dev is _DASK_DEVICE: # TODO this needs a better design + continue + y = to_device(x, dev) + assert device(y) == dev @pytest.mark.parametrize("library", wrapped_libraries) diff --git a/tests/test_cupy.py b/tests/test_cupy.py new file mode 100644 index 00000000..f8b4a4d8 --- /dev/null +++ b/tests/test_cupy.py @@ -0,0 +1,22 @@ +import pytest +from array_api_compat import device, to_device + +xp = pytest.importorskip("array_api_compat.cupy") +from cupy.cuda import Stream + + +def test_to_device_with_stream(): + devices = xp.__array_namespace_info__().devices() + streams = [ + Stream(), + Stream(non_blocking=True), + Stream(null=True), + Stream(ptds=True), + 123, # dlpack stream + ] + + a = xp.asarray([1, 2, 3]) + for dev in devices: + for stream in streams: + b = to_device(a, dev, stream=stream) + assert device(b) == dev