From cac1443942212f872909fc6b6f2b610d8a5a6b84 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Tue, 4 Mar 2025 10:22:14 +0100 Subject: [PATCH 01/80] DOC: add a changelog for the 1.11.1 release --- docs/changelog.md | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/docs/changelog.md b/docs/changelog.md index 1de11606..bdf5f9e1 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -1,6 +1,27 @@ # Changelog -## 1.11.0 (2025-XX-XX) +## 1.11.1 (2025-03-04) + +This is a bugfix release with no new features compared to version 1.11. + +### Major Changes + +- fix `count_nonzero` wrappers: work around the lack of the `keepdims` argument in + several array libraries (torch, dask, cupy); work around numpy returning python + ints in for some input combinations. + +### Minor Changes + +- runnings self-tests does not require all array libraries. Missing libraries are + skipped. + +The following users contributed to this release: + +Evgeni Burovski +Guido Imperiale + + +## 1.11.0 (2025-02-27) ### Major Changes From e14754ba0fe4c4cd51b6f45bb11a3c6609be3b5c Mon Sep 17 00:00:00 2001 From: Guido Imperiale Date: Wed, 5 Mar 2025 10:04:21 +0000 Subject: [PATCH 02/80] BUG: `clip(out=...)` is broken (#261) reviewed at https://github.com/data-apis/array-api-compat/pull/261 --- array_api_compat/common/_aliases.py | 26 +++++++++++++------------- tests/test_common.py | 15 +++++++++++++++ 2 files changed, 28 insertions(+), 13 deletions(-) diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 98b8e425..d7e8ef2d 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -12,7 +12,7 @@ from typing import NamedTuple import inspect -from ._helpers import array_namespace, _check_device, device, is_torch_array, is_cupy_namespace +from ._helpers import array_namespace, _check_device, device, is_cupy_namespace # These functions are modified from the NumPy versions. @@ -368,23 +368,23 @@ def _isscalar(a): if type(max) is int and max >= wrapped_xp.iinfo(x.dtype).max: max = None + dev = device(x) if out is None: - out = wrapped_xp.asarray(xp.broadcast_to(x, result_shape), - copy=True, device=device(x)) + out = wrapped_xp.empty(result_shape, dtype=x.dtype, device=dev) + out[()] = x + if min is not None: - if is_torch_array(x) and x.dtype == xp.float64 and _isscalar(min): - # Avoid loss of precision due to torch defaulting to float32 - min = wrapped_xp.asarray(min, dtype=xp.float64) - a = xp.broadcast_to(wrapped_xp.asarray(min, device=device(x)), result_shape) + a = wrapped_xp.asarray(min, dtype=x.dtype, device=dev) + a = xp.broadcast_to(a, result_shape) ia = (out < a) | xp.isnan(a) - # torch requires an explicit cast here - out[ia] = wrapped_xp.astype(a[ia], out.dtype) + out[ia] = a[ia] + if max is not None: - if is_torch_array(x) and x.dtype == xp.float64 and _isscalar(max): - max = wrapped_xp.asarray(max, dtype=xp.float64) - b = xp.broadcast_to(wrapped_xp.asarray(max, device=device(x)), result_shape) + b = wrapped_xp.asarray(max, dtype=x.dtype, device=dev) + b = xp.broadcast_to(b, result_shape) ib = (out > b) | xp.isnan(b) - out[ib] = wrapped_xp.astype(b[ib], out.dtype) + out[ib] = b[ib] + # Return a scalar for 0-D return out[()] diff --git a/tests/test_common.py b/tests/test_common.py index 32876e69..f86e0936 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -367,3 +367,18 @@ def test_asarray_copy(library): assert all(b[0] == 1.0) else: assert all(b[0] == 0.0) + + +@pytest.mark.parametrize("library", ["numpy", "cupy", "torch"]) +def test_clip_out(library): + """Test non-standard out= parameter for clip() + + (see "Avoid Restricting Behavior that is Outside the Scope of the Standard" + in https://data-apis.org/array-api-compat/dev/special-considerations.html) + """ + xp = import_(library, wrapper=True) + x = xp.asarray([10, 20, 30]) + out = xp.zeros_like(x) + xp.clip(x, 15, 25, out=out) + expect = xp.asarray([15, 20, 25]) + assert xp.all(out == expect) From 3f14b184dbbd47e81d1d47514c7b5a2772969b81 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sat, 15 Mar 2025 12:06:31 +0000 Subject: [PATCH 03/80] add torch xfails for scalars in binary functions --- torch-xfails.txt | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/torch-xfails.txt b/torch-xfails.txt index 2899bdb3..6e8f7dc6 100644 --- a/torch-xfails.txt +++ b/torch-xfails.txt @@ -124,7 +124,6 @@ array_api_tests/test_signatures.py::test_func_signature[from_dlpack] # Argument 'max_version' missing from signature array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] - # 2024.12 support array_api_tests/test_signatures.py::test_func_signature[bitwise_and] array_api_tests/test_signatures.py::test_func_signature[bitwise_left_shift] @@ -136,3 +135,22 @@ array_api_tests/test_signatures.py::test_array_method_signature[__lshift__] array_api_tests/test_signatures.py::test_array_method_signature[__or__] array_api_tests/test_signatures.py::test_array_method_signature[__rshift__] array_api_tests/test_signatures.py::test_array_method_signature[__xor__] + +# 2024.12 support: binary functions reject python scalar arguments +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[atan2] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[copysign] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[divide] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[hypot] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[logaddexp] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[maximum] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[minimum] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[equal] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[neq] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[less] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[les_equal] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[greater] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[greater_equal] + +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_bool[logical_and] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_bool[logical_or] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_bool[logical_xor] From 05ade6738542ff556da4bbaa6ad4acd29290d989 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Magnus=20Dalen=20Kvalev=C3=A5g?= Date: Mon, 17 Mar 2025 17:20:59 +0100 Subject: [PATCH 04/80] Fix clipping float with python int for min and max --- array_api_compat/common/_aliases.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index d7e8ef2d..35262d3a 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -363,10 +363,11 @@ def _isscalar(a): # At least handle the case of Python integers correctly (see # https://github.com/numpy/numpy/pull/26892). - if type(min) is int and min <= wrapped_xp.iinfo(x.dtype).min: - min = None - if type(max) is int and max >= wrapped_xp.iinfo(x.dtype).max: - max = None + if wrapped_xp.isdtype(x.dtype, "integral"): + if type(min) is int and min <= wrapped_xp.iinfo(x.dtype).min: + min = None + if type(max) is int and max >= wrapped_xp.iinfo(x.dtype).max: + max = None dev = device(x) if out is None: From 58d8037f372113dfc4d0b36b3740a8b34ed85c7f Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 17 Mar 2025 19:29:07 +0100 Subject: [PATCH 05/80] BUG: torch: fix `result_type` with python scalars 1. Allow inputs to be arrays or dtypes or python scalars 2. Keep the pytorch-specific additions, e.g. `result_type(int, float) -> float`, `result_type(scalar, scalar) -> dtype` which are unspecified in the standard 3. Since pytorch only defines a binary `result_type` function, add a version with multiple inputs. The latter is a bit tricky because we want to - keep allowing "unspecified" behaviors - keep standard-allowed promotions compliant - (preferably) make result_type independent on the argument order The latter is important because of `int,float->float` promotions which break associativity. So what we do, we always promote all scalars after all array/dtype arguments. --- array_api_compat/torch/_aliases.py | 45 ++++++++++---- tests/test_torch.py | 98 ++++++++++++++++++++++++++++++ 2 files changed, 130 insertions(+), 13 deletions(-) create mode 100644 tests/test_torch.py diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index b4786320..4b727f1c 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -1,6 +1,6 @@ from __future__ import annotations -from functools import wraps as _wraps +from functools import reduce as _reduce, wraps as _wraps from builtins import all as _builtin_all, any as _builtin_any from ..common import _aliases @@ -124,25 +124,43 @@ def _fix_promotion(x1, x2, only_scalar=True): def result_type(*arrays_and_dtypes: Union[array, Dtype, bool, int, float, complex]) -> Dtype: - if len(arrays_and_dtypes) == 0: - raise TypeError("At least one array or dtype must be provided") - if len(arrays_and_dtypes) == 1: + num = len(arrays_and_dtypes) + + if num == 0: + raise ValueError("At least one array or dtype must be provided") + + elif num == 1: x = arrays_and_dtypes[0] if isinstance(x, torch.dtype): return x return x.dtype - if len(arrays_and_dtypes) > 2: - return result_type(arrays_and_dtypes[0], result_type(*arrays_and_dtypes[1:])) - x, y = arrays_and_dtypes - if isinstance(x, _py_scalars) or isinstance(y, _py_scalars): - return torch.result_type(x, y) + if num == 2: + x, y = arrays_and_dtypes + return _result_type(x, y) + + else: + # sort scalars so that they are treated last + scalars, others = [], [] + for x in arrays_and_dtypes: + if isinstance(x, _py_scalars): + scalars.append(x) + else: + others.append(x) + if not others: + raise ValueError("At least one array or dtype must be provided") + + # combine left-to-right + return _reduce(_result_type, others + scalars) - xdt = x.dtype if not isinstance(x, torch.dtype) else x - ydt = y.dtype if not isinstance(y, torch.dtype) else y - if (xdt, ydt) in _promotion_table: - return _promotion_table[xdt, ydt] +def _result_type(x, y): + if not (isinstance(x, _py_scalars) or isinstance(y, _py_scalars)): + xdt = x.dtype if not isinstance(x, torch.dtype) else x + ydt = y.dtype if not isinstance(y, torch.dtype) else y + + if (xdt, ydt) in _promotion_table: + return _promotion_table[xdt, ydt] # This doesn't result_type(dtype, dtype) for non-array API dtypes # because torch.result_type only accepts tensors. This does however, allow @@ -151,6 +169,7 @@ def result_type(*arrays_and_dtypes: Union[array, Dtype, bool, int, float, comple y = torch.tensor([], dtype=y) if isinstance(y, torch.dtype) else y return torch.result_type(x, y) + def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool: if not isinstance(from_, torch.dtype): from_ = from_.dtype diff --git a/tests/test_torch.py b/tests/test_torch.py new file mode 100644 index 00000000..75b3a136 --- /dev/null +++ b/tests/test_torch.py @@ -0,0 +1,98 @@ +"""Test "unspecified" behavior which we cannot easily test in the Array API test suite. +""" +import itertools + +import pytest +import torch + +from array_api_compat import torch as xp + + +class TestResultType: + def test_empty(self): + with pytest.raises(ValueError): + xp.result_type() + + def test_one_arg(self): + for x in [1, 1.0, 1j, '...', None]: + with pytest.raises((ValueError, AttributeError)): + xp.result_type(x) + + for x in [xp.float32, xp.int64, torch.complex64]: + assert xp.result_type(x) == x + + for x in [xp.asarray(True, dtype=xp.bool), xp.asarray(1, dtype=xp.complex64)]: + assert xp.result_type(x) == x.dtype + + def test_two_args(self): + # Only include here things "unspecified" in the spec + + # scalar, tensor or tensor,tensor + for x, y in [ + (1., 1j), + (1j, xp.arange(3)), + (True, xp.asarray(3.)), + (xp.ones(3) == 1, 1j*xp.ones(3)), + ]: + assert xp.result_type(x, y) == torch.result_type(x, y) + + # dtype, scalar + for x, y in [ + (1j, xp.int64), + (True, xp.float64), + ]: + assert xp.result_type(x, y) == torch.result_type(x, xp.empty([], dtype=y)) + + # dtype, dtype + for x, y in [ + (xp.bool, xp.complex64) + ]: + xt, yt = xp.empty([], dtype=x), xp.empty([], dtype=y) + assert xp.result_type(x, y) == torch.result_type(xt, yt) + + def test_multi_arg(self): + torch.set_default_dtype(torch.float32) + + args = [1., 5, 3, torch.asarray([3], dtype=torch.float16), 5, 6, 1.] + assert xp.result_type(*args) == torch.float16 + + args = [1, 2, 3j, xp.arange(3, dtype=xp.float32), 4, 5, 6] + assert xp.result_type(*args) == xp.complex64 + + args = [1, 2, 3j, xp.float64, 4, 5, 6] + assert xp.result_type(*args) == xp.complex128 + + args = [1, 2, 3j, xp.float64, 4, xp.asarray(3, dtype=xp.int16), 5, 6, False] + assert xp.result_type(*args) == xp.complex128 + + i64 = xp.ones(1, dtype=xp.int64) + f16 = xp.ones(1, dtype=xp.float16) + for i in itertools.permutations([i64, f16, 1.0, 1.0]): + assert xp.result_type(*i) == xp.float16, f"{i}" + + with pytest.raises(ValueError): + xp.result_type(1, 2, 3, 4) + + + @pytest.mark.parametrize("default_dt", ['float32', 'float64']) + @pytest.mark.parametrize("dtype_a", + (xp.int32, xp.int64, xp.float32, xp.float64, xp.complex64, xp.complex128) + ) + @pytest.mark.parametrize("dtype_b", + (xp.int32, xp.int64, xp.float32, xp.float64, xp.complex64, xp.complex128) + ) + def test_gh_273(self, default_dt, dtype_a, dtype_b): + # Regression test for https://github.com/data-apis/array-api-compat/issues/273 + + try: + prev_default = torch.get_default_dtype() + default_dtype = getattr(torch, default_dt) + torch.set_default_dtype(default_dtype) + + a = xp.asarray([2, 1], dtype=dtype_a) + b = xp.asarray([1, -1], dtype=dtype_b) + dtype_1 = xp.result_type(a, b, 1.0) + dtype_2 = xp.result_type(b, a, 1.0) + assert dtype_1 == dtype_2 + finally: + torch.set_default_dtype(prev_default) From 5473d84d5c36b23e091b880279c863c32f41b828 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Wed, 19 Mar 2025 16:28:49 +0100 Subject: [PATCH 06/80] TST: skip test_all --- tests/test_all.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_all.py b/tests/test_all.py index 10a2a95d..d2e9b768 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -16,6 +16,7 @@ import pytest +@pytest.mark.skip(reason="TODO: starts failing after adding test_torch.py in gh-277") @pytest.mark.parametrize("library", ["common"] + wrapped_libraries) def test_all(library): if library == "common": From c9622f965be76e947f2a7d3ecac827a90e67edfb Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Thu, 20 Mar 2025 13:11:49 +0100 Subject: [PATCH 07/80] DOC: add a changelog for 1.11.2 release --- docs/changelog.md | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/docs/changelog.md b/docs/changelog.md index bdf5f9e1..18928e98 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -1,5 +1,22 @@ # Changelog +## 1.11.2 (2025-03-20) + +This is a bugfix release with no new features compared to version 1.11. + +- fix the `result_type` wrapper for pytorch. Previously, `result_type` had multiple + issues with scalar arguments. +- fix several issues with `clip` wrappers. Previously, `clip` was failing to allow + behaviors which are unspecified by the 2024.12 standard but allowed by the array + libraries. + +The following users contributed to this release: + +Evgeni Burovski +Guido Imperiale +Magnus Dalen Kvalevåg + + ## 1.11.1 (2025-03-04) This is a bugfix release with no new features compared to version 1.11. From b8323760865a66ad03114ad80c1aa058df28dc98 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Thu, 20 Mar 2025 13:17:22 +0100 Subject: [PATCH 08/80] BLD: upper cap setuptools, do not error on deprecationwarnings --- .github/workflows/publish-package.yml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/publish-package.yml b/.github/workflows/publish-package.yml index 7733059d..6d88066d 100644 --- a/.github/workflows/publish-package.yml +++ b/.github/workflows/publish-package.yml @@ -41,13 +41,14 @@ jobs: - name: Install python-build and twine run: | - python -m pip install --upgrade pip setuptools + python -m pip install --upgrade pip "setuptools<=67" python -m pip install build twine python -m pip list - name: Build a wheel and a sdist run: | - PYTHONWARNINGS=error,default::DeprecationWarning python -m build . + #PYTHONWARNINGS=error,default::DeprecationWarning python -m build . + python -m build . - name: Verify the distribution run: twine check --strict dist/* From 1b0de51538deb7c21d0c268f36764a8589e40012 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Thu, 20 Mar 2025 13:33:39 +0100 Subject: [PATCH 09/80] REL: bump the version to 1.11.2 --- array_api_compat/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_compat/__init__.py b/array_api_compat/__init__.py index 60b37e97..96b061e7 100644 --- a/array_api_compat/__init__.py +++ b/array_api_compat/__init__.py @@ -17,6 +17,6 @@ this implementation for the default when working with NumPy arrays. """ -__version__ = '1.12.dev0' +__version__ = '1.11.2' from .common import * # noqa: F401, F403 From b1316cff516d147519a9c30f0e8327e5895598f4 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Thu, 20 Mar 2025 16:47:38 +0100 Subject: [PATCH 10/80] TST: skip tests of binary funcs w/scalar on older numpies NumPy < 2 fails to promote an empty f32 array with a scalar, returns an empty f64 array --- numpy-1-21-xfails.txt | 3 +++ numpy-1-26-xfails.txt | 3 +++ 2 files changed, 6 insertions(+) diff --git a/numpy-1-21-xfails.txt b/numpy-1-21-xfails.txt index 28c0e13a..7c7a0757 100644 --- a/numpy-1-21-xfails.txt +++ b/numpy-1-21-xfails.txt @@ -212,3 +212,6 @@ array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] + +# numpy < 2 bug: type promotion of asarray([], 'float32') and (np.finfo(float32).max + 1) -> float64 +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real diff --git a/numpy-1-26-xfails.txt b/numpy-1-26-xfails.txt index 80790534..57259b6f 100644 --- a/numpy-1-26-xfails.txt +++ b/numpy-1-26-xfails.txt @@ -66,3 +66,6 @@ array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] + +# numpy < 2 bug: type promotion of asarray([], 'float32') and (finfo(float32).max + 1) gives float64 not float32 +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real From 64ab7e26b86d0cd2d4cb544fdd39699a887823e8 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Fri, 21 Mar 2025 01:27:29 +0100 Subject: [PATCH 11/80] MAINT: update the version for 1.12.dev0 development --- array_api_compat/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_compat/__init__.py b/array_api_compat/__init__.py index 96b061e7..60b37e97 100644 --- a/array_api_compat/__init__.py +++ b/array_api_compat/__init__.py @@ -17,6 +17,6 @@ this implementation for the default when working with NumPy arrays. """ -__version__ = '1.11.2' +__version__ = '1.12.dev0' from .common import * # noqa: F401, F403 From 0080afed5b110c311cb88314d0370a2a3fcbefef Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Fri, 21 Mar 2025 12:02:54 +0100 Subject: [PATCH 12/80] Add a CuPy xfail CuPy 13.x follows NumPy 1.x without "weak scalars". In NumPy `result_type(int32, uint8, 1) != result_type(int32, uint8)` has been fixed in 2.x (or 1.x with set_promotion_state("weak"), so hopefully CuPy 14.x follows the suite, when released. Until then, just xfail the test. --- cupy-xfails.txt | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cupy-xfails.txt b/cupy-xfails.txt index 63e844cd..3d20d745 100644 --- a/cupy-xfails.txt +++ b/cupy-xfails.txt @@ -183,7 +183,7 @@ array_api_tests/test_manipulation_functions.py::test_repeat array_api_tests/test_signatures.py::test_func_signature[from_dlpack] array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] -+# 2024.12 support +# 2024.12 support array_api_tests/test_signatures.py::test_func_signature[count_nonzero] array_api_tests/test_signatures.py::test_func_signature[bitwise_and] array_api_tests/test_signatures.py::test_func_signature[bitwise_left_shift] @@ -192,3 +192,5 @@ array_api_tests/test_signatures.py::test_func_signature[bitwise_right_shift] array_api_tests/test_signatures.py::test_func_signature[bitwise_xor] array_api_tests/test_special_cases.py::test_binary[nextafter(x1_i is +0 and x2_i is -0) -> -0] +# cupy 13.x follows numpy 1.x w/o weak promotion: result_type(int32, uint8, 1) != result_type(int32, uint8) +array_api_tests/test_data_type_functions.py::TestResultType::test_with_scalars From a5a1d8ba722da9b8a2783ccd63c0b60713932793 Mon Sep 17 00:00:00 2001 From: Guido Imperiale Date: Sat, 22 Mar 2025 17:34:57 +0000 Subject: [PATCH 13/80] TYP: Type annotations overhaul, part 1 (#257) * ENH: Type annotations overhaul * Re-add py.typed * code review * lint * asarray * fill_value * result_type * lint * Arrays don't need to support buffer protocol * bool is a subclass of int * reshape: copy kwarg is keyword-only * tensordot formatting * Reinstate explicit bool | complex --- array_api_compat/common/_aliases.py | 248 +++++++++++++----------- array_api_compat/common/_fft.py | 87 +++++---- array_api_compat/common/_helpers.py | 32 +-- array_api_compat/common/_linalg.py | 84 +++++--- array_api_compat/common/_typing.py | 16 +- array_api_compat/cupy/_aliases.py | 36 ++-- array_api_compat/cupy/_typing.py | 63 +++--- array_api_compat/dask/array/_aliases.py | 54 ++---- array_api_compat/dask/array/fft.py | 13 +- array_api_compat/dask/array/linalg.py | 25 +-- array_api_compat/numpy/_aliases.py | 41 ++-- array_api_compat/numpy/_typing.py | 63 +++--- array_api_compat/py.typed | 0 array_api_compat/torch/_aliases.py | 168 ++++++++-------- array_api_compat/torch/_typing.py | 4 + array_api_compat/torch/fft.py | 35 ++-- array_api_compat/torch/linalg.py | 28 ++- setup.py | 5 +- tests/test_all.py | 17 +- 19 files changed, 511 insertions(+), 508 deletions(-) create mode 100644 array_api_compat/py.typed create mode 100644 array_api_compat/torch/_typing.py diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 35262d3a..0d123b99 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -4,15 +4,11 @@ from __future__ import annotations -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from typing import Optional, Sequence, Tuple, Union - from ._typing import ndarray, Device, Dtype - -from typing import NamedTuple import inspect +from typing import NamedTuple, Optional, Sequence, Tuple, Union from ._helpers import array_namespace, _check_device, device, is_cupy_namespace +from ._typing import Array, Device, DType, Namespace # These functions are modified from the NumPy versions. @@ -24,29 +20,34 @@ def arange( stop: Optional[Union[int, float]] = None, step: Union[int, float] = 1, *, - xp, - dtype: Optional[Dtype] = None, + xp: Namespace, + dtype: Optional[DType] = None, device: Optional[Device] = None, - **kwargs -) -> ndarray: + **kwargs, +) -> Array: _check_device(xp, device) return xp.arange(start, stop=stop, step=step, dtype=dtype, **kwargs) def empty( shape: Union[int, Tuple[int, ...]], - xp, + xp: Namespace, *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, - **kwargs -) -> ndarray: + **kwargs, +) -> Array: _check_device(xp, device) return xp.empty(shape, dtype=dtype, **kwargs) def empty_like( - x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, - **kwargs -) -> ndarray: + x: Array, + /, + xp: Namespace, + *, + dtype: Optional[DType] = None, + device: Optional[Device] = None, + **kwargs, +) -> Array: _check_device(xp, device) return xp.empty_like(x, dtype=dtype, **kwargs) @@ -55,37 +56,37 @@ def eye( n_cols: Optional[int] = None, /, *, - xp, + xp: Namespace, k: int = 0, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, **kwargs, -) -> ndarray: +) -> Array: _check_device(xp, device) return xp.eye(n_rows, M=n_cols, k=k, dtype=dtype, **kwargs) def full( shape: Union[int, Tuple[int, ...]], - fill_value: Union[int, float], - xp, + fill_value: complex, + xp: Namespace, *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, **kwargs, -) -> ndarray: +) -> Array: _check_device(xp, device) return xp.full(shape, fill_value, dtype=dtype, **kwargs) def full_like( - x: ndarray, + x: Array, /, - fill_value: Union[int, float], + fill_value: complex, *, - xp, - dtype: Optional[Dtype] = None, + xp: Namespace, + dtype: Optional[DType] = None, device: Optional[Device] = None, **kwargs, -) -> ndarray: +) -> Array: _check_device(xp, device) return xp.full_like(x, fill_value, dtype=dtype, **kwargs) @@ -95,48 +96,58 @@ def linspace( /, num: int, *, - xp, - dtype: Optional[Dtype] = None, + xp: Namespace, + dtype: Optional[DType] = None, device: Optional[Device] = None, endpoint: bool = True, **kwargs, -) -> ndarray: +) -> Array: _check_device(xp, device) return xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint, **kwargs) def ones( shape: Union[int, Tuple[int, ...]], - xp, + xp: Namespace, *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, **kwargs, -) -> ndarray: +) -> Array: _check_device(xp, device) return xp.ones(shape, dtype=dtype, **kwargs) def ones_like( - x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, + x: Array, + /, + xp: Namespace, + *, + dtype: Optional[DType] = None, + device: Optional[Device] = None, **kwargs, -) -> ndarray: +) -> Array: _check_device(xp, device) return xp.ones_like(x, dtype=dtype, **kwargs) def zeros( shape: Union[int, Tuple[int, ...]], - xp, + xp: Namespace, *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, **kwargs, -) -> ndarray: +) -> Array: _check_device(xp, device) return xp.zeros(shape, dtype=dtype, **kwargs) def zeros_like( - x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, + x: Array, + /, + xp: Namespace, + *, + dtype: Optional[DType] = None, + device: Optional[Device] = None, **kwargs, -) -> ndarray: +) -> Array: _check_device(xp, device) return xp.zeros_like(x, dtype=dtype, **kwargs) @@ -150,23 +161,23 @@ def zeros_like( # Note that these named tuples aren't actually part of the standard namespace, # but I don't see any issue with exporting the names here regardless. class UniqueAllResult(NamedTuple): - values: ndarray - indices: ndarray - inverse_indices: ndarray - counts: ndarray + values: Array + indices: Array + inverse_indices: Array + counts: Array class UniqueCountsResult(NamedTuple): - values: ndarray - counts: ndarray + values: Array + counts: Array class UniqueInverseResult(NamedTuple): - values: ndarray - inverse_indices: ndarray + values: Array + inverse_indices: Array -def _unique_kwargs(xp): +def _unique_kwargs(xp: Namespace) -> dict[str, bool]: # Older versions of NumPy and CuPy do not have equal_nan. Rather than # trying to parse version numbers, just check if equal_nan is in the # signature. @@ -175,7 +186,7 @@ def _unique_kwargs(xp): return {'equal_nan': False} return {} -def unique_all(x: ndarray, /, xp) -> UniqueAllResult: +def unique_all(x: Array, /, xp: Namespace) -> UniqueAllResult: kwargs = _unique_kwargs(xp) values, indices, inverse_indices, counts = xp.unique( x, @@ -195,7 +206,7 @@ def unique_all(x: ndarray, /, xp) -> UniqueAllResult: ) -def unique_counts(x: ndarray, /, xp) -> UniqueCountsResult: +def unique_counts(x: Array, /, xp: Namespace) -> UniqueCountsResult: kwargs = _unique_kwargs(xp) res = xp.unique( x, @@ -208,7 +219,7 @@ def unique_counts(x: ndarray, /, xp) -> UniqueCountsResult: return UniqueCountsResult(*res) -def unique_inverse(x: ndarray, /, xp) -> UniqueInverseResult: +def unique_inverse(x: Array, /, xp: Namespace) -> UniqueInverseResult: kwargs = _unique_kwargs(xp) values, inverse_indices = xp.unique( x, @@ -223,7 +234,7 @@ def unique_inverse(x: ndarray, /, xp) -> UniqueInverseResult: return UniqueInverseResult(values, inverse_indices) -def unique_values(x: ndarray, /, xp) -> ndarray: +def unique_values(x: Array, /, xp: Namespace) -> Array: kwargs = _unique_kwargs(xp) return xp.unique( x, @@ -236,42 +247,42 @@ def unique_values(x: ndarray, /, xp) -> ndarray: # These functions have different keyword argument names def std( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, - correction: Union[int, float] = 0.0, # correction instead of ddof + correction: Union[int, float] = 0.0, # correction instead of ddof keepdims: bool = False, **kwargs, -) -> ndarray: +) -> Array: return xp.std(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs) def var( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, - correction: Union[int, float] = 0.0, # correction instead of ddof + correction: Union[int, float] = 0.0, # correction instead of ddof keepdims: bool = False, **kwargs, -) -> ndarray: +) -> Array: return xp.var(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs) # cumulative_sum is renamed from cumsum, and adds the include_initial keyword # argument def cumulative_sum( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, axis: Optional[int] = None, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, include_initial: bool = False, - **kwargs -) -> ndarray: + **kwargs, +) -> Array: wrapped_xp = array_namespace(x) # TODO: The standard is not clear about what should happen when x.ndim == 0. @@ -294,15 +305,15 @@ def cumulative_sum( def cumulative_prod( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, axis: Optional[int] = None, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, include_initial: bool = False, - **kwargs -) -> ndarray: + **kwargs, +) -> Array: wrapped_xp = array_namespace(x) if axis is None: @@ -325,17 +336,18 @@ def cumulative_prod( # The min and max argument names in clip are different and not optional in numpy, and type # promotion behavior is different. def clip( - x: ndarray, + x: Array, /, - min: Optional[Union[int, float, ndarray]] = None, - max: Optional[Union[int, float, ndarray]] = None, + min: Optional[Union[int, float, Array]] = None, + max: Optional[Union[int, float, Array]] = None, *, - xp, + xp: Namespace, # TODO: np.clip has other ufunc kwargs - out: Optional[ndarray] = None, -) -> ndarray: + out: Optional[Array] = None, +) -> Array: def _isscalar(a): return isinstance(a, (int, float, type(None))) + min_shape = () if _isscalar(min) else min.shape max_shape = () if _isscalar(max) else max.shape @@ -390,15 +402,19 @@ def _isscalar(a): return out[()] # Unlike transpose(), the axes argument to permute_dims() is required. -def permute_dims(x: ndarray, /, axes: Tuple[int, ...], xp) -> ndarray: +def permute_dims(x: Array, /, axes: Tuple[int, ...], xp: Namespace) -> Array: return xp.transpose(x, axes) # np.reshape calls the keyword argument 'newshape' instead of 'shape' -def reshape(x: ndarray, - /, - shape: Tuple[int, ...], - xp, copy: Optional[bool] = None, - **kwargs) -> ndarray: +def reshape( + x: Array, + /, + shape: Tuple[int, ...], + xp: Namespace, + *, + copy: Optional[bool] = None, + **kwargs, +) -> Array: if copy is True: x = x.copy() elif copy is False: @@ -410,9 +426,15 @@ def reshape(x: ndarray, # The descending keyword is new in sort and argsort, and 'kind' replaced with # 'stable' def argsort( - x: ndarray, /, xp, *, axis: int = -1, descending: bool = False, stable: bool = True, + x: Array, + /, + xp: Namespace, + *, + axis: int = -1, + descending: bool = False, + stable: bool = True, **kwargs, -) -> ndarray: +) -> Array: # Note: this keyword argument is different, and the default is different. # We set it in kwargs like this because numpy.sort uses kind='quicksort' # as the default whereas cupy.sort uses kind=None. @@ -435,9 +457,15 @@ def argsort( return res def sort( - x: ndarray, /, xp, *, axis: int = -1, descending: bool = False, stable: bool = True, + x: Array, + /, + xp: Namespace, + *, + axis: int = -1, + descending: bool = False, + stable: bool = True, **kwargs, -) -> ndarray: +) -> Array: # Note: this keyword argument is different, and the default is different. # We set it in kwargs like this because numpy.sort uses kind='quicksort' # as the default whereas cupy.sort uses kind=None. @@ -449,50 +477,51 @@ def sort( return res # nonzero should error for zero-dimensional arrays -def nonzero(x: ndarray, /, xp, **kwargs) -> Tuple[ndarray, ...]: +def nonzero(x: Array, /, xp: Namespace, **kwargs) -> Tuple[Array, ...]: if x.ndim == 0: raise ValueError("nonzero() does not support zero-dimensional arrays") return xp.nonzero(x, **kwargs) # ceil, floor, and trunc return integers for integer inputs -def ceil(x: ndarray, /, xp, **kwargs) -> ndarray: +def ceil(x: Array, /, xp: Namespace, **kwargs) -> Array: if xp.issubdtype(x.dtype, xp.integer): return x return xp.ceil(x, **kwargs) -def floor(x: ndarray, /, xp, **kwargs) -> ndarray: +def floor(x: Array, /, xp: Namespace, **kwargs) -> Array: if xp.issubdtype(x.dtype, xp.integer): return x return xp.floor(x, **kwargs) -def trunc(x: ndarray, /, xp, **kwargs) -> ndarray: +def trunc(x: Array, /, xp: Namespace, **kwargs) -> Array: if xp.issubdtype(x.dtype, xp.integer): return x return xp.trunc(x, **kwargs) # linear algebra functions -def matmul(x1: ndarray, x2: ndarray, /, xp, **kwargs) -> ndarray: +def matmul(x1: Array, x2: Array, /, xp: Namespace, **kwargs) -> Array: return xp.matmul(x1, x2, **kwargs) # Unlike transpose, matrix_transpose only transposes the last two axes. -def matrix_transpose(x: ndarray, /, xp) -> ndarray: +def matrix_transpose(x: Array, /, xp: Namespace) -> Array: if x.ndim < 2: raise ValueError("x must be at least 2-dimensional for matrix_transpose") return xp.swapaxes(x, -1, -2) -def tensordot(x1: ndarray, - x2: ndarray, - /, - xp, - *, - axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2, - **kwargs, -) -> ndarray: +def tensordot( + x1: Array, + x2: Array, + /, + xp: Namespace, + *, + axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2, + **kwargs, +) -> Array: return xp.tensordot(x1, x2, axes=axes, **kwargs) -def vecdot(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1) -> ndarray: +def vecdot(x1: Array, x2: Array, /, xp: Namespace, *, axis: int = -1) -> Array: if x1.shape[axis] != x2.shape[axis]: raise ValueError("x1 and x2 must have the same size along the given axis") @@ -511,8 +540,11 @@ def vecdot(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1) -> ndarray: # isdtype is a new function in the 2022.12 array API specification. def isdtype( - dtype: Dtype, kind: Union[Dtype, str, Tuple[Union[Dtype, str], ...]], xp, - *, _tuple=True, # Disallow nested tuples + dtype: DType, + kind: Union[DType, str, Tuple[Union[DType, str], ...]], + xp: Namespace, + *, + _tuple: bool = True, # Disallow nested tuples ) -> bool: """ Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``. @@ -551,14 +583,14 @@ def isdtype( return dtype == kind # unstack is a new function in the 2023.12 array API standard -def unstack(x: ndarray, /, xp, *, axis: int = 0) -> Tuple[ndarray, ...]: +def unstack(x: Array, /, xp: Namespace, *, axis: int = 0) -> Tuple[Array, ...]: if x.ndim == 0: raise ValueError("Input array must be at least 1-d.") return tuple(xp.moveaxis(x, axis, 0)) # numpy 1.26 does not use the standard definition for sign on complex numbers -def sign(x: ndarray, /, xp, **kwargs) -> ndarray: +def sign(x: Array, /, xp: Namespace, **kwargs) -> Array: if isdtype(x.dtype, 'complex floating', xp=xp): out = (x/xp.abs(x, **kwargs))[...] # sign(0) = 0 but the above formula would give nan diff --git a/array_api_compat/common/_fft.py b/array_api_compat/common/_fft.py index e5caebef..bd2a4e1a 100644 --- a/array_api_compat/common/_fft.py +++ b/array_api_compat/common/_fft.py @@ -1,149 +1,148 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Union, Optional, Literal +from collections.abc import Sequence +from typing import Union, Optional, Literal -if TYPE_CHECKING: - from ._typing import Device, ndarray, DType - from collections.abc import Sequence +from ._typing import Device, Array, DType, Namespace # Note: NumPy fft functions improperly upcast float32 and complex64 to # complex128, which is why we require wrapping them all here. def fft( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, n: Optional[int] = None, axis: int = -1, norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: +) -> Array: res = xp.fft.fft(x, n=n, axis=axis, norm=norm) if x.dtype in [xp.float32, xp.complex64]: return res.astype(xp.complex64) return res def ifft( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, n: Optional[int] = None, axis: int = -1, norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: +) -> Array: res = xp.fft.ifft(x, n=n, axis=axis, norm=norm) if x.dtype in [xp.float32, xp.complex64]: return res.astype(xp.complex64) return res def fftn( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, s: Sequence[int] = None, axes: Sequence[int] = None, norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: +) -> Array: res = xp.fft.fftn(x, s=s, axes=axes, norm=norm) if x.dtype in [xp.float32, xp.complex64]: return res.astype(xp.complex64) return res def ifftn( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, s: Sequence[int] = None, axes: Sequence[int] = None, norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: +) -> Array: res = xp.fft.ifftn(x, s=s, axes=axes, norm=norm) if x.dtype in [xp.float32, xp.complex64]: return res.astype(xp.complex64) return res def rfft( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, n: Optional[int] = None, axis: int = -1, norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: +) -> Array: res = xp.fft.rfft(x, n=n, axis=axis, norm=norm) if x.dtype == xp.float32: return res.astype(xp.complex64) return res def irfft( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, n: Optional[int] = None, axis: int = -1, norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: +) -> Array: res = xp.fft.irfft(x, n=n, axis=axis, norm=norm) if x.dtype == xp.complex64: return res.astype(xp.float32) return res def rfftn( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, s: Sequence[int] = None, axes: Sequence[int] = None, norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: +) -> Array: res = xp.fft.rfftn(x, s=s, axes=axes, norm=norm) if x.dtype == xp.float32: return res.astype(xp.complex64) return res def irfftn( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, s: Sequence[int] = None, axes: Sequence[int] = None, norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: +) -> Array: res = xp.fft.irfftn(x, s=s, axes=axes, norm=norm) if x.dtype == xp.complex64: return res.astype(xp.float32) return res def hfft( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, n: Optional[int] = None, axis: int = -1, norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: +) -> Array: res = xp.fft.hfft(x, n=n, axis=axis, norm=norm) if x.dtype in [xp.float32, xp.complex64]: return res.astype(xp.float32) return res def ihfft( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, n: Optional[int] = None, axis: int = -1, norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: +) -> Array: res = xp.fft.ihfft(x, n=n, axis=axis, norm=norm) if x.dtype in [xp.float32, xp.complex64]: return res.astype(xp.complex64) @@ -152,12 +151,12 @@ def ihfft( def fftfreq( n: int, /, - xp, + xp: Namespace, *, d: float = 1.0, dtype: Optional[DType] = None, - device: Optional[Device] = None -) -> ndarray: + device: Optional[Device] = None, +) -> Array: if device not in ["cpu", None]: raise ValueError(f"Unsupported device {device!r}") res = xp.fft.fftfreq(n, d=d) @@ -168,12 +167,12 @@ def fftfreq( def rfftfreq( n: int, /, - xp, + xp: Namespace, *, d: float = 1.0, dtype: Optional[DType] = None, - device: Optional[Device] = None -) -> ndarray: + device: Optional[Device] = None, +) -> Array: if device not in ["cpu", None]: raise ValueError(f"Unsupported device {device!r}") res = xp.fft.rfftfreq(n, d=d) @@ -181,10 +180,14 @@ def rfftfreq( return res.astype(dtype) return res -def fftshift(x: ndarray, /, xp, *, axes: Union[int, Sequence[int]] = None) -> ndarray: +def fftshift( + x: Array, /, xp: Namespace, *, axes: Union[int, Sequence[int]] = None +) -> Array: return xp.fft.fftshift(x, axes=axes) -def ifftshift(x: ndarray, /, xp, *, axes: Union[int, Sequence[int]] = None) -> ndarray: +def ifftshift( + x: Array, /, xp: Namespace, *, axes: Union[int, Sequence[int]] = None +) -> Array: return xp.fft.ifftshift(x, axes=axes) __all__ = [ diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 791edb81..6d95069d 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -7,16 +7,14 @@ """ from __future__ import annotations -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from typing import Optional, Union, Any - from ._typing import Array, Device, Namespace - import sys import math import inspect import warnings +from typing import Optional, Union, Any + +from ._typing import Array, Device, Namespace + def _is_jax_zero_gradient_array(x: object) -> bool: """Return True if `x` is a zero-gradient array. @@ -268,7 +266,7 @@ def _compat_module_name() -> str: return __name__.removesuffix('.common._helpers') -def is_numpy_namespace(xp) -> bool: +def is_numpy_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is a NumPy namespace. @@ -289,7 +287,7 @@ def is_numpy_namespace(xp) -> bool: return xp.__name__ in {'numpy', _compat_module_name() + '.numpy'} -def is_cupy_namespace(xp) -> bool: +def is_cupy_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is a CuPy namespace. @@ -310,7 +308,7 @@ def is_cupy_namespace(xp) -> bool: return xp.__name__ in {'cupy', _compat_module_name() + '.cupy'} -def is_torch_namespace(xp) -> bool: +def is_torch_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is a PyTorch namespace. @@ -331,7 +329,7 @@ def is_torch_namespace(xp) -> bool: return xp.__name__ in {'torch', _compat_module_name() + '.torch'} -def is_ndonnx_namespace(xp) -> bool: +def is_ndonnx_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is an NDONNX namespace. @@ -350,7 +348,7 @@ def is_ndonnx_namespace(xp) -> bool: return xp.__name__ == 'ndonnx' -def is_dask_namespace(xp) -> bool: +def is_dask_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is a Dask namespace. @@ -371,7 +369,7 @@ def is_dask_namespace(xp) -> bool: return xp.__name__ in {'dask.array', _compat_module_name() + '.dask.array'} -def is_jax_namespace(xp) -> bool: +def is_jax_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is a JAX namespace. @@ -393,7 +391,7 @@ def is_jax_namespace(xp) -> bool: return xp.__name__ in {'jax.numpy', 'jax.experimental.array_api'} -def is_pydata_sparse_namespace(xp) -> bool: +def is_pydata_sparse_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is a pydata/sparse namespace. @@ -412,7 +410,7 @@ def is_pydata_sparse_namespace(xp) -> bool: return xp.__name__ == 'sparse' -def is_array_api_strict_namespace(xp) -> bool: +def is_array_api_strict_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is an array-api-strict namespace. @@ -439,7 +437,11 @@ def _check_api_version(api_version: str) -> None: raise ValueError("Only the 2024.12 version of the array API specification is currently supported") -def array_namespace(*xs, api_version=None, use_compat=None) -> Namespace: +def array_namespace( + *xs: Union[Array, bool, int, float, complex, None], + api_version: Optional[str] = None, + use_compat: Optional[bool] = None, +) -> Namespace: """ Get the array API compatible namespace for the arrays `xs`. diff --git a/array_api_compat/common/_linalg.py b/array_api_compat/common/_linalg.py index bfa1f1b9..c77ee3b8 100644 --- a/array_api_compat/common/_linalg.py +++ b/array_api_compat/common/_linalg.py @@ -1,11 +1,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, NamedTuple -if TYPE_CHECKING: - from typing import Literal, Optional, Tuple, Union - from ._typing import ndarray - import math +from typing import Literal, NamedTuple, Optional, Tuple, Union import numpy as np if np.__version__[0] == "2": @@ -15,50 +11,53 @@ from ._aliases import matmul, matrix_transpose, tensordot, vecdot, isdtype from .._internal import get_xp +from ._typing import Array, Namespace # These are in the main NumPy namespace but not in numpy.linalg -def cross(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1, **kwargs) -> ndarray: +def cross(x1: Array, x2: Array, /, xp: Namespace, *, axis: int = -1, **kwargs) -> Array: return xp.cross(x1, x2, axis=axis, **kwargs) -def outer(x1: ndarray, x2: ndarray, /, xp, **kwargs) -> ndarray: +def outer(x1: Array, x2: Array, /, xp: Namespace, **kwargs) -> Array: return xp.outer(x1, x2, **kwargs) class EighResult(NamedTuple): - eigenvalues: ndarray - eigenvectors: ndarray + eigenvalues: Array + eigenvectors: Array class QRResult(NamedTuple): - Q: ndarray - R: ndarray + Q: Array + R: Array class SlogdetResult(NamedTuple): - sign: ndarray - logabsdet: ndarray + sign: Array + logabsdet: Array class SVDResult(NamedTuple): - U: ndarray - S: ndarray - Vh: ndarray + U: Array + S: Array + Vh: Array # These functions are the same as their NumPy counterparts except they return # a namedtuple. -def eigh(x: ndarray, /, xp, **kwargs) -> EighResult: +def eigh(x: Array, /, xp: Namespace, **kwargs) -> EighResult: return EighResult(*xp.linalg.eigh(x, **kwargs)) -def qr(x: ndarray, /, xp, *, mode: Literal['reduced', 'complete'] = 'reduced', +def qr(x: Array, /, xp: Namespace, *, mode: Literal['reduced', 'complete'] = 'reduced', **kwargs) -> QRResult: return QRResult(*xp.linalg.qr(x, mode=mode, **kwargs)) -def slogdet(x: ndarray, /, xp, **kwargs) -> SlogdetResult: +def slogdet(x: Array, /, xp: Namespace, **kwargs) -> SlogdetResult: return SlogdetResult(*xp.linalg.slogdet(x, **kwargs)) -def svd(x: ndarray, /, xp, *, full_matrices: bool = True, **kwargs) -> SVDResult: +def svd( + x: Array, /, xp: Namespace, *, full_matrices: bool = True, **kwargs +) -> SVDResult: return SVDResult(*xp.linalg.svd(x, full_matrices=full_matrices, **kwargs)) # These functions have additional keyword arguments # The upper keyword argument is new from NumPy -def cholesky(x: ndarray, /, xp, *, upper: bool = False, **kwargs) -> ndarray: +def cholesky(x: Array, /, xp: Namespace, *, upper: bool = False, **kwargs) -> Array: L = xp.linalg.cholesky(x, **kwargs) if upper: U = get_xp(xp)(matrix_transpose)(L) @@ -69,12 +68,12 @@ def cholesky(x: ndarray, /, xp, *, upper: bool = False, **kwargs) -> ndarray: # The rtol keyword argument of matrix_rank() and pinv() is new from NumPy. # Note that it has a different semantic meaning from tol and rcond. -def matrix_rank(x: ndarray, +def matrix_rank(x: Array, /, - xp, + xp: Namespace, *, - rtol: Optional[Union[float, ndarray]] = None, - **kwargs) -> ndarray: + rtol: Optional[Union[float, Array]] = None, + **kwargs) -> Array: # this is different from xp.linalg.matrix_rank, which supports 1 # dimensional arrays. if x.ndim < 2: @@ -88,7 +87,9 @@ def matrix_rank(x: ndarray, tol = S.max(axis=-1, keepdims=True)*xp.asarray(rtol)[..., xp.newaxis] return xp.count_nonzero(S > tol, axis=-1) -def pinv(x: ndarray, /, xp, *, rtol: Optional[Union[float, ndarray]] = None, **kwargs) -> ndarray: +def pinv( + x: Array, /, xp: Namespace, *, rtol: Optional[Union[float, Array]] = None, **kwargs +) -> Array: # this is different from xp.linalg.pinv, which does not multiply the # default tolerance by max(M, N). if rtol is None: @@ -97,15 +98,30 @@ def pinv(x: ndarray, /, xp, *, rtol: Optional[Union[float, ndarray]] = None, **k # These functions are new in the array API spec -def matrix_norm(x: ndarray, /, xp, *, keepdims: bool = False, ord: Optional[Union[int, float, Literal['fro', 'nuc']]] = 'fro') -> ndarray: +def matrix_norm( + x: Array, + /, + xp: Namespace, + *, + keepdims: bool = False, + ord: Optional[Union[int, float, Literal['fro', 'nuc']]] = 'fro', +) -> Array: return xp.linalg.norm(x, axis=(-2, -1), keepdims=keepdims, ord=ord) # svdvals is not in NumPy (but it is in SciPy). It is equivalent to # xp.linalg.svd(compute_uv=False). -def svdvals(x: ndarray, /, xp) -> Union[ndarray, Tuple[ndarray, ...]]: +def svdvals(x: Array, /, xp: Namespace) -> Union[Array, Tuple[Array, ...]]: return xp.linalg.svd(x, compute_uv=False) -def vector_norm(x: ndarray, /, xp, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ord: Optional[Union[int, float]] = 2) -> ndarray: +def vector_norm( + x: Array, + /, + xp: Namespace, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdims: bool = False, + ord: Optional[Union[int, float]] = 2, +) -> Array: # xp.linalg.norm tries to do a matrix norm whenever axis is a 2-tuple or # when axis=None and the input is 2-D, so to force a vector norm, we make # it so the input is 1-D (for axis=None), or reshape so that norm is done @@ -143,11 +159,15 @@ def vector_norm(x: ndarray, /, xp, *, axis: Optional[Union[int, Tuple[int, ...]] # xp.diagonal and xp.trace operate on the first two axes whereas these # operates on the last two -def diagonal(x: ndarray, /, xp, *, offset: int = 0, **kwargs) -> ndarray: +def diagonal(x: Array, /, xp: Namespace, *, offset: int = 0, **kwargs) -> Array: return xp.diagonal(x, offset=offset, axis1=-2, axis2=-1, **kwargs) -def trace(x: ndarray, /, xp, *, offset: int = 0, dtype=None, **kwargs) -> ndarray: - return xp.asarray(xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs)) +def trace( + x: Array, /, xp: Namespace, *, offset: int = 0, dtype=None, **kwargs +) -> Array: + return xp.asarray( + xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs) + ) __all__ = ['cross', 'matmul', 'outer', 'tensordot', 'EighResult', 'QRResult', 'SlogdetResult', 'SVDResult', 'eigh', 'qr', 'slogdet', diff --git a/array_api_compat/common/_typing.py b/array_api_compat/common/_typing.py index d8acdef7..4c3b356b 100644 --- a/array_api_compat/common/_typing.py +++ b/array_api_compat/common/_typing.py @@ -1,26 +1,24 @@ from __future__ import annotations +from types import ModuleType as Namespace +from typing import Any, TypeVar, Protocol __all__ = [ + "Array", + "DType", + "Device", + "Namespace", "NestedSequence", "SupportsBufferProtocol", ] -from types import ModuleType -from typing import ( - Any, - TypeVar, - Protocol, -) - _T_co = TypeVar("_T_co", covariant=True) class NestedSequence(Protocol[_T_co]): def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ... def __len__(self, /) -> int: ... -SupportsBufferProtocol = Any +SupportsBufferProtocol = Any Array = Any Device = Any DType = Any -Namespace = ModuleType diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index 30d9fe48..ebc7ccd9 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -1,16 +1,14 @@ from __future__ import annotations +from typing import Optional + import cupy as cp from ..common import _aliases, _helpers +from ..common._typing import NestedSequence, SupportsBufferProtocol from .._internal import get_xp - from ._info import __array_namespace_info__ - -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from typing import Optional, Union - from ._typing import ndarray, Device, Dtype, NestedSequence, SupportsBufferProtocol +from ._typing import Array, Device, DType bool = cp.bool_ @@ -66,23 +64,19 @@ _copy_default = object() + # asarray also adds the copy keyword, which is not present in numpy 1.0. def asarray( - obj: Union[ - ndarray, - bool, - int, - float, - NestedSequence[bool | int | float], - SupportsBufferProtocol, - ], + obj: ( + Array | bool | complex | NestedSequence[bool | complex] | SupportsBufferProtocol + ), /, *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, copy: Optional[bool] = _copy_default, **kwargs, -) -> ndarray: +) -> Array: """ Array API compatibility wrapper for asarray(). @@ -112,13 +106,13 @@ def asarray( def astype( - x: ndarray, - dtype: Dtype, + x: Array, + dtype: DType, /, *, copy: bool = True, device: Optional[Device] = None, -) -> ndarray: +) -> Array: if device is None: return x.astype(dtype=dtype, copy=copy) out = _helpers.to_device(x.astype(dtype=dtype, copy=False), device) @@ -127,10 +121,10 @@ def astype( # cupy.count_nonzero does not have keepdims def count_nonzero( - x: ndarray, + x: Array, axis=None, keepdims=False -) -> ndarray: +) -> Array: result = cp.count_nonzero(x, axis) if keepdims: if axis is None: diff --git a/array_api_compat/cupy/_typing.py b/array_api_compat/cupy/_typing.py index f3d9aab6..66af5d19 100644 --- a/array_api_compat/cupy/_typing.py +++ b/array_api_compat/cupy/_typing.py @@ -1,46 +1,31 @@ from __future__ import annotations -__all__ = [ - "ndarray", - "Device", - "Dtype", -] +__all__ = ["Array", "DType", "Device"] +_all_ignore = ["cp"] -import sys -from typing import ( - Union, - TYPE_CHECKING, -) - -from cupy import ( - ndarray, - dtype, - int8, - int16, - int32, - int64, - uint8, - uint16, - uint32, - uint64, - float32, - float64, -) +from typing import TYPE_CHECKING +import cupy as cp +from cupy import ndarray as Array from cupy.cuda.device import Device -if TYPE_CHECKING or sys.version_info >= (3, 9): - Dtype = dtype[Union[ - int8, - int16, - int32, - int64, - uint8, - uint16, - uint32, - uint64, - float32, - float64, - ]] +if TYPE_CHECKING: + # NumPy 1.x on Python 3.9 and 3.10 fails to parse np.dtype[] + DType = cp.dtype[ + cp.intp + | cp.int8 + | cp.int16 + | cp.int32 + | cp.int64 + | cp.uint8 + | cp.uint16 + | cp.uint32 + | cp.uint64 + | cp.float32 + | cp.float64 + | cp.complex64 + | cp.complex128 + | cp.bool_ + ] else: - Dtype = dtype + DType = cp.dtype diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index 80d66281..e737cebd 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -1,16 +1,10 @@ from __future__ import annotations -from typing import Callable - -from ...common import _aliases, array_namespace - -from ..._internal import get_xp - -from ._info import __array_namespace_info__ +from typing import Callable, Optional, Union import numpy as np from numpy import ( - # Dtypes + # dtypes iinfo, finfo, bool_ as bool, @@ -29,22 +23,19 @@ can_cast, result_type, ) - -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from typing import Optional, Union - - from ...common._typing import ( - Device, - Dtype, - Array, - NestedSequence, - SupportsBufferProtocol, - ) - import dask.array as da +from ...common import _aliases, array_namespace +from ...common._typing import ( + Array, + Device, + DType, + NestedSequence, + SupportsBufferProtocol, +) +from ..._internal import get_xp +from ._info import __array_namespace_info__ + isdtype = get_xp(np)(_aliases.isdtype) unstack = get_xp(da)(_aliases.unstack) @@ -52,7 +43,7 @@ # da.astype doesn't respect copy=True def astype( x: Array, - dtype: Dtype, + dtype: DType, /, *, copy: bool = True, @@ -84,7 +75,7 @@ def arange( stop: Optional[Union[int, float]] = None, step: Union[int, float] = 1, *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, **kwargs, ) -> Array: @@ -144,17 +135,12 @@ def arange( # asarray also adds the copy keyword, which is not present in numpy 1.0. def asarray( - obj: Union[ - Array, - bool, - int, - float, - NestedSequence[bool | int | float], - SupportsBufferProtocol, - ], + obj: ( + Array | bool | complex | NestedSequence[bool | complex] | SupportsBufferProtocol + ), /, *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, copy: Optional[Union[bool, np._CopyMode]] = None, **kwargs, @@ -360,4 +346,4 @@ def count_nonzero( 'complex64', 'complex128', 'iinfo', 'finfo', 'can_cast', 'count_nonzero', 'result_type'] -_all_ignore = ["Callable", "array_namespace", "get_xp", "da", "np"] +_all_ignore = ["array_namespace", "get_xp", "da", "np"] diff --git a/array_api_compat/dask/array/fft.py b/array_api_compat/dask/array/fft.py index aebd86f7..3f40dffe 100644 --- a/array_api_compat/dask/array/fft.py +++ b/array_api_compat/dask/array/fft.py @@ -4,9 +4,10 @@ # from dask.array.fft import __all__ as linalg_all _n = {} exec('from dask.array.fft import *', _n) -del _n['__builtins__'] +for k in ("__builtins__", "Sequence", "annotations", "warnings"): + _n.pop(k, None) fft_all = list(_n) -del _n +del _n, k from ...common import _fft from ..._internal import get_xp @@ -16,9 +17,5 @@ fftfreq = get_xp(da)(_fft.fftfreq) rfftfreq = get_xp(da)(_fft.rfftfreq) -__all__ = [elem for elem in fft_all if elem != "annotations"] + ["fftfreq", "rfftfreq"] - -del get_xp -del da -del fft_all -del _fft +__all__ = fft_all + ["fftfreq", "rfftfreq"] +_all_ignore = ["da", "fft_all", "get_xp", "warnings"] diff --git a/array_api_compat/dask/array/linalg.py b/array_api_compat/dask/array/linalg.py index 49c26d8b..bd53f0df 100644 --- a/array_api_compat/dask/array/linalg.py +++ b/array_api_compat/dask/array/linalg.py @@ -1,33 +1,28 @@ from __future__ import annotations -from ...common import _linalg -from ..._internal import get_xp +from typing import Literal +import dask.array as da # Exports from dask.array.linalg import * # noqa: F403 from dask.array import outer - # These functions are in both the main and linalg namespaces from dask.array import matmul, tensordot -from ._aliases import matrix_transpose, vecdot -import dask.array as da - -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from ...common._typing import Array - from typing import Literal +from ..._internal import get_xp +from ...common import _linalg +from ...common._typing import Array +from ._aliases import matrix_transpose, vecdot # dask.array.linalg doesn't have __all__. If it is added, replace this with # # from dask.array.linalg import __all__ as linalg_all _n = {} exec('from dask.array.linalg import *', _n) -del _n['__builtins__'] -if 'annotations' in _n: - del _n['annotations'] +for k in ('__builtins__', 'annotations', 'operator', 'warnings', 'Array'): + _n.pop(k, None) linalg_all = list(_n) -del _n +del _n, k EighResult = _linalg.EighResult QRResult = _linalg.QRResult @@ -70,4 +65,4 @@ def svdvals(x: Array) -> Array: "cholesky", "matrix_rank", "matrix_norm", "svdvals", "vector_norm", "diagonal"] -_all_ignore = ['get_xp', 'da', 'linalg_all'] +_all_ignore = ['get_xp', 'da', 'linalg_all', 'warnings'] diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index a47f7121..6536d9a8 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -1,17 +1,15 @@ from __future__ import annotations -from ..common import _aliases +from typing import Optional, Union from .._internal import get_xp - +from ..common import _aliases +from ..common._typing import NestedSequence, SupportsBufferProtocol from ._info import __array_namespace_info__ - -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from typing import Optional, Union - from ._typing import ndarray, Device, Dtype, NestedSequence, SupportsBufferProtocol +from ._typing import Array, Device, DType import numpy as np + bool = np.bool_ # Basic renames @@ -64,6 +62,7 @@ tensordot = get_xp(np)(_aliases.tensordot) sign = get_xp(np)(_aliases.sign) + def _supports_buffer_protocol(obj): try: memoryview(obj) @@ -71,26 +70,22 @@ def _supports_buffer_protocol(obj): return False return True + # asarray also adds the copy keyword, which is not present in numpy 1.0. # asarray() is different enough between numpy, cupy, and dask, the logic # complicated enough that it's easier to define it separately for each module # rather than trying to combine everything into one function in common/ def asarray( - obj: Union[ - ndarray, - bool, - int, - float, - NestedSequence[bool | int | float], - SupportsBufferProtocol, - ], + obj: ( + Array | bool | complex | NestedSequence[bool | complex] | SupportsBufferProtocol + ), /, *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, copy: "Optional[Union[bool, np._CopyMode]]" = None, **kwargs, -) -> ndarray: +) -> Array: """ Array API compatibility wrapper for asarray(). @@ -117,23 +112,19 @@ def asarray( def astype( - x: ndarray, - dtype: Dtype, + x: Array, + dtype: DType, /, *, copy: bool = True, device: Optional[Device] = None, -) -> ndarray: +) -> Array: return x.astype(dtype=dtype, copy=copy) # count_nonzero returns a python int for axis=None and keepdims=False # https://github.com/numpy/numpy/issues/17562 -def count_nonzero( - x : ndarray, - axis=None, - keepdims=False -) -> ndarray: +def count_nonzero(x: Array, axis=None, keepdims=False) -> Array: result = np.count_nonzero(x, axis=axis, keepdims=keepdims) if axis is None and not keepdims: return np.asarray(result) diff --git a/array_api_compat/numpy/_typing.py b/array_api_compat/numpy/_typing.py index c5ebb5ab..6a18a3b2 100644 --- a/array_api_compat/numpy/_typing.py +++ b/array_api_compat/numpy/_typing.py @@ -1,46 +1,31 @@ from __future__ import annotations -__all__ = [ - "ndarray", - "Device", - "Dtype", -] +__all__ = ["Array", "DType", "Device"] +_all_ignore = ["np"] -import sys -from typing import ( - Literal, - Union, - TYPE_CHECKING, -) +from typing import Literal, TYPE_CHECKING -from numpy import ( - ndarray, - dtype, - int8, - int16, - int32, - int64, - uint8, - uint16, - uint32, - uint64, - float32, - float64, -) +import numpy as np +from numpy import ndarray as Array Device = Literal["cpu"] -if TYPE_CHECKING or sys.version_info >= (3, 9): - Dtype = dtype[Union[ - int8, - int16, - int32, - int64, - uint8, - uint16, - uint32, - uint64, - float32, - float64, - ]] +if TYPE_CHECKING: + # NumPy 1.x on Python 3.9 and 3.10 fails to parse np.dtype[] + DType = np.dtype[ + np.intp + | np.int8 + | np.int16 + | np.int32 + | np.int64 + | np.uint8 + | np.uint16 + | np.uint32 + | np.uint64 + | np.float32 + | np.float64 + | np.complex64 + | np.complex128 + | np.bool + ] else: - Dtype = dtype + DType = np.dtype diff --git a/array_api_compat/py.typed b/array_api_compat/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 4b727f1c..87d32d85 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -2,21 +2,14 @@ from functools import reduce as _reduce, wraps as _wraps from builtins import all as _builtin_all, any as _builtin_any - -from ..common import _aliases -from .._internal import get_xp - -from ._info import __array_namespace_info__ +from typing import List, Optional, Sequence, Tuple, Union import torch -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from typing import List, Optional, Sequence, Tuple, Union - from ..common._typing import Device - from torch import dtype as Dtype - - array = torch.Tensor +from .._internal import get_xp +from ..common import _aliases +from ._info import __array_namespace_info__ +from ._typing import Array, Device, DType _int_dtypes = { torch.uint8, @@ -123,7 +116,7 @@ def _fix_promotion(x1, x2, only_scalar=True): _py_scalars = (bool, int, float, complex) -def result_type(*arrays_and_dtypes: Union[array, Dtype, bool, int, float, complex]) -> Dtype: +def result_type(*arrays_and_dtypes: Array | DType | complex) -> DType: num = len(arrays_and_dtypes) if num == 0: @@ -170,7 +163,7 @@ def _result_type(x, y): return torch.result_type(x, y) -def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool: +def can_cast(from_: Union[DType, Array], to: DType, /) -> bool: if not isinstance(from_, torch.dtype): from_ = from_.dtype return torch.can_cast(from_, to) @@ -216,13 +209,13 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool: # of 'axis'. # torch.min and torch.max return a tuple and don't support multiple axes https://github.com/pytorch/pytorch/issues/58745 -def max(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array: +def max(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Array: # https://github.com/pytorch/pytorch/issues/29137 if axis == (): return torch.clone(x) return torch.amax(x, axis, keepdims=keepdims) -def min(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array: +def min(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Array: # https://github.com/pytorch/pytorch/issues/29137 if axis == (): return torch.clone(x) @@ -235,7 +228,7 @@ def min(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keep # torch.sort also returns a tuple # https://github.com/pytorch/pytorch/issues/70921 -def sort(x: array, /, *, axis: int = -1, descending: bool = False, stable: bool = True, **kwargs) -> array: +def sort(x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True, **kwargs) -> Array: return torch.sort(x, dim=axis, descending=descending, stable=stable, **kwargs).values def _normalize_axes(axis, ndim): @@ -280,13 +273,13 @@ def _reduce_multiple_axes(f, x, axis, keepdims=False, **kwargs): out = torch.unsqueeze(out, a) return out -def prod(x: array, +def prod(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, keepdims: bool = False, - **kwargs) -> array: + **kwargs) -> Array: x = torch.asarray(x) ndim = x.ndim @@ -316,13 +309,13 @@ def prod(x: array, return torch.prod(x, axis, dtype=dtype, keepdims=keepdims, **kwargs) -def sum(x: array, +def sum(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, keepdims: bool = False, - **kwargs) -> array: + **kwargs) -> Array: x = torch.asarray(x) ndim = x.ndim @@ -347,12 +340,12 @@ def sum(x: array, return torch.sum(x, axis, dtype=dtype, keepdims=keepdims, **kwargs) -def any(x: array, +def any(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, - **kwargs) -> array: + **kwargs) -> Array: x = torch.asarray(x) ndim = x.ndim if axis == (): @@ -372,12 +365,12 @@ def any(x: array, # torch.any doesn't return bool for uint8 return torch.any(x, axis, keepdims=keepdims).to(torch.bool) -def all(x: array, +def all(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, - **kwargs) -> array: + **kwargs) -> Array: x = torch.asarray(x) ndim = x.ndim if axis == (): @@ -397,12 +390,12 @@ def all(x: array, # torch.all doesn't return bool for uint8 return torch.all(x, axis, keepdims=keepdims).to(torch.bool) -def mean(x: array, +def mean(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, - **kwargs) -> array: + **kwargs) -> Array: # https://github.com/pytorch/pytorch/issues/29137 if axis == (): return torch.clone(x) @@ -414,13 +407,13 @@ def mean(x: array, return res return torch.mean(x, axis, keepdims=keepdims, **kwargs) -def std(x: array, +def std(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, correction: Union[int, float] = 0.0, keepdims: bool = False, - **kwargs) -> array: + **kwargs) -> Array: # Note, float correction is not supported # https://github.com/pytorch/pytorch/issues/61492. We don't try to # implement it here for now. @@ -445,13 +438,13 @@ def std(x: array, return res return torch.std(x, axis, correction=_correction, keepdims=keepdims, **kwargs) -def var(x: array, +def var(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, correction: Union[int, float] = 0.0, keepdims: bool = False, - **kwargs) -> array: + **kwargs) -> Array: # Note, float correction is not supported # https://github.com/pytorch/pytorch/issues/61492. We don't try to # implement it here for now. @@ -474,11 +467,11 @@ def var(x: array, # torch.concat doesn't support dim=None # https://github.com/pytorch/pytorch/issues/70925 -def concat(arrays: Union[Tuple[array, ...], List[array]], +def concat(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: Optional[int] = 0, - **kwargs) -> array: + **kwargs) -> Array: if axis is None: arrays = tuple(ar.flatten() for ar in arrays) axis = 0 @@ -487,7 +480,7 @@ def concat(arrays: Union[Tuple[array, ...], List[array]], # torch.squeeze only accepts int dim and doesn't require it # https://github.com/pytorch/pytorch/issues/70924. Support for tuple dim was # added at https://github.com/pytorch/pytorch/pull/89017. -def squeeze(x: array, /, axis: Union[int, Tuple[int, ...]]) -> array: +def squeeze(x: Array, /, axis: Union[int, Tuple[int, ...]]) -> Array: if isinstance(axis, int): axis = (axis,) for a in axis: @@ -501,27 +494,27 @@ def squeeze(x: array, /, axis: Union[int, Tuple[int, ...]]) -> array: return x # torch.broadcast_to uses size instead of shape -def broadcast_to(x: array, /, shape: Tuple[int, ...], **kwargs) -> array: +def broadcast_to(x: Array, /, shape: Tuple[int, ...], **kwargs) -> Array: return torch.broadcast_to(x, shape, **kwargs) # torch.permute uses dims instead of axes -def permute_dims(x: array, /, axes: Tuple[int, ...]) -> array: +def permute_dims(x: Array, /, axes: Tuple[int, ...]) -> Array: return torch.permute(x, axes) # The axis parameter doesn't work for flip() and roll() # https://github.com/pytorch/pytorch/issues/71210. Also torch.flip() doesn't # accept axis=None -def flip(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs) -> array: +def flip(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs) -> Array: if axis is None: axis = tuple(range(x.ndim)) # torch.flip doesn't accept dim as an int but the method does # https://github.com/pytorch/pytorch/issues/18095 return x.flip(axis, **kwargs) -def roll(x: array, /, shift: Union[int, Tuple[int, ...]], *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs) -> array: +def roll(x: Array, /, shift: Union[int, Tuple[int, ...]], *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs) -> Array: return torch.roll(x, shift, axis, **kwargs) -def nonzero(x: array, /, **kwargs) -> Tuple[array, ...]: +def nonzero(x: Array, /, **kwargs) -> Tuple[Array, ...]: if x.ndim == 0: raise ValueError("nonzero() does not support zero-dimensional arrays") return torch.nonzero(x, as_tuple=True, **kwargs) @@ -529,25 +522,25 @@ def nonzero(x: array, /, **kwargs) -> Tuple[array, ...]: # torch uses `dim` instead of `axis` def diff( - x: array, + x: Array, /, *, axis: int = -1, n: int = 1, - prepend: Optional[array] = None, - append: Optional[array] = None, -) -> array: + prepend: Optional[Array] = None, + append: Optional[Array] = None, +) -> Array: return torch.diff(x, dim=axis, n=n, prepend=prepend, append=append) # torch uses `dim` instead of `axis`, does not have keepdims def count_nonzero( - x: array, + x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, -) -> array: +) -> Array: result = torch.count_nonzero(x, dim=axis) if keepdims: if axis is not None: @@ -557,17 +550,17 @@ def count_nonzero( return result - -def where(condition: array, x1: array, x2: array, /) -> array: +def where(condition: Array, x1: Array, x2: Array, /) -> Array: x1, x2 = _fix_promotion(x1, x2) return torch.where(condition, x1, x2) # torch.reshape doesn't have the copy keyword -def reshape(x: array, +def reshape(x: Array, /, shape: Tuple[int, ...], + *, copy: Optional[bool] = None, - **kwargs) -> array: + **kwargs) -> Array: if copy is not None: raise NotImplementedError("torch.reshape doesn't yet support the copy keyword") return torch.reshape(x, shape, **kwargs) @@ -581,9 +574,9 @@ def arange(start: Union[int, float], stop: Optional[Union[int, float]] = None, step: Union[int, float] = 1, *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, - **kwargs) -> array: + **kwargs) -> Array: if stop is None: start, stop = 0, start if step > 0 and stop <= start or step < 0 and stop >= start: @@ -602,9 +595,9 @@ def eye(n_rows: int, /, *, k: int = 0, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, - **kwargs) -> array: + **kwargs) -> Array: if n_cols is None: n_cols = n_rows z = torch.zeros(n_rows, n_cols, dtype=dtype, device=device, **kwargs) @@ -618,10 +611,10 @@ def linspace(start: Union[int, float], /, num: int, *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, endpoint: bool = True, - **kwargs) -> array: + **kwargs) -> Array: if not endpoint: return torch.linspace(start, stop, num+1, dtype=dtype, device=device, **kwargs)[:-1] return torch.linspace(start, stop, num, dtype=dtype, device=device, **kwargs) @@ -629,11 +622,11 @@ def linspace(start: Union[int, float], # torch.full does not accept an int size # https://github.com/pytorch/pytorch/issues/70906 def full(shape: Union[int, Tuple[int, ...]], - fill_value: Union[bool, int, float, complex], + fill_value: complex, *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, - **kwargs) -> array: + **kwargs) -> Array: if isinstance(shape, int): shape = (shape,) @@ -642,52 +635,52 @@ def full(shape: Union[int, Tuple[int, ...]], # ones, zeros, and empty do not accept shape as a keyword argument def ones(shape: Union[int, Tuple[int, ...]], *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, - **kwargs) -> array: + **kwargs) -> Array: return torch.ones(shape, dtype=dtype, device=device, **kwargs) def zeros(shape: Union[int, Tuple[int, ...]], *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, - **kwargs) -> array: + **kwargs) -> Array: return torch.zeros(shape, dtype=dtype, device=device, **kwargs) def empty(shape: Union[int, Tuple[int, ...]], *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, - **kwargs) -> array: + **kwargs) -> Array: return torch.empty(shape, dtype=dtype, device=device, **kwargs) # tril and triu do not call the keyword argument k -def tril(x: array, /, *, k: int = 0) -> array: +def tril(x: Array, /, *, k: int = 0) -> Array: return torch.tril(x, k) -def triu(x: array, /, *, k: int = 0) -> array: +def triu(x: Array, /, *, k: int = 0) -> Array: return torch.triu(x, k) # Functions that aren't in torch https://github.com/pytorch/pytorch/issues/58742 -def expand_dims(x: array, /, *, axis: int = 0) -> array: +def expand_dims(x: Array, /, *, axis: int = 0) -> Array: return torch.unsqueeze(x, axis) def astype( - x: array, - dtype: Dtype, + x: Array, + dtype: DType, /, *, copy: bool = True, device: Optional[Device] = None, -) -> array: +) -> Array: if device is not None: return x.to(device, dtype=dtype, copy=copy) return x.to(dtype=dtype, copy=copy) -def broadcast_arrays(*arrays: array) -> List[array]: +def broadcast_arrays(*arrays: Array) -> List[Array]: shape = torch.broadcast_shapes(*[a.shape for a in arrays]) return [torch.broadcast_to(a, shape) for a in arrays] @@ -697,7 +690,7 @@ def broadcast_arrays(*arrays: array) -> List[array]: UniqueInverseResult) # https://github.com/pytorch/pytorch/issues/70920 -def unique_all(x: array) -> UniqueAllResult: +def unique_all(x: Array) -> UniqueAllResult: # torch.unique doesn't support returning indices. # https://github.com/pytorch/pytorch/issues/36748. The workaround # suggested in that issue doesn't actually function correctly (it relies @@ -710,7 +703,7 @@ def unique_all(x: array) -> UniqueAllResult: # counts[torch.isnan(values)] = 1 # return UniqueAllResult(values, indices, inverse_indices, counts) -def unique_counts(x: array) -> UniqueCountsResult: +def unique_counts(x: Array) -> UniqueCountsResult: values, counts = torch.unique(x, return_counts=True) # torch.unique incorrectly gives a 0 count for nan values. @@ -718,14 +711,14 @@ def unique_counts(x: array) -> UniqueCountsResult: counts[torch.isnan(values)] = 1 return UniqueCountsResult(values, counts) -def unique_inverse(x: array) -> UniqueInverseResult: +def unique_inverse(x: Array) -> UniqueInverseResult: values, inverse = torch.unique(x, return_inverse=True) return UniqueInverseResult(values, inverse) -def unique_values(x: array) -> array: +def unique_values(x: Array) -> Array: return torch.unique(x) -def matmul(x1: array, x2: array, /, **kwargs) -> array: +def matmul(x1: Array, x2: Array, /, **kwargs) -> Array: # torch.matmul doesn't type promote (but differently from _fix_promotion) x1, x2 = _fix_promotion(x1, x2, only_scalar=False) return torch.matmul(x1, x2, **kwargs) @@ -733,12 +726,19 @@ def matmul(x1: array, x2: array, /, **kwargs) -> array: matrix_transpose = get_xp(torch)(_aliases.matrix_transpose) _vecdot = get_xp(torch)(_aliases.vecdot) -def vecdot(x1: array, x2: array, /, *, axis: int = -1) -> array: +def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: x1, x2 = _fix_promotion(x1, x2, only_scalar=False) return _vecdot(x1, x2, axis=axis) # torch.tensordot uses dims instead of axes -def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2, **kwargs) -> array: +def tensordot( + x1: Array, + x2: Array, + /, + *, + axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2, + **kwargs, +) -> Array: # Note: torch.tensordot fails with integer dtypes when there is only 1 # element in the axis (https://github.com/pytorch/pytorch/issues/84530). x1, x2 = _fix_promotion(x1, x2, only_scalar=False) @@ -746,7 +746,7 @@ def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int], def isdtype( - dtype: Dtype, kind: Union[Dtype, str, Tuple[Union[Dtype, str], ...]], + dtype: DType, kind: Union[DType, str, Tuple[Union[DType, str], ...]], *, _tuple=True, # Disallow nested tuples ) -> bool: """ @@ -781,7 +781,7 @@ def isdtype( else: return dtype == kind -def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -> array: +def take(x: Array, indices: Array, /, *, axis: Optional[int] = None, **kwargs) -> Array: if axis is None: if x.ndim != 1: raise ValueError("axis must be specified when ndim > 1") @@ -789,11 +789,11 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) - return torch.index_select(x, axis, indices, **kwargs) -def take_along_axis(x: array, indices: array, /, *, axis: int = -1) -> array: +def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array: return torch.take_along_dim(x, indices, dim=axis) -def sign(x: array, /) -> array: +def sign(x: Array, /) -> Array: # torch sign() does not support complex numbers and does not propagate # nans. See https://github.com/data-apis/array-api-compat/issues/136 if x.dtype.is_complex: diff --git a/array_api_compat/torch/_typing.py b/array_api_compat/torch/_typing.py new file mode 100644 index 00000000..29ad3fa7 --- /dev/null +++ b/array_api_compat/torch/_typing.py @@ -0,0 +1,4 @@ +__all__ = ["Array", "DType", "Device"] + +from torch import dtype as DType, Tensor as Array +from ..common._typing import Device diff --git a/array_api_compat/torch/fft.py b/array_api_compat/torch/fft.py index 3c9117ee..50e6a0d0 100644 --- a/array_api_compat/torch/fft.py +++ b/array_api_compat/torch/fft.py @@ -1,76 +1,75 @@ from __future__ import annotations -from typing import TYPE_CHECKING -if TYPE_CHECKING: - import torch - array = torch.Tensor - from typing import Union, Sequence, Literal +from typing import Union, Sequence, Literal -from torch.fft import * # noqa: F403 +import torch import torch.fft +from torch.fft import * # noqa: F403 + +from ._typing import Array # Several torch fft functions do not map axes to dim def fftn( - x: array, + x: Array, /, *, s: Sequence[int] = None, axes: Sequence[int] = None, norm: Literal["backward", "ortho", "forward"] = "backward", **kwargs, -) -> array: +) -> Array: return torch.fft.fftn(x, s=s, dim=axes, norm=norm, **kwargs) def ifftn( - x: array, + x: Array, /, *, s: Sequence[int] = None, axes: Sequence[int] = None, norm: Literal["backward", "ortho", "forward"] = "backward", **kwargs, -) -> array: +) -> Array: return torch.fft.ifftn(x, s=s, dim=axes, norm=norm, **kwargs) def rfftn( - x: array, + x: Array, /, *, s: Sequence[int] = None, axes: Sequence[int] = None, norm: Literal["backward", "ortho", "forward"] = "backward", **kwargs, -) -> array: +) -> Array: return torch.fft.rfftn(x, s=s, dim=axes, norm=norm, **kwargs) def irfftn( - x: array, + x: Array, /, *, s: Sequence[int] = None, axes: Sequence[int] = None, norm: Literal["backward", "ortho", "forward"] = "backward", **kwargs, -) -> array: +) -> Array: return torch.fft.irfftn(x, s=s, dim=axes, norm=norm, **kwargs) def fftshift( - x: array, + x: Array, /, *, axes: Union[int, Sequence[int]] = None, **kwargs, -) -> array: +) -> Array: return torch.fft.fftshift(x, dim=axes, **kwargs) def ifftshift( - x: array, + x: Array, /, *, axes: Union[int, Sequence[int]] = None, **kwargs, -) -> array: +) -> Array: return torch.fft.ifftshift(x, dim=axes, **kwargs) diff --git a/array_api_compat/torch/linalg.py b/array_api_compat/torch/linalg.py index e26198b9..7b59a670 100644 --- a/array_api_compat/torch/linalg.py +++ b/array_api_compat/torch/linalg.py @@ -1,14 +1,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING -if TYPE_CHECKING: - import torch - array = torch.Tensor - from torch import dtype as Dtype - from typing import Optional, Union, Tuple, Literal - inf = float('inf') - -from ._aliases import _fix_promotion, sum +import torch +from typing import Optional, Union, Tuple from torch.linalg import * # noqa: F403 @@ -19,15 +12,17 @@ # outer is implemented in torch but aren't in the linalg namespace from torch import outer +from ._aliases import _fix_promotion, sum # These functions are in both the main and linalg namespaces from ._aliases import matmul, matrix_transpose, tensordot +from ._typing import Array, DType # Note: torch.linalg.cross does not default to axis=-1 (it defaults to the # first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743 # torch.cross also does not support broadcasting when it would add new # dimensions https://github.com/pytorch/pytorch/issues/39656 -def cross(x1: array, x2: array, /, *, axis: int = -1) -> array: +def cross(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: x1, x2 = _fix_promotion(x1, x2, only_scalar=False) if not (-min(x1.ndim, x2.ndim) <= axis < max(x1.ndim, x2.ndim)): raise ValueError(f"axis {axis} out of bounds for cross product of arrays with shapes {x1.shape} and {x2.shape}") @@ -36,7 +31,7 @@ def cross(x1: array, x2: array, /, *, axis: int = -1) -> array: x1, x2 = torch.broadcast_tensors(x1, x2) return torch_linalg.cross(x1, x2, dim=axis) -def vecdot(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array: +def vecdot(x1: Array, x2: Array, /, *, axis: int = -1, **kwargs) -> Array: from ._aliases import isdtype x1, x2 = _fix_promotion(x1, x2, only_scalar=False) @@ -58,7 +53,7 @@ def vecdot(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array: return res[..., 0, 0] return torch.linalg.vecdot(x1, x2, dim=axis, **kwargs) -def solve(x1: array, x2: array, /, **kwargs) -> array: +def solve(x1: Array, x2: Array, /, **kwargs) -> Array: x1, x2 = _fix_promotion(x1, x2, only_scalar=False) # Torch tries to emulate NumPy 1 solve behavior by using batched 1-D solve # whenever @@ -79,19 +74,20 @@ def solve(x1: array, x2: array, /, **kwargs) -> array: return torch.linalg.solve(x1, x2, **kwargs) # torch.trace doesn't support the offset argument and doesn't support stacking -def trace(x: array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> array: +def trace(x: Array, /, *, offset: int = 0, dtype: Optional[DType] = None) -> Array: # Use our wrapped sum to make sure it does upcasting correctly return sum(torch.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype) def vector_norm( - x: array, + x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, - ord: Union[int, float, Literal[inf, -inf]] = 2, + # float stands for inf | -inf, which are not valid for Literal + ord: Union[int, float, float] = 2, **kwargs, -) -> array: +) -> Array: # torch.vector_norm incorrectly treats axis=() the same as axis=None if axis == (): out = kwargs.get('out') diff --git a/setup.py b/setup.py index 3d2b68a2..2368ccc4 100644 --- a/setup.py +++ b/setup.py @@ -33,5 +33,8 @@ "Programming Language :: Python :: 3.13", "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", - ] + ], + package_data={ + "array_api_compat": ["py.typed"], + }, ) diff --git a/tests/test_all.py b/tests/test_all.py index d2e9b768..598fab62 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -15,6 +15,16 @@ from ._helpers import import_, wrapped_libraries import pytest +import typing + +TYPING_NAMES = frozenset(( + "Array", + "Device", + "DType", + "Namespace", + "NestedSequence", + "SupportsBufferProtocol", +)) @pytest.mark.skip(reason="TODO: starts failing after adding test_torch.py in gh-277") @pytest.mark.parametrize("library", ["common"] + wrapped_libraries) @@ -38,8 +48,11 @@ def test_all(library): dir_names = [n for n in dir(module) if not n.startswith('_')] if '__array_namespace_info__' in dir(module): dir_names.append('__array_namespace_info__') - ignore_all_names = getattr(module, '_all_ignore', []) - ignore_all_names += ['annotations', 'TYPE_CHECKING'] + ignore_all_names = set(getattr(module, '_all_ignore', ())) + ignore_all_names |= set(dir(typing)) + ignore_all_names |= {"annotations"} + if not module.__name__.endswith("._typing"): + ignore_all_names |= TYPING_NAMES dir_names = set(dir_names) - set(ignore_all_names) all_names = module.__all__ From 26845bd904ee66bb830463f46bb39f1cc5392275 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sat, 22 Mar 2025 11:12:31 +0100 Subject: [PATCH 14/80] Revert "TST: skip test_all" This reverts commit 5473d84d5c36b23e091b880279c863c32f41b828. --- tests/test_all.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_all.py b/tests/test_all.py index 598fab62..eeb67e4b 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -26,7 +26,6 @@ "SupportsBufferProtocol", )) -@pytest.mark.skip(reason="TODO: starts failing after adding test_torch.py in gh-277") @pytest.mark.parametrize("library", ["common"] + wrapped_libraries) def test_all(library): if library == "common": From 07a3cd41e1c5804b7c11d358400431e8a53a984a Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sat, 22 Mar 2025 11:40:02 +0100 Subject: [PATCH 15/80] MAINT: run self-tests even if a library is missing --- tests/test_array_namespace.py | 6 ++++-- tests/test_dask.py | 8 ++++++-- tests/test_jax.py | 8 ++++++-- tests/test_torch.py | 6 +++++- tests/test_vendoring.py | 2 ++ 5 files changed, 23 insertions(+), 7 deletions(-) diff --git a/tests/test_array_namespace.py b/tests/test_array_namespace.py index 605c69a1..cdb80007 100644 --- a/tests/test_array_namespace.py +++ b/tests/test_array_namespace.py @@ -2,10 +2,8 @@ import sys import warnings -import jax import numpy as np import pytest -import torch import array_api_compat from array_api_compat import array_namespace @@ -76,6 +74,7 @@ def test_array_namespace(library, api_version, use_compat): subprocess.run([sys.executable, "-c", code], check=True) def test_jax_zero_gradient(): + jax = import_("jax") jx = jax.numpy.arange(4) jax_zero = jax.vmap(jax.grad(jax.numpy.float32, allow_int=True))(jx) assert array_namespace(jax_zero) is array_namespace(jx) @@ -89,11 +88,13 @@ def test_array_namespace_errors(): pytest.raises(TypeError, lambda: array_namespace(x, (x, x))) def test_array_namespace_errors_torch(): + torch = import_("torch") y = torch.asarray([1, 2]) x = np.asarray([1, 2]) pytest.raises(TypeError, lambda: array_namespace(x, y)) def test_api_version_torch(): + torch = import_("torch") x = torch.asarray([1, 2]) torch_ = import_("torch", wrapper=True) assert array_namespace(x, api_version="2023.12") == torch_ @@ -118,6 +119,7 @@ def test_get_namespace(): assert array_api_compat.get_namespace is array_namespace def test_python_scalars(): + torch = import_("torch") a = torch.asarray([1, 2]) xp = import_("torch", wrapper=True) diff --git a/tests/test_dask.py b/tests/test_dask.py index be2b1e39..69c738f6 100644 --- a/tests/test_dask.py +++ b/tests/test_dask.py @@ -1,10 +1,14 @@ from contextlib import contextmanager import array_api_strict -import dask import numpy as np import pytest -import dask.array as da + +try: + import dask + import dask.array as da +except ImportError: + pytestmark = pytest.skip(allow_module_level=True, reason="dask not found") from array_api_compat import array_namespace diff --git a/tests/test_jax.py b/tests/test_jax.py index e33cec02..285958d4 100644 --- a/tests/test_jax.py +++ b/tests/test_jax.py @@ -1,10 +1,14 @@ -import jax -import jax.numpy as jnp from numpy.testing import assert_equal import pytest from array_api_compat import device, to_device +try: + import jax + import jax.numpy as jnp +except ImportError: + pytestmark = pytest.skip(allow_module_level=True, reason="jax not found") + HAS_JAX_0_4_31 = jax.__version__ >= "0.4.31" diff --git a/tests/test_torch.py b/tests/test_torch.py index 75b3a136..e8340f31 100644 --- a/tests/test_torch.py +++ b/tests/test_torch.py @@ -3,7 +3,11 @@ import itertools import pytest -import torch + +try: + import torch +except ImportError: + pytestmark = pytest.skip(allow_module_level=True, reason="pytorch not found") from array_api_compat import torch as xp diff --git a/tests/test_vendoring.py b/tests/test_vendoring.py index 70083b49..8b561551 100644 --- a/tests/test_vendoring.py +++ b/tests/test_vendoring.py @@ -16,11 +16,13 @@ def test_vendoring_cupy(): def test_vendoring_torch(): + pytest.importorskip("torch") from vendor_test import uses_torch uses_torch._test_torch() def test_vendoring_dask(): + pytest.importorskip("dask") from vendor_test import uses_dask uses_dask._test_dask() From 89466a6b43672b9a4a2dbdaea2896c24e4dcdd76 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sat, 22 Mar 2025 14:01:44 +0100 Subject: [PATCH 16/80] MAINT: common._aliases.__all__ --- array_api_compat/common/_aliases.py | 18 +++++++++++++----- tests/test_all.py | 2 +- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 0d123b99..0d1ecfbc 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -7,8 +7,14 @@ import inspect from typing import NamedTuple, Optional, Sequence, Tuple, Union -from ._helpers import array_namespace, _check_device, device, is_cupy_namespace from ._typing import Array, Device, DType, Namespace +from ._helpers import ( + array_namespace, + _check_device, + device as _get_device, + is_cupy_namespace as _is_cupy_namespace +) + # These functions are modified from the NumPy versions. @@ -298,7 +304,7 @@ def cumulative_sum( initial_shape = list(x.shape) initial_shape[axis] = 1 res = xp.concatenate( - [wrapped_xp.zeros(shape=initial_shape, dtype=res.dtype, device=device(res)), res], + [wrapped_xp.zeros(shape=initial_shape, dtype=res.dtype, device=_get_device(res)), res], axis=axis, ) return res @@ -328,7 +334,7 @@ def cumulative_prod( initial_shape = list(x.shape) initial_shape[axis] = 1 res = xp.concatenate( - [wrapped_xp.ones(shape=initial_shape, dtype=res.dtype, device=device(res)), res], + [wrapped_xp.ones(shape=initial_shape, dtype=res.dtype, device=_get_device(res)), res], axis=axis, ) return res @@ -381,7 +387,7 @@ def _isscalar(a): if type(max) is int and max >= wrapped_xp.iinfo(x.dtype).max: max = None - dev = device(x) + dev = _get_device(x) if out is None: out = wrapped_xp.empty(result_shape, dtype=x.dtype, device=dev) out[()] = x @@ -599,7 +605,7 @@ def sign(x: Array, /, xp: Namespace, **kwargs) -> Array: out = xp.sign(x, **kwargs) # CuPy sign() does not propagate nans. See # https://github.com/data-apis/array-api-compat/issues/136 - if is_cupy_namespace(xp) and isdtype(x.dtype, 'real floating', xp=xp): + if _is_cupy_namespace(xp) and isdtype(x.dtype, 'real floating', xp=xp): out[xp.isnan(x)] = xp.nan return out[()] @@ -611,3 +617,5 @@ def sign(x: Array, /, xp: Namespace, **kwargs) -> Array: 'reshape', 'argsort', 'sort', 'nonzero', 'ceil', 'floor', 'trunc', 'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype', 'unstack', 'sign'] + +_all_ignore = ['inspect', 'array_namespace', 'NamedTuple'] diff --git a/tests/test_all.py b/tests/test_all.py index eeb67e4b..4df4a361 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -33,7 +33,7 @@ def test_all(library): else: import_(library, wrapper=True) - for mod_name in sys.modules: + for mod_name in sys.modules.copy(): if not mod_name.startswith('array_api_compat.' + library): continue From 23841dfdb319fbb66a4065e0c138235c56e611f0 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sun, 23 Mar 2025 09:28:03 +0100 Subject: [PATCH 17/80] TST: update the torch skiplist --- torch-xfails.txt | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torch-xfails.txt b/torch-xfails.txt index 6e8f7dc6..f8333d90 100644 --- a/torch-xfails.txt +++ b/torch-xfails.txt @@ -144,10 +144,12 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_sc array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[logaddexp] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[maximum] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[minimum] + +# https://github.com/pytorch/pytorch/issues/149815 array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[equal] -array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[neq] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[not_equal] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[less] -array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[les_equal] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[less_equal] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[greater] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[greater_equal] From 3b4ea593d43c3d522aa1e601a93781774606bbc3 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sun, 23 Mar 2025 09:33:26 +0100 Subject: [PATCH 18/80] TST: update numpy<2 skiplists --- numpy-1-21-xfails.txt | 1 + numpy-1-26-xfails.txt | 1 + 2 files changed, 2 insertions(+) diff --git a/numpy-1-21-xfails.txt b/numpy-1-21-xfails.txt index 7c7a0757..30cde668 100644 --- a/numpy-1-21-xfails.txt +++ b/numpy-1-21-xfails.txt @@ -192,6 +192,7 @@ array_api_tests/test_signatures.py::test_func_signature[bitwise_left_shift] array_api_tests/test_signatures.py::test_func_signature[bitwise_or] array_api_tests/test_signatures.py::test_func_signature[bitwise_right_shift] array_api_tests/test_signatures.py::test_func_signature[bitwise_xor] +array_api_tests/test_data_type_functions.py::TestResultType::test_with_scalars # Stubs have a comment: (**note**: libraries may return ``NaN`` to match Python behavior.); Apparently,NumPy does just that array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] diff --git a/numpy-1-26-xfails.txt b/numpy-1-26-xfails.txt index 57259b6f..1ce28ef4 100644 --- a/numpy-1-26-xfails.txt +++ b/numpy-1-26-xfails.txt @@ -46,6 +46,7 @@ array_api_tests/test_signatures.py::test_func_signature[bitwise_left_shift] array_api_tests/test_signatures.py::test_func_signature[bitwise_or] array_api_tests/test_signatures.py::test_func_signature[bitwise_right_shift] array_api_tests/test_signatures.py::test_func_signature[bitwise_xor] +array_api_tests/test_data_type_functions.py::TestResultType::test_with_scalars # Stubs have a comment: (**note**: libraries may return ``NaN`` to match Python behavior.); Apparently, NumPy does just that array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] From f19256e3e132f0c16147936d1cf320680366055a Mon Sep 17 00:00:00 2001 From: Neil Girdhar Date: Fri, 21 Mar 2025 07:02:22 -0400 Subject: [PATCH 19/80] Add pyprject.toml --- .github/workflows/docs-build.yml | 2 +- .github/workflows/tests.yml | 6 +- docs/dev/tests.md | 2 +- docs/requirements.txt | 6 -- pyproject.toml | 96 ++++++++++++++++++++++++++++++++ requirements-dev.txt | 8 --- ruff.toml | 17 ------ setup.py | 40 ------------- 8 files changed, 99 insertions(+), 78 deletions(-) delete mode 100644 docs/requirements.txt create mode 100644 pyproject.toml delete mode 100644 requirements-dev.txt delete mode 100644 ruff.toml delete mode 100644 setup.py diff --git a/.github/workflows/docs-build.yml b/.github/workflows/docs-build.yml index 04c3aa66..34b9cbc6 100644 --- a/.github/workflows/docs-build.yml +++ b/.github/workflows/docs-build.yml @@ -10,7 +10,7 @@ jobs: - uses: actions/setup-python@v5 - name: Install Dependencies run: | - python -m pip install -r docs/requirements.txt + python -m pip install .[docs] - name: Build Docs run: | cd docs diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index fcd43367..54f6f402 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -29,11 +29,7 @@ jobs: PIP_EXTRA='numpy==1.26.*' fi - if [ "${{ matrix.python-version }}" == "3.9" ]; then - sed -i '/^ndonnx/d' requirements-dev.txt - fi - - python -m pip install -r requirements-dev.txt $PIP_EXTRA + python -m pip install .[dev] $PIP_EXTRA - name: Run Tests run: | diff --git a/docs/dev/tests.md b/docs/dev/tests.md index 6d9d1d7b..18fb7cf5 100644 --- a/docs/dev/tests.md +++ b/docs/dev/tests.md @@ -7,7 +7,7 @@ the array API standard. There are also array-api-compat specific tests in These tests should be limited to things that are not tested by the test suite, e.g., tests for [helper functions](../helper-functions.rst) or for behavior that is not strictly required by the standard. To run these tests, install the -dependencies from `requirements-dev.txt` (array-api-compat has [no hard +dependencies from the `dev` optional group (array-api-compat has [no hard runtime dependencies](no-dependencies)). array-api-tests is run against all supported libraries are tested on CI diff --git a/docs/requirements.txt b/docs/requirements.txt deleted file mode 100644 index dbec7740..00000000 --- a/docs/requirements.txt +++ /dev/null @@ -1,6 +0,0 @@ -furo -linkify-it-py -myst-parser -sphinx -sphinx-copybutton -sphinx-autobuild diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..f17c720f --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,96 @@ +[build-system] +requires = ["setuptools", "setuptools-scm"] +build-backend = "setuptools.build_meta" + +[project] +name = "array-api-compat" +dynamic = ["version"] +description = "A wrapper around NumPy and other array libraries to make them compatible with the Array API standard" +readme = "README.md" +requires-python = ">=3.9" +license = "MIT" +authors = [{name = "Consortium for Python Data API Standards"}] +classifiers = [ + "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Topic :: Software Development :: Libraries :: Python Modules", + "Typing :: Typed", +] + +[project.optional-dependencies] +cupy = ["cupy"] +dask = ["dask"] +jax = ["jax"] +numpy = ["numpy"] +pytorch = ["torch"] +sparse = ["sparse>=0.15.1"] +docs = [ + "furo", + "linkify-it-py", + "myst-parser", + "sphinx", + "sphinx-copybutton", + "sphinx-autobuild", +] +dev = [ + "array-api-strict", + "dask[array]", + "jax[cpu]", + "numpy", + "pytest", + "torch", + "sparse>=0.15.1", + "ndonnx; python_version>=\"3.10\"" +] + +[project.urls] +homepage = "https://data-apis.org/array-api-compat/" +repository = "https://github.com/data-apis/array-api-compat/" + +[tool.setuptools.dynamic] +version = {attr = "array_api_compat.__version__"} + +[tool.setuptools.packages.find] +include = ["array_api_compat*"] +namespaces = false + +[toolint] +preview = true +select = [ +# Defaults +"E4", "E7", "E9", "F", +# Undefined export +"F822", +# Useless import alias +"PLC0414" +] + +ignore = [ + # Module import not at top of file + "E402", + # Do not use bare `except` + "E722" +] + +[tool.ruff.lint] +preview = true +select = [ +# Defaults +"E4", "E7", "E9", "F", +# Undefined export +"F822", +# Useless import alias +"PLC0414" +] + +ignore = [ + # Module import not at top of file + "E402", + # Do not use bare `except` + "E722" +] diff --git a/requirements-dev.txt b/requirements-dev.txt deleted file mode 100644 index c9d10f71..00000000 --- a/requirements-dev.txt +++ /dev/null @@ -1,8 +0,0 @@ -array-api-strict -dask[array] -jax[cpu] -numpy -pytest -torch -sparse >=0.15.1 -ndonnx diff --git a/ruff.toml b/ruff.toml deleted file mode 100644 index 72e111b5..00000000 --- a/ruff.toml +++ /dev/null @@ -1,17 +0,0 @@ -[lint] -preview = true -select = [ -# Defaults -"E4", "E7", "E9", "F", -# Undefined export -"F822", -# Useless import alias -"PLC0414" -] - -ignore = [ - # Module import not at top of file - "E402", - # Do not use bare `except` - "E722" -] diff --git a/setup.py b/setup.py deleted file mode 100644 index 2368ccc4..00000000 --- a/setup.py +++ /dev/null @@ -1,40 +0,0 @@ -from setuptools import setup, find_packages - -with open("README.md", "r") as fh: - long_description = fh.read() - -import array_api_compat - -setup( - name='array_api_compat', - version=array_api_compat.__version__, - packages=find_packages(include=["array_api_compat*"]), - author="Consortium for Python Data API Standards", - description="A wrapper around NumPy and other array libraries to make them compatible with the Array API standard", - long_description=long_description, - long_description_content_type="text/markdown", - url="https://data-apis.org/array-api-compat/", - license="MIT", - extras_require={ - "numpy": "numpy", - "cupy": "cupy", - "jax": "jax", - "pytorch": "pytorch", - "dask": "dask", - "sparse": "sparse >=0.15.1", - }, - python_requires=">=3.9", - classifiers=[ - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", - "Programming Language :: Python :: 3.12", - "Programming Language :: Python :: 3.13", - "License :: OSI Approved :: MIT License", - "Operating System :: OS Independent", - ], - package_data={ - "array_api_compat": ["py.typed"], - }, -) From 1db3fae0f682199bda3ae920f8a695e4f579b439 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Tue, 25 Mar 2025 18:07:45 +0000 Subject: [PATCH 20/80] ENH: correct Dask capabilities --- array_api_compat/dask/array/_info.py | 22 ++++++++++++++++------ dask-xfails.txt | 8 +++++--- 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/array_api_compat/dask/array/_info.py b/array_api_compat/dask/array/_info.py index e15a69f4..fc70b5a2 100644 --- a/array_api_compat/dask/array/_info.py +++ b/array_api_compat/dask/array/_info.py @@ -68,11 +68,22 @@ def capabilities(self): The resulting dictionary has the following keys: - **"boolean indexing"**: boolean indicating whether an array library - supports boolean indexing. Always ``False`` for Dask. + supports boolean indexing. + + Dask support boolean indexing as long as both the index + and the indexed arrays have known shapes. + Note however that the output .shape and .size properties + will contain a non-compliant math.nan instead of None. - **"data-dependent shapes"**: boolean indicating whether an array - library supports data-dependent output shapes. Always ``False`` for - Dask. + library supports data-dependent output shapes. + + Dask implements unique_values et.al. + Note however that the output .shape and .size properties + will contain a non-compliant math.nan instead of None. + + - **"max dimensions"**: integer indicating the maximum number of + dimensions supported by the array library. See https://data-apis.org/array-api/latest/API_specification/generated/array_api.info.capabilities.html @@ -99,9 +110,8 @@ def capabilities(self): """ return { - "boolean indexing": False, - "data-dependent shapes": False, - # 'max rank' will be part of the 2024.12 standard + "boolean indexing": True, + "data-dependent shapes": True, "max dimensions": 64, } diff --git a/dask-xfails.txt b/dask-xfails.txt index d2474f9f..bd65d004 100644 --- a/dask-xfails.txt +++ b/dask-xfails.txt @@ -28,12 +28,14 @@ array_api_tests/test_has_names.py::test_has_names[array_method-to_device] array_api_tests/test_has_names.py::test_has_names[array_attribute-device] array_api_tests/test_has_names.py::test_has_names[array_attribute-mT] -# Fails because shape is NaN since we don't materialize it yet +# Data-dependent output shape +# These tests fail as array-api-tests doesn't cope with unknown shapes +# Also, output shape is (math.nan, ) instead of (None, ) +# Also, da.unique() doesn't accept equals_nan which causes non-compliant +# output when there are NaNs in the input. array_api_tests/test_searching_functions.py::test_nonzero array_api_tests/test_set_functions.py::test_unique_all array_api_tests/test_set_functions.py::test_unique_counts - -# Different error but same cause as above, we're just trying to do ndindex on nan shape array_api_tests/test_set_functions.py::test_unique_inverse array_api_tests/test_set_functions.py::test_unique_values From 71d90ead399c03f5fcbc15d205d7cedb6bc9825c Mon Sep 17 00:00:00 2001 From: Lucas Colley Date: Sun, 30 Mar 2025 09:19:56 +0100 Subject: [PATCH 21/80] Update test_all.py Co-authored-by: Evgeni Burovski --- tests/test_all.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_all.py b/tests/test_all.py index 4df4a361..271cd189 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -33,6 +33,7 @@ def test_all(library): else: import_(library, wrapper=True) + # NB: iterate over a copy to avoid a "dictionary size changed" error for mod_name in sys.modules.copy(): if not mod_name.startswith('array_api_compat.' + library): continue From b2af137864a484908fc96fddb1e47af56f0a4adf Mon Sep 17 00:00:00 2001 From: Guido Imperiale Date: Mon, 31 Mar 2025 23:51:14 +0100 Subject: [PATCH 22/80] TYP: Type annotations overhaul, part 2 (#291) --- array_api_compat/common/_aliases.py | 4 ++-- array_api_compat/cupy/_aliases.py | 5 ++++- array_api_compat/dask/array/_aliases.py | 5 ++++- array_api_compat/numpy/_aliases.py | 5 ++++- array_api_compat/torch/_aliases.py | 14 +++++++++++--- array_api_compat/torch/linalg.py | 2 +- 6 files changed, 26 insertions(+), 9 deletions(-) diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 0d1ecfbc..03910681 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -73,7 +73,7 @@ def eye( def full( shape: Union[int, Tuple[int, ...]], - fill_value: complex, + fill_value: bool | int | float | complex, xp: Namespace, *, dtype: Optional[DType] = None, @@ -86,7 +86,7 @@ def full( def full_like( x: Array, /, - fill_value: complex, + fill_value: bool | int | float | complex, *, xp: Namespace, dtype: Optional[DType] = None, diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index ebc7ccd9..423fd10a 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -68,7 +68,10 @@ # asarray also adds the copy keyword, which is not present in numpy 1.0. def asarray( obj: ( - Array | bool | complex | NestedSequence[bool | complex] | SupportsBufferProtocol + Array + | bool | int | float | complex + | NestedSequence[bool | int | float | complex] + | SupportsBufferProtocol ), /, *, diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index e737cebd..e6eff359 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -136,7 +136,10 @@ def arange( # asarray also adds the copy keyword, which is not present in numpy 1.0. def asarray( obj: ( - Array | bool | complex | NestedSequence[bool | complex] | SupportsBufferProtocol + Array + | bool | int | float | complex + | NestedSequence[bool | int | float | complex] + | SupportsBufferProtocol ), /, *, diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index 6536d9a8..1d084b2b 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -77,7 +77,10 @@ def _supports_buffer_protocol(obj): # rather than trying to combine everything into one function in common/ def asarray( obj: ( - Array | bool | complex | NestedSequence[bool | complex] | SupportsBufferProtocol + Array + | bool | int | float | complex + | NestedSequence[bool | int | float | complex] + | SupportsBufferProtocol ), /, *, diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 87d32d85..982500b0 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -116,7 +116,9 @@ def _fix_promotion(x1, x2, only_scalar=True): _py_scalars = (bool, int, float, complex) -def result_type(*arrays_and_dtypes: Array | DType | complex) -> DType: +def result_type( + *arrays_and_dtypes: Array | DType | bool | int | float | complex +) -> DType: num = len(arrays_and_dtypes) if num == 0: @@ -550,10 +552,16 @@ def count_nonzero( return result -def where(condition: Array, x1: Array, x2: Array, /) -> Array: +def where( + condition: Array, + x1: Array | bool | int | float | complex, + x2: Array | bool | int | float | complex, + /, +) -> Array: x1, x2 = _fix_promotion(x1, x2) return torch.where(condition, x1, x2) + # torch.reshape doesn't have the copy keyword def reshape(x: Array, /, @@ -622,7 +630,7 @@ def linspace(start: Union[int, float], # torch.full does not accept an int size # https://github.com/pytorch/pytorch/issues/70906 def full(shape: Union[int, Tuple[int, ...]], - fill_value: complex, + fill_value: bool | int | float | complex, *, dtype: Optional[DType] = None, device: Optional[Device] = None, diff --git a/array_api_compat/torch/linalg.py b/array_api_compat/torch/linalg.py index 7b59a670..1ff7319d 100644 --- a/array_api_compat/torch/linalg.py +++ b/array_api_compat/torch/linalg.py @@ -85,7 +85,7 @@ def vector_norm( axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, # float stands for inf | -inf, which are not valid for Literal - ord: Union[int, float, float] = 2, + ord: Union[int, float] = 2, **kwargs, ) -> Array: # torch.vector_norm incorrectly treats axis=() the same as axis=None From 29f494160a7657dc4da21113851bd6880e39dc7c Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 3 Apr 2025 10:26:59 +0100 Subject: [PATCH 23/80] TST: bump to ndonnx 0.10.1 --- tests/test_common.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/test_common.py b/tests/test_common.py index f86e0936..bbf14572 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -234,6 +234,7 @@ def test_asarray_cross_library(source_library, target_library, request): # TODO: remove xfail once # https://github.com/dask/dask/issues/8260 is resolved xfail(request, reason="Bug in dask raising error on conversion") + elif ( source_library == "ndonnx" and target_library not in ("array_api_strict", "ndonnx", "numpy") @@ -241,6 +242,9 @@ def test_asarray_cross_library(source_library, target_library, request): xfail(request, reason="The truth value of lazy Array Array(dtype=Boolean) is unknown") elif source_library == "ndonnx" and target_library == "numpy": xfail(request, reason="produces numpy array of ndonnx scalar arrays") + elif target_library == "ndonnx" and source_library in ("torch", "dask.array", "jax.numpy"): + xfail(request, reason="unable to infer dtype") + elif source_library == "jax.numpy" and target_library == "torch": xfail(request, reason="casts int to float") elif source_library == "cupy" and target_library != "cupy": From f80f15792ec981e943bef7f49faff687ef29b27c Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 2 Apr 2025 17:07:20 +0100 Subject: [PATCH 24/80] ENH: wrap iinfo/finfo --- array_api_compat/common/_aliases.py | 21 +++++++++++++++++++-- array_api_compat/cupy/_aliases.py | 2 ++ array_api_compat/dask/array/_aliases.py | 9 ++++----- array_api_compat/numpy/_aliases.py | 2 ++ array_api_compat/torch/_aliases.py | 5 ++++- cupy-xfails.txt | 11 ++++++++--- dask-xfails.txt | 10 ++++++++-- numpy-1-21-xfails.txt | 12 +++++++++--- numpy-1-26-xfails.txt | 10 ++++++++-- numpy-dev-xfails.txt | 10 ++++++++-- numpy-xfails.txt | 10 ++++++++-- torch-xfails.txt | 4 ++++ 12 files changed, 84 insertions(+), 22 deletions(-) diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 03910681..46cbb359 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -5,7 +5,7 @@ from __future__ import annotations import inspect -from typing import NamedTuple, Optional, Sequence, Tuple, Union +from typing import Any, NamedTuple, Optional, Sequence, Tuple, Union from ._typing import Array, Device, DType, Namespace from ._helpers import ( @@ -609,6 +609,23 @@ def sign(x: Array, /, xp: Namespace, **kwargs) -> Array: out[xp.isnan(x)] = xp.nan return out[()] + +def finfo(type_: DType | Array, /, xp: Namespace) -> Any: + # It is surprisingly difficult to recognize a dtype apart from an array. + # np.int64 is not the same as np.asarray(1).dtype! + try: + return xp.finfo(type_) + except (ValueError, TypeError): + return xp.finfo(type_.dtype) + + +def iinfo(type_: DType | Array, /, xp: Namespace) -> Any: + try: + return xp.iinfo(type_) + except (ValueError, TypeError): + return xp.iinfo(type_.dtype) + + __all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like', 'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like', 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult', @@ -616,6 +633,6 @@ def sign(x: Array, /, xp: Namespace, **kwargs) -> Array: 'std', 'var', 'cumulative_sum', 'cumulative_prod','clip', 'permute_dims', 'reshape', 'argsort', 'sort', 'nonzero', 'ceil', 'floor', 'trunc', 'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype', - 'unstack', 'sign'] + 'unstack', 'sign', 'finfo', 'iinfo'] _all_ignore = ['inspect', 'array_namespace', 'NamedTuple'] diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index 423fd10a..fd1460ae 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -61,6 +61,8 @@ matrix_transpose = get_xp(cp)(_aliases.matrix_transpose) tensordot = get_xp(cp)(_aliases.tensordot) sign = get_xp(cp)(_aliases.sign) +finfo = get_xp(cp)(_aliases.finfo) +iinfo = get_xp(cp)(_aliases.iinfo) _copy_default = object() diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index e6eff359..dca6d570 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -5,8 +5,6 @@ import numpy as np from numpy import ( # dtypes - iinfo, - finfo, bool_ as bool, float32, float64, @@ -131,6 +129,8 @@ def arange( matmul = get_xp(np)(_aliases.matmul) tensordot = get_xp(np)(_aliases.tensordot) sign = get_xp(np)(_aliases.sign) +finfo = get_xp(np)(_aliases.finfo) +iinfo = get_xp(np)(_aliases.iinfo) # asarray also adds the copy keyword, which is not present in numpy 1.0. @@ -343,10 +343,9 @@ def count_nonzero( '__array_namespace_info__', 'asarray', 'astype', 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atan2', 'atanh', 'bitwise_left_shift', 'bitwise_invert', - 'bitwise_right_shift', 'concat', 'pow', 'iinfo', 'finfo', 'can_cast', + 'bitwise_right_shift', 'concat', 'pow', 'can_cast', 'result_type', 'bool', 'float32', 'float64', 'int8', 'int16', 'int32', 'int64', - 'uint8', 'uint16', 'uint32', 'uint64', - 'complex64', 'complex128', 'iinfo', 'finfo', + 'uint8', 'uint16', 'uint32', 'uint64', 'complex64', 'complex128', 'can_cast', 'count_nonzero', 'result_type'] _all_ignore = ["array_namespace", "get_xp", "da", "np"] diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index 1d084b2b..ae0d006d 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -61,6 +61,8 @@ matrix_transpose = get_xp(np)(_aliases.matrix_transpose) tensordot = get_xp(np)(_aliases.tensordot) sign = get_xp(np)(_aliases.sign) +finfo = get_xp(np)(_aliases.finfo) +iinfo = get_xp(np)(_aliases.iinfo) def _supports_buffer_protocol(obj): diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 982500b0..9384e4c0 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -227,6 +227,9 @@ def min(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keep unstack = get_xp(torch)(_aliases.unstack) cumulative_sum = get_xp(torch)(_aliases.cumulative_sum) cumulative_prod = get_xp(torch)(_aliases.cumulative_prod) +finfo = get_xp(torch)(_aliases.finfo) +iinfo = get_xp(torch)(_aliases.iinfo) + # torch.sort also returns a tuple # https://github.com/pytorch/pytorch/issues/70921 @@ -832,6 +835,6 @@ def sign(x: Array, /) -> Array: 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult', 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values', 'matmul', 'matrix_transpose', 'vecdot', 'tensordot', 'isdtype', - 'take', 'take_along_axis', 'sign'] + 'take', 'take_along_axis', 'sign', 'finfo', 'iinfo'] _all_ignore = ['torch', 'get_xp'] diff --git a/cupy-xfails.txt b/cupy-xfails.txt index 3d20d745..f4cd1e36 100644 --- a/cupy-xfails.txt +++ b/cupy-xfails.txt @@ -14,9 +14,14 @@ array_api_tests/test_array_object.py::test_getitem # copy=False is not yet implemented array_api_tests/test_creation_functions.py::test_asarray_arrays -# finfo test is testing that the result is a float instead of float32 (see -# also https://github.com/data-apis/array-api/issues/405) -array_api_tests/test_data_type_functions.py::test_finfo[float32] +# attributes are np.float32 instead of float +# (see also https://github.com/data-apis/array-api/issues/405) +array_api_tests/test_data_type_functions.py::test_finfo[float32-dtype] +array_api_tests/test_data_type_functions.py::test_finfo[float32-array] +array_api_tests/test_data_type_functions.py::test_finfo[float32-array.dtype] +array_api_tests/test_data_type_functions.py::test_finfo[complex64-dtype] +array_api_tests/test_data_type_functions.py::test_finfo[complex64-array] +array_api_tests/test_data_type_functions.py::test_finfo[complex64-array.dtype] # Some array attributes are missing, and we do not wrap the array object array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__] diff --git a/dask-xfails.txt b/dask-xfails.txt index bd65d004..abab825c 100644 --- a/dask-xfails.txt +++ b/dask-xfails.txt @@ -12,8 +12,14 @@ array_api_tests/test_array_object.py::test_getitem_masking # zero division error, and typeerror: tuple indices must be integers or slices not tuple array_api_tests/test_creation_functions.py::test_eye -# finfo(float32).eps returns float32 but should return float -array_api_tests/test_data_type_functions.py::test_finfo[float32] +# attributes are np.float32 instead of float +# (see also https://github.com/data-apis/array-api/issues/405) +array_api_tests/test_data_type_functions.py::test_finfo[float32-dtype] +array_api_tests/test_data_type_functions.py::test_finfo[float32-array] +array_api_tests/test_data_type_functions.py::test_finfo[float32-array.dtype] +array_api_tests/test_data_type_functions.py::test_finfo[complex64-dtype] +array_api_tests/test_data_type_functions.py::test_finfo[complex64-array] +array_api_tests/test_data_type_functions.py::test_finfo[complex64-array.dtype] # out[-1]=dask.array but should be some floating number # (I think the test is not forcing the op to be computed?) diff --git a/numpy-1-21-xfails.txt b/numpy-1-21-xfails.txt index 30cde668..93a90757 100644 --- a/numpy-1-21-xfails.txt +++ b/numpy-1-21-xfails.txt @@ -1,8 +1,14 @@ # asarray(copy=False) is not yet implemented array_api_tests/test_creation_functions.py::test_asarray_arrays -# finfo(float32).eps returns float32 but should return float -array_api_tests/test_data_type_functions.py::test_finfo[float32] +# attributes are np.float32 instead of float +# (see also https://github.com/data-apis/array-api/issues/405) +array_api_tests/test_data_type_functions.py::test_finfo[float32-dtype] +array_api_tests/test_data_type_functions.py::test_finfo[float32-array] +array_api_tests/test_data_type_functions.py::test_finfo[float32-array.dtype] +array_api_tests/test_data_type_functions.py::test_finfo[complex64-dtype] +array_api_tests/test_data_type_functions.py::test_finfo[complex64-array] +array_api_tests/test_data_type_functions.py::test_finfo[complex64-array.dtype] # Array methods and attributes not already on np.ndarray cannot be wrapped array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__] @@ -41,7 +47,7 @@ array_api_tests/meta/test_hypothesis_helpers.py::test_symmetric_matrices ############################ # finfo has no smallest_normal -array_api_tests/test_data_type_functions.py::test_finfo[float64] +array_api_tests/test_data_type_functions.py::test_finfo # dlpack stuff array_api_tests/test_has_names.py::test_has_names[creation-from_dlpack] diff --git a/numpy-1-26-xfails.txt b/numpy-1-26-xfails.txt index 1ce28ef4..84916e73 100644 --- a/numpy-1-26-xfails.txt +++ b/numpy-1-26-xfails.txt @@ -1,5 +1,11 @@ -# finfo(float32).eps returns float32 but should return float -array_api_tests/test_data_type_functions.py::test_finfo[float32] +# attributes are np.float32 instead of float +# (see also https://github.com/data-apis/array-api/issues/405) +array_api_tests/test_data_type_functions.py::test_finfo[float32-dtype] +array_api_tests/test_data_type_functions.py::test_finfo[float32-array] +array_api_tests/test_data_type_functions.py::test_finfo[float32-array.dtype] +array_api_tests/test_data_type_functions.py::test_finfo[complex64-dtype] +array_api_tests/test_data_type_functions.py::test_finfo[complex64-array] +array_api_tests/test_data_type_functions.py::test_finfo[complex64-array.dtype] # Array methods and attributes not already on np.ndarray cannot be wrapped array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__] diff --git a/numpy-dev-xfails.txt b/numpy-dev-xfails.txt index 98659710..31bcb63b 100644 --- a/numpy-dev-xfails.txt +++ b/numpy-dev-xfails.txt @@ -1,5 +1,11 @@ -# finfo(float32).eps returns float32 but should return float -array_api_tests/test_data_type_functions.py::test_finfo[float32] +# attributes are np.float32 instead of float +# (see also https://github.com/data-apis/array-api/issues/405) +array_api_tests/test_data_type_functions.py::test_finfo[float32-dtype] +array_api_tests/test_data_type_functions.py::test_finfo[float32-array] +array_api_tests/test_data_type_functions.py::test_finfo[float32-array.dtype] +array_api_tests/test_data_type_functions.py::test_finfo[complex64-dtype] +array_api_tests/test_data_type_functions.py::test_finfo[complex64-array] +array_api_tests/test_data_type_functions.py::test_finfo[complex64-array.dtype] # The test suite cannot properly get the signature for vecdot # https://github.com/numpy/numpy/pull/26237 diff --git a/numpy-xfails.txt b/numpy-xfails.txt index 0885dcaa..0810aea6 100644 --- a/numpy-xfails.txt +++ b/numpy-xfails.txt @@ -1,5 +1,11 @@ -# finfo(float32).eps returns float32 but should return float -array_api_tests/test_data_type_functions.py::test_finfo[float32] +# attributes are np.float32 instead of float +# (see also https://github.com/data-apis/array-api/issues/405) +array_api_tests/test_data_type_functions.py::test_finfo[float32-dtype] +array_api_tests/test_data_type_functions.py::test_finfo[float32-array] +array_api_tests/test_data_type_functions.py::test_finfo[float32-array.dtype] +array_api_tests/test_data_type_functions.py::test_finfo[complex64-dtype] +array_api_tests/test_data_type_functions.py::test_finfo[complex64-array] +array_api_tests/test_data_type_functions.py::test_finfo[complex64-array.dtype] # The test suite cannot properly get the signature for vecdot # https://github.com/numpy/numpy/pull/26237 diff --git a/torch-xfails.txt b/torch-xfails.txt index f8333d90..e556fa4f 100644 --- a/torch-xfails.txt +++ b/torch-xfails.txt @@ -115,6 +115,10 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_round array_api_tests/test_set_functions.py::test_unique_counts array_api_tests/test_set_functions.py::test_unique_values +# finfo/iinfo.dtype is a string instead of a dtype +array_api_tests/test_data_type_functions.py::test_finfo_dtype +array_api_tests/test_data_type_functions.py::test_iinfo_dtype + # 2023.12 support array_api_tests/test_has_names.py::test_has_names[manipulation-repeat] array_api_tests/test_manipulation_functions.py::test_repeat From 37b1c475c98fb092135ef021f11b7f79cd46debd Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 3 Apr 2025 12:42:35 +0100 Subject: [PATCH 25/80] MAINT: validate device on numpy and dask --- array_api_compat/common/_helpers.py | 24 +++++++++++++++++++++--- array_api_compat/dask/array/_aliases.py | 5 ++++- array_api_compat/numpy/_aliases.py | 6 +++--- 3 files changed, 28 insertions(+), 7 deletions(-) diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 6d95069d..67c619b8 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -595,11 +595,29 @@ def your_function(x, y): # backwards compatibility alias get_namespace = array_namespace -def _check_device(xp, device): - if xp == sys.modules.get('numpy'): - if device not in ["cpu", None]: + +def _check_device(bare_xp, device): + """ + Validate dummy device on device-less array backends. + + Notes + ----- + This function is also invoked by CuPy, which does have multiple devices + if there are multiple GPUs available. + However, CuPy multi-device support is currently impossible + without using the global device or a context manager: + + https://github.com/data-apis/array-api-compat/pull/293 + """ + if bare_xp is sys.modules.get('numpy'): + if device not in ("cpu", None): raise ValueError(f"Unsupported device for NumPy: {device!r}") + elif bare_xp is sys.modules.get('dask.array'): + if device not in ("cpu", _DASK_DEVICE, None): + raise ValueError(f"Unsupported device for Dask: {device!r}") + + # Placeholder object to represent the dask device # when the array backend is not the CPU. # (since it is not easy to tell which device a dask array is on) diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index e6eff359..c5cd7489 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -25,7 +25,7 @@ ) import dask.array as da -from ...common import _aliases, array_namespace +from ...common import _aliases, _helpers, array_namespace from ...common._typing import ( Array, Device, @@ -56,6 +56,7 @@ def astype( specification for more details. """ # TODO: respect device keyword? + _helpers._check_device(da, device) if not copy and dtype == x.dtype: return x @@ -86,6 +87,7 @@ def arange( specification for more details. """ # TODO: respect device keyword? + _helpers._check_device(da, device) args = [start] if stop is not None: @@ -155,6 +157,7 @@ def asarray( specification for more details. """ # TODO: respect device keyword? + _helpers._check_device(da, device) if isinstance(obj, da.Array): if dtype is not None and dtype != obj.dtype: diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index 1d084b2b..d5b7feac 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -3,7 +3,7 @@ from typing import Optional, Union from .._internal import get_xp -from ..common import _aliases +from ..common import _aliases, _helpers from ..common._typing import NestedSequence, SupportsBufferProtocol from ._info import __array_namespace_info__ from ._typing import Array, Device, DType @@ -95,8 +95,7 @@ def asarray( See the corresponding documentation in the array library and/or the array API specification for more details. """ - if device not in ["cpu", None]: - raise ValueError(f"Unsupported device for NumPy: {device!r}") + _helpers._check_device(np, device) if hasattr(np, '_CopyMode'): if copy is None: @@ -122,6 +121,7 @@ def astype( copy: bool = True, device: Optional[Device] = None, ) -> Array: + _helpers._check_device(np, device) return x.astype(dtype=dtype, copy=copy) From 2c1cb6b515849048cd062e31462b6a193b81471c Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 3 Apr 2025 13:30:49 +0100 Subject: [PATCH 26/80] BUG: Don't import helpers in namespaces --- array_api_compat/common/_linalg.py | 2 ++ array_api_compat/cupy/__init__.py | 3 --- array_api_compat/dask/array/__init__.py | 1 + array_api_compat/numpy/__init__.py | 9 --------- array_api_compat/numpy/_aliases.py | 2 +- array_api_compat/torch/__init__.py | 6 ++---- tests/test_common.py | 2 +- 7 files changed, 7 insertions(+), 18 deletions(-) diff --git a/array_api_compat/common/_linalg.py b/array_api_compat/common/_linalg.py index c77ee3b8..d1e7ebd8 100644 --- a/array_api_compat/common/_linalg.py +++ b/array_api_compat/common/_linalg.py @@ -174,3 +174,5 @@ def trace( 'svd', 'cholesky', 'matrix_rank', 'pinv', 'matrix_norm', 'matrix_transpose', 'svdvals', 'vecdot', 'vector_norm', 'diagonal', 'trace'] + +_all_ignore = ['math', 'normalize_axis_tuple', 'get_xp', 'np', 'isdtype'] diff --git a/array_api_compat/cupy/__init__.py b/array_api_compat/cupy/__init__.py index 59e01058..9a30f95d 100644 --- a/array_api_compat/cupy/__init__.py +++ b/array_api_compat/cupy/__init__.py @@ -8,9 +8,6 @@ # See the comment in the numpy __init__.py __import__(__package__ + '.linalg') - __import__(__package__ + '.fft') -from ..common._helpers import * # noqa: F401,F403 - __array_api_version__ = '2024.12' diff --git a/array_api_compat/dask/array/__init__.py b/array_api_compat/dask/array/__init__.py index a6e69ad3..bb649306 100644 --- a/array_api_compat/dask/array/__init__.py +++ b/array_api_compat/dask/array/__init__.py @@ -5,5 +5,6 @@ __array_api_version__ = '2024.12' +# See the comment in the numpy __init__.py __import__(__package__ + '.linalg') __import__(__package__ + '.fft') diff --git a/array_api_compat/numpy/__init__.py b/array_api_compat/numpy/__init__.py index 02c55d28..6a5d9867 100644 --- a/array_api_compat/numpy/__init__.py +++ b/array_api_compat/numpy/__init__.py @@ -14,17 +14,8 @@ # It doesn't overwrite np.linalg from above. The import is generated # dynamically so that the library can be vendored. __import__(__package__ + '.linalg') - __import__(__package__ + '.fft') from .linalg import matrix_transpose, vecdot # noqa: F401 -from ..common._helpers import * # noqa: F403 - -try: - # Used in asarray(). Not present in older versions. - from numpy import _CopyMode # noqa: F401 -except ImportError: - pass - __array_api_version__ = '2024.12' diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index 1d084b2b..9e4f1174 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -86,7 +86,7 @@ def asarray( *, dtype: Optional[DType] = None, device: Optional[Device] = None, - copy: "Optional[Union[bool, np._CopyMode]]" = None, + copy: Optional[Union[bool, np._CopyMode]] = None, **kwargs, ) -> Array: """ diff --git a/array_api_compat/torch/__init__.py b/array_api_compat/torch/__init__.py index a985986e..69fd19ce 100644 --- a/array_api_compat/torch/__init__.py +++ b/array_api_compat/torch/__init__.py @@ -9,16 +9,14 @@ or 'cpu' in n or 'backward' in n): continue - exec(n + ' = torch.' + n) + exec(f"{n} = torch.{n}") +del n # These imports may overwrite names from the import * above. from ._aliases import * # noqa: F403 # See the comment in the numpy __init__.py __import__(__package__ + '.linalg') - __import__(__package__ + '.fft') -from ..common._helpers import * # noqa: F403 - __array_api_version__ = '2024.12' diff --git a/tests/test_common.py b/tests/test_common.py index bbf14572..54024d47 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -276,7 +276,7 @@ def test_asarray_copy(library): is_lib_func = globals()[is_array_functions[library]] all = xp.all if library != 'dask.array' else lambda x: xp.all(x).compute() - if library == 'numpy' and xp.__version__[0] < '2' and not hasattr(xp, '_CopyMode') : + if library == 'numpy' and xp.__version__[0] < '2' and not hasattr(np, "_CopyMode"): supports_copy_false_other_ns = False supports_copy_false_same_ns = False elif library == 'cupy': From 621494be1bd8682f1d76ae874272c12464953d3d Mon Sep 17 00:00:00 2001 From: Guido Imperiale Date: Fri, 4 Apr 2025 12:21:20 +0100 Subject: [PATCH 27/80] ENH: torch.asarray device propagation (#299) --- array_api_compat/torch/_aliases.py | 31 ++++++++++++++++++++++++------ array_api_compat/torch/_typing.py | 5 ++--- 2 files changed, 27 insertions(+), 9 deletions(-) diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 982500b0..0891525a 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -2,12 +2,13 @@ from functools import reduce as _reduce, wraps as _wraps from builtins import all as _builtin_all, any as _builtin_any -from typing import List, Optional, Sequence, Tuple, Union +from typing import Any, List, Optional, Sequence, Tuple, Union import torch from .._internal import get_xp from ..common import _aliases +from ..common._typing import NestedSequence, SupportsBufferProtocol from ._info import __array_namespace_info__ from ._typing import Array, Device, DType @@ -207,6 +208,28 @@ def can_cast(from_: Union[DType, Array], to: DType, /) -> bool: remainder = _two_arg(torch.remainder) subtract = _two_arg(torch.subtract) + +def asarray( + obj: ( + Array + | bool | int | float | complex + | NestedSequence[bool | int | float | complex] + | SupportsBufferProtocol + ), + /, + *, + dtype: DType | None = None, + device: Device | None = None, + copy: bool | None = None, + **kwargs: Any, +) -> Array: + # torch.asarray does not respect input->output device propagation + # https://github.com/pytorch/pytorch/issues/150199 + if device is None and isinstance(obj, torch.Tensor): + device = obj.device + return torch.asarray(obj, dtype=dtype, device=device, copy=copy, **kwargs) + + # These wrappers are mostly based on the fact that pytorch uses 'dim' instead # of 'axis'. @@ -282,7 +305,6 @@ def prod(x: Array, dtype: Optional[DType] = None, keepdims: bool = False, **kwargs) -> Array: - x = torch.asarray(x) ndim = x.ndim # https://github.com/pytorch/pytorch/issues/29137. Separate from the logic @@ -318,7 +340,6 @@ def sum(x: Array, dtype: Optional[DType] = None, keepdims: bool = False, **kwargs) -> Array: - x = torch.asarray(x) ndim = x.ndim # https://github.com/pytorch/pytorch/issues/29137. @@ -348,7 +369,6 @@ def any(x: Array, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, **kwargs) -> Array: - x = torch.asarray(x) ndim = x.ndim if axis == (): return x.to(torch.bool) @@ -373,7 +393,6 @@ def all(x: Array, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, **kwargs) -> Array: - x = torch.asarray(x) ndim = x.ndim if axis == (): return x.to(torch.bool) @@ -816,7 +835,7 @@ def sign(x: Array, /) -> Array: return out -__all__ = ['__array_namespace_info__', 'result_type', 'can_cast', +__all__ = ['__array_namespace_info__', 'asarray', 'result_type', 'can_cast', 'permute_dims', 'bitwise_invert', 'newaxis', 'conj', 'add', 'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'copysign', 'count_nonzero', diff --git a/array_api_compat/torch/_typing.py b/array_api_compat/torch/_typing.py index 29ad3fa7..52670871 100644 --- a/array_api_compat/torch/_typing.py +++ b/array_api_compat/torch/_typing.py @@ -1,4 +1,3 @@ -__all__ = ["Array", "DType", "Device"] +__all__ = ["Array", "Device", "DType"] -from torch import dtype as DType, Tensor as Array -from ..common._typing import Device +from torch import device as Device, dtype as DType, Tensor as Array From c629a64c928bd76fdf0bec28a1399467801364be Mon Sep 17 00:00:00 2001 From: crusaderky Date: Mon, 7 Apr 2025 10:27:05 +0100 Subject: [PATCH 28/80] Simplify test parametrization --- cupy-xfails.txt | 8 ++------ dask-xfails.txt | 8 ++------ numpy-1-21-xfails.txt | 8 ++------ numpy-1-26-xfails.txt | 8 ++------ numpy-dev-xfails.txt | 8 ++------ numpy-xfails.txt | 8 ++------ 6 files changed, 12 insertions(+), 36 deletions(-) diff --git a/cupy-xfails.txt b/cupy-xfails.txt index f4cd1e36..a30572f8 100644 --- a/cupy-xfails.txt +++ b/cupy-xfails.txt @@ -16,12 +16,8 @@ array_api_tests/test_creation_functions.py::test_asarray_arrays # attributes are np.float32 instead of float # (see also https://github.com/data-apis/array-api/issues/405) -array_api_tests/test_data_type_functions.py::test_finfo[float32-dtype] -array_api_tests/test_data_type_functions.py::test_finfo[float32-array] -array_api_tests/test_data_type_functions.py::test_finfo[float32-array.dtype] -array_api_tests/test_data_type_functions.py::test_finfo[complex64-dtype] -array_api_tests/test_data_type_functions.py::test_finfo[complex64-array] -array_api_tests/test_data_type_functions.py::test_finfo[complex64-array.dtype] +array_api_tests/test_data_type_functions.py::test_finfo[float32] +array_api_tests/test_data_type_functions.py::test_finfo[complex64] # Some array attributes are missing, and we do not wrap the array object array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__] diff --git a/dask-xfails.txt b/dask-xfails.txt index abab825c..932aeada 100644 --- a/dask-xfails.txt +++ b/dask-xfails.txt @@ -14,12 +14,8 @@ array_api_tests/test_creation_functions.py::test_eye # attributes are np.float32 instead of float # (see also https://github.com/data-apis/array-api/issues/405) -array_api_tests/test_data_type_functions.py::test_finfo[float32-dtype] -array_api_tests/test_data_type_functions.py::test_finfo[float32-array] -array_api_tests/test_data_type_functions.py::test_finfo[float32-array.dtype] -array_api_tests/test_data_type_functions.py::test_finfo[complex64-dtype] -array_api_tests/test_data_type_functions.py::test_finfo[complex64-array] -array_api_tests/test_data_type_functions.py::test_finfo[complex64-array.dtype] +array_api_tests/test_data_type_functions.py::test_finfo[float32] +array_api_tests/test_data_type_functions.py::test_finfo[complex64] # out[-1]=dask.array but should be some floating number # (I think the test is not forcing the op to be computed?) diff --git a/numpy-1-21-xfails.txt b/numpy-1-21-xfails.txt index 93a90757..66443a73 100644 --- a/numpy-1-21-xfails.txt +++ b/numpy-1-21-xfails.txt @@ -3,12 +3,8 @@ array_api_tests/test_creation_functions.py::test_asarray_arrays # attributes are np.float32 instead of float # (see also https://github.com/data-apis/array-api/issues/405) -array_api_tests/test_data_type_functions.py::test_finfo[float32-dtype] -array_api_tests/test_data_type_functions.py::test_finfo[float32-array] -array_api_tests/test_data_type_functions.py::test_finfo[float32-array.dtype] -array_api_tests/test_data_type_functions.py::test_finfo[complex64-dtype] -array_api_tests/test_data_type_functions.py::test_finfo[complex64-array] -array_api_tests/test_data_type_functions.py::test_finfo[complex64-array.dtype] +array_api_tests/test_data_type_functions.py::test_finfo[float32] +array_api_tests/test_data_type_functions.py::test_finfo[complex64] # Array methods and attributes not already on np.ndarray cannot be wrapped array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__] diff --git a/numpy-1-26-xfails.txt b/numpy-1-26-xfails.txt index 84916e73..ed95083a 100644 --- a/numpy-1-26-xfails.txt +++ b/numpy-1-26-xfails.txt @@ -1,11 +1,7 @@ # attributes are np.float32 instead of float # (see also https://github.com/data-apis/array-api/issues/405) -array_api_tests/test_data_type_functions.py::test_finfo[float32-dtype] -array_api_tests/test_data_type_functions.py::test_finfo[float32-array] -array_api_tests/test_data_type_functions.py::test_finfo[float32-array.dtype] -array_api_tests/test_data_type_functions.py::test_finfo[complex64-dtype] -array_api_tests/test_data_type_functions.py::test_finfo[complex64-array] -array_api_tests/test_data_type_functions.py::test_finfo[complex64-array.dtype] +array_api_tests/test_data_type_functions.py::test_finfo[float32] +array_api_tests/test_data_type_functions.py::test_finfo[complex64] # Array methods and attributes not already on np.ndarray cannot be wrapped array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__] diff --git a/numpy-dev-xfails.txt b/numpy-dev-xfails.txt index 31bcb63b..972d2346 100644 --- a/numpy-dev-xfails.txt +++ b/numpy-dev-xfails.txt @@ -1,11 +1,7 @@ # attributes are np.float32 instead of float # (see also https://github.com/data-apis/array-api/issues/405) -array_api_tests/test_data_type_functions.py::test_finfo[float32-dtype] -array_api_tests/test_data_type_functions.py::test_finfo[float32-array] -array_api_tests/test_data_type_functions.py::test_finfo[float32-array.dtype] -array_api_tests/test_data_type_functions.py::test_finfo[complex64-dtype] -array_api_tests/test_data_type_functions.py::test_finfo[complex64-array] -array_api_tests/test_data_type_functions.py::test_finfo[complex64-array.dtype] +array_api_tests/test_data_type_functions.py::test_finfo[float32] +array_api_tests/test_data_type_functions.py::test_finfo[complex64] # The test suite cannot properly get the signature for vecdot # https://github.com/numpy/numpy/pull/26237 diff --git a/numpy-xfails.txt b/numpy-xfails.txt index 0810aea6..0f09985e 100644 --- a/numpy-xfails.txt +++ b/numpy-xfails.txt @@ -1,11 +1,7 @@ # attributes are np.float32 instead of float # (see also https://github.com/data-apis/array-api/issues/405) -array_api_tests/test_data_type_functions.py::test_finfo[float32-dtype] -array_api_tests/test_data_type_functions.py::test_finfo[float32-array] -array_api_tests/test_data_type_functions.py::test_finfo[float32-array.dtype] -array_api_tests/test_data_type_functions.py::test_finfo[complex64-dtype] -array_api_tests/test_data_type_functions.py::test_finfo[complex64-array] -array_api_tests/test_data_type_functions.py::test_finfo[complex64-array.dtype] +array_api_tests/test_data_type_functions.py::test_finfo[float32] +array_api_tests/test_data_type_functions.py::test_finfo[complex64] # The test suite cannot properly get the signature for vecdot # https://github.com/numpy/numpy/pull/26237 From bff3bf467d6f126015179558f1b8c71242014cbc Mon Sep 17 00:00:00 2001 From: Guido Imperiale Date: Tue, 15 Apr 2025 11:48:45 +0100 Subject: [PATCH 29/80] Drop Python 3.9; test on Python 3.13; drop NumPy 1.21; skip CUDA install (#304) reviewed at https://github.com/data-apis/array-api-compat/pull/304 --- .github/workflows/array-api-tests-dask.yml | 2 +- .../workflows/array-api-tests-numpy-1-21.yml | 11 --- .../workflows/array-api-tests-numpy-1-22.yml | 12 +++ .../workflows/array-api-tests-numpy-1-26.yml | 1 + .../workflows/array-api-tests-numpy-dev.yml | 1 + .../array-api-tests-numpy-latest.yml | 3 +- .github/workflows/array-api-tests-torch.yml | 4 +- .github/workflows/array-api-tests.yml | 23 +++-- .github/workflows/tests.yml | 58 ++++++++----- array_api_compat/cupy/_typing.py | 2 +- array_api_compat/dask/array/_aliases.py | 2 +- array_api_compat/numpy/_aliases.py | 18 ++-- array_api_compat/numpy/_typing.py | 2 +- docs/supported-array-libraries.md | 17 +--- ...y-1-21-xfails.txt => numpy-1-22-xfails.txt | 83 +++---------------- numpy-1-26-xfails.txt | 3 - numpy-skips.txt | 11 --- numpy-xfails.txt | 4 +- pyproject.toml | 16 ++-- tests/test_common.py | 5 +- tests/test_dask.py | 6 +- torch-skips.txt | 11 --- torch-xfails.txt | 4 + 23 files changed, 114 insertions(+), 185 deletions(-) delete mode 100644 .github/workflows/array-api-tests-numpy-1-21.yml create mode 100644 .github/workflows/array-api-tests-numpy-1-22.yml rename numpy-1-21-xfails.txt => numpy-1-22-xfails.txt (68%) diff --git a/.github/workflows/array-api-tests-dask.yml b/.github/workflows/array-api-tests-dask.yml index 2ad98586..afc67975 100644 --- a/.github/workflows/array-api-tests-dask.yml +++ b/.github/workflows/array-api-tests-dask.yml @@ -7,7 +7,6 @@ jobs: uses: ./.github/workflows/array-api-tests.yml with: package-name: dask - package-version: '>= 2024.9.0' module-name: dask.array extra-requires: numpy # Dask is substantially slower then other libraries on unit tests. @@ -16,3 +15,4 @@ jobs: # flakiness. Before changes to dask-xfails.txt or dask-skips.txt, please run # the full test suite with at least 200 examples. pytest-extra-args: --max-examples=5 + python-versions: '[''3.10'', ''3.13'']' diff --git a/.github/workflows/array-api-tests-numpy-1-21.yml b/.github/workflows/array-api-tests-numpy-1-21.yml deleted file mode 100644 index 2d81c3cd..00000000 --- a/.github/workflows/array-api-tests-numpy-1-21.yml +++ /dev/null @@ -1,11 +0,0 @@ -name: Array API Tests (NumPy 1.21) - -on: [push, pull_request] - -jobs: - array-api-tests-numpy-1-21: - uses: ./.github/workflows/array-api-tests.yml - with: - package-name: numpy - package-version: '== 1.21.*' - xfails-file-extra: '-1-21' diff --git a/.github/workflows/array-api-tests-numpy-1-22.yml b/.github/workflows/array-api-tests-numpy-1-22.yml new file mode 100644 index 00000000..d8f60432 --- /dev/null +++ b/.github/workflows/array-api-tests-numpy-1-22.yml @@ -0,0 +1,12 @@ +name: Array API Tests (NumPy 1.22) + +on: [push, pull_request] + +jobs: + array-api-tests-numpy-1-22: + uses: ./.github/workflows/array-api-tests.yml + with: + package-name: numpy + package-version: '== 1.22.*' + xfails-file-extra: '-1-22' + python-versions: '[''3.10'']' diff --git a/.github/workflows/array-api-tests-numpy-1-26.yml b/.github/workflows/array-api-tests-numpy-1-26.yml index 660935f0..33780760 100644 --- a/.github/workflows/array-api-tests-numpy-1-26.yml +++ b/.github/workflows/array-api-tests-numpy-1-26.yml @@ -9,3 +9,4 @@ jobs: package-name: numpy package-version: '== 1.26.*' xfails-file-extra: '-1-26' + python-versions: '[''3.10'', ''3.12'']' diff --git a/.github/workflows/array-api-tests-numpy-dev.yml b/.github/workflows/array-api-tests-numpy-dev.yml index eef4269d..d6de1a53 100644 --- a/.github/workflows/array-api-tests-numpy-dev.yml +++ b/.github/workflows/array-api-tests-numpy-dev.yml @@ -9,3 +9,4 @@ jobs: package-name: numpy extra-requires: '--pre --extra-index https://pypi.anaconda.org/scientific-python-nightly-wheels/simple' xfails-file-extra: '-dev' + python-versions: '[''3.11'', ''3.13'']' diff --git a/.github/workflows/array-api-tests-numpy-latest.yml b/.github/workflows/array-api-tests-numpy-latest.yml index 36984345..4d3667f6 100644 --- a/.github/workflows/array-api-tests-numpy-latest.yml +++ b/.github/workflows/array-api-tests-numpy-latest.yml @@ -1,4 +1,4 @@ -name: Array API Tests (NumPy Latest) +name: Array API Tests (NumPy latest) on: [push, pull_request] @@ -7,3 +7,4 @@ jobs: uses: ./.github/workflows/array-api-tests.yml with: package-name: numpy + python-versions: '[''3.10'', ''3.13'']' diff --git a/.github/workflows/array-api-tests-torch.yml b/.github/workflows/array-api-tests-torch.yml index 56ab81a3..ac20df25 100644 --- a/.github/workflows/array-api-tests-torch.yml +++ b/.github/workflows/array-api-tests-torch.yml @@ -1,4 +1,4 @@ -name: Array API Tests (PyTorch Latest) +name: Array API Tests (PyTorch CPU) on: [push, pull_request] @@ -7,5 +7,7 @@ jobs: uses: ./.github/workflows/array-api-tests.yml with: package-name: torch + extra-requires: '--index-url https://download.pytorch.org/whl/cpu' extra-env-vars: | ARRAY_API_TESTS_SKIP_DTYPES=uint16,uint32,uint64 + python-versions: '[''3.10'', ''3.13'']' diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index 6ace193a..31bedde6 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -16,6 +16,10 @@ on: required: false type: string default: '>= 0' + python-versions: + required: true + type: string + description: JSON array of Python versions to test against. pytest-extra-args: required: false type: string @@ -30,7 +34,7 @@ on: extra-env-vars: required: false type: string - description: "Multiline string of environment variables to set for the test run." + description: Multiline string of environment variables to set for the test run. env: PYTEST_ARGS: "--max-examples 200 -v -rxXfE --ci ${{ inputs.pytest-extra-args }} --hypothesis-disable-deadline --durations 10" @@ -39,41 +43,44 @@ jobs: tests: runs-on: ubuntu-latest strategy: + fail-fast: false matrix: - # Min version of dask we need dropped support for Python 3.9 - # There is no numpy git tip for Python 3.9 or 3.10 - python-version: ${{ (inputs.package-name == 'dask' && fromJson('[''3.10'', ''3.11'', ''3.12'']')) || (inputs.package-name == 'numpy' && inputs.xfails-file-extra == '-dev' && fromJson('[''3.11'', ''3.12'']')) || fromJson('[''3.9'', ''3.10'', ''3.11'', ''3.12'']') }} + python-version: ${{ fromJson(inputs.python-versions) }} steps: - name: Checkout array-api-compat uses: actions/checkout@v4 with: path: array-api-compat + - name: Checkout array-api-tests uses: actions/checkout@v4 with: repository: data-apis/array-api-tests submodules: 'true' path: array-api-tests + - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} + - name: Set Extra Environment Variables # Set additional environment variables if provided if: inputs.extra-env-vars run: | echo "${{ inputs.extra-env-vars }}" >> $GITHUB_ENV + - name: Install dependencies - # NumPy 1.21 doesn't support Python 3.11. There doesn't seem to be a way - # to put this in the numpy 1.21 config file. - if: "! ((matrix.python-version == '3.11' || matrix.python-version == '3.12') && inputs.package-name == 'numpy' && contains(inputs.package-version, '1.21'))" run: | python -m pip install --upgrade pip python -m pip install '${{ inputs.package-name }} ${{ inputs.package-version }}' ${{ inputs.extra-requires }} python -m pip install -r ${GITHUB_WORKSPACE}/array-api-tests/requirements.txt + + - name: Dump pip environment + run: pip freeze + - name: Run the array API testsuite (${{ inputs.package-name }}) - if: "! ((matrix.python-version == '3.11' || matrix.python-version == '3.12') && inputs.package-name == 'numpy' && contains(inputs.package-version, '1.21'))" env: ARRAY_API_TESTS_MODULE: array_api_compat.${{ inputs.module-name || inputs.package-name }} ARRAY_API_TESTS_VERSION: 2024.12 diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 54f6f402..81a05b3f 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -4,15 +4,24 @@ jobs: tests: runs-on: ubuntu-latest strategy: + fail-fast: false matrix: - python-version: ['3.9', '3.10', '3.11', '3.12'] - numpy-version: ['1.21', '1.26', '2.0', 'dev'] - exclude: - - python-version: '3.11' - numpy-version: '1.21' - - python-version: '3.12' - numpy-version: '1.21' - fail-fast: true + include: + - numpy-version: '1.22' + python-version: '3.10' + - numpy-version: '1.26' + python-version: '3.10' + - numpy-version: '1.26' + python-version: '3.12' + - numpy-version: 'latest' + python-version: '3.10' + - numpy-version: 'latest' + python-version: '3.13' + - numpy-version: 'dev' + python-version: '3.11' + - numpy-version: 'dev' + python-version: '3.13' + steps: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 @@ -21,22 +30,29 @@ jobs: - name: Install Dependencies run: | python -m pip install --upgrade pip + python -m pip install pytest + if [ "${{ matrix.numpy-version }}" == "dev" ]; then - PIP_EXTRA='numpy --pre --extra-index-url https://pypi.anaconda.org/scientific-python-nightly-wheels/simple' - elif [ "${{ matrix.numpy-version }}" == "1.21" ]; then - PIP_EXTRA='numpy==1.21.*' + python -m pip install numpy --pre --extra-index-url https://pypi.anaconda.org/scientific-python-nightly-wheels/simple + elif [ "${{ matrix.numpy-version }}" == "1.22" ]; then + python -m pip install 'numpy==1.22.*' + elif [ "${{ matrix.numpy-version }}" == "1.26" ]; then + python -m pip install 'numpy==1.26.*' else - PIP_EXTRA='numpy==1.26.*' + # Don't `pip install .[dev]` as it would pull in the whole torch cuda stack + python -m pip install array-api-strict dask[array] jax[cpu] numpy sparse + python -m pip install torch --index-url https://download.pytorch.org/whl/cpu + if [ "${{ matrix.python-version }}" != "3.13" ]; then + # onnx wheels are not available on Python 3.13 at the moment of writing + python -m pip install ndonnx + fi fi - python -m pip install .[dev] $PIP_EXTRA + - name: Dump pip environment + run: pip freeze - - name: Run Tests - run: | - if [[ "${{ matrix.numpy-version }}" == "1.21" || "${{ matrix.numpy-version }}" == "dev" ]]; then - PYTEST_EXTRA=(-k "numpy and not jax and not torch and not dask and not sparse") - fi - pytest -v "${PYTEST_EXTRA[@]}" + - name: Test it installs + run: python -m pip install . - # Make sure it installs - python -m pip install . + - name: Run Tests + run: pytest -v diff --git a/array_api_compat/cupy/_typing.py b/array_api_compat/cupy/_typing.py index 66af5d19..d8e49ca7 100644 --- a/array_api_compat/cupy/_typing.py +++ b/array_api_compat/cupy/_typing.py @@ -10,7 +10,7 @@ from cupy.cuda.device import Device if TYPE_CHECKING: - # NumPy 1.x on Python 3.9 and 3.10 fails to parse np.dtype[] + # NumPy 1.x on Python 3.10 fails to parse np.dtype[] DType = cp.dtype[ cp.intp | cp.int8 diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index 4733b1a6..e7ddde78 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -147,7 +147,7 @@ def asarray( *, dtype: Optional[DType] = None, device: Optional[Device] = None, - copy: Optional[Union[bool, np._CopyMode]] = None, + copy: Optional[bool] = None, **kwargs, ) -> Array: """ diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index 59a0b8f4..d1fd46a1 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -99,18 +99,12 @@ def asarray( """ _helpers._check_device(np, device) - if hasattr(np, '_CopyMode'): - if copy is None: - copy = np._CopyMode.IF_NEEDED - elif copy is False: - copy = np._CopyMode.NEVER - elif copy is True: - copy = np._CopyMode.ALWAYS - else: - # Not present in older NumPys. In this case, we cannot really support - # copy=False. - if copy is False: - raise NotImplementedError("asarray(copy=False) requires a newer version of NumPy.") + if copy is None: + copy = np._CopyMode.IF_NEEDED + elif copy is False: + copy = np._CopyMode.NEVER + elif copy is True: + copy = np._CopyMode.ALWAYS return np.array(obj, copy=copy, dtype=dtype, **kwargs) diff --git a/array_api_compat/numpy/_typing.py b/array_api_compat/numpy/_typing.py index 6a18a3b2..a6c96924 100644 --- a/array_api_compat/numpy/_typing.py +++ b/array_api_compat/numpy/_typing.py @@ -10,7 +10,7 @@ Device = Literal["cpu"] if TYPE_CHECKING: - # NumPy 1.x on Python 3.9 and 3.10 fails to parse np.dtype[] + # NumPy 1.x on Python 3.10 fails to parse np.dtype[] DType = np.dtype[ np.intp | np.int8 diff --git a/docs/supported-array-libraries.md b/docs/supported-array-libraries.md index 4519c4ac..46fcdc27 100644 --- a/docs/supported-array-libraries.md +++ b/docs/supported-array-libraries.md @@ -36,23 +36,16 @@ deviations from the standard should be noted: 50](https://numpy.org/neps/nep-0050-scalar-promotion.html) and https://github.com/numpy/numpy/issues/22341) -- `asarray()` does not support `copy=False`. - - Functions which are not wrapped may not have the same type annotations as the spec. - Functions which are not wrapped may not use positional-only arguments. -The minimum supported NumPy version is 1.21. However, this older version of +The minimum supported NumPy version is 1.22. However, this older version of NumPy has a few issues: - `unique_*` will not compare nans as unequal. -- `finfo()` has no `smallest_normal`. - No `from_dlpack` or `__dlpack__`. -- `argmax()` and `argmin()` do not have `keepdims`. -- `qr()` doesn't support matrix stacks. -- `asarray()` doesn't support `copy=True` (as noted above, `copy=False` is not - supported even in the latest NumPy). - Type promotion behavior will be value based for 0-D arrays (and there is no `NPY_PROMOTION_STATE=weak` to disable this). @@ -72,8 +65,8 @@ version. attribute in the spec. Use the {func}`~.size()` helper function as a portable workaround. -- PyTorch does not have unsigned integer types other than `uint8`, and no - attempt is made to implement them here. +- PyTorch has incomplete support for unsigned integer types other + than `uint8`, and no attempt is made to implement them here. - PyTorch has type promotion semantics that differ from the array API specification for 0-D tensor objects. The array functions in this wrapper @@ -100,8 +93,6 @@ version. - As with NumPy, type annotations and positional-only arguments may not exactly match the spec for functions that are not wrapped at all. -The minimum supported PyTorch version is 1.13. - (jax-support)= ## [JAX](https://jax.readthedocs.io/en/latest/) @@ -131,8 +122,6 @@ For `linalg`, several methods are missing, for example: - `matrix_rank` Other methods may only be partially implemented or return incorrect results at times. -The minimum supported Dask version is 2023.12.0. - (sparse-support)= ## [Sparse](https://sparse.pydata.org/en/stable/) diff --git a/numpy-1-21-xfails.txt b/numpy-1-22-xfails.txt similarity index 68% rename from numpy-1-21-xfails.txt rename to numpy-1-22-xfails.txt index 66443a73..93edf311 100644 --- a/numpy-1-21-xfails.txt +++ b/numpy-1-22-xfails.txt @@ -1,6 +1,3 @@ -# asarray(copy=False) is not yet implemented -array_api_tests/test_creation_functions.py::test_asarray_arrays - # attributes are np.float32 instead of float # (see also https://github.com/data-apis/array-api/issues/405) array_api_tests/test_data_type_functions.py::test_finfo[float32] @@ -39,38 +36,24 @@ array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and # https://github.com/numpy/numpy/issues/21213 array_api_tests/meta/test_hypothesis_helpers.py::test_symmetric_matrices -# NumPy 1.21 specific XFAILS +# NumPy 1.22 specific XFAILS ############################ -# finfo has no smallest_normal -array_api_tests/test_data_type_functions.py::test_finfo - -# dlpack stuff -array_api_tests/test_has_names.py::test_has_names[creation-from_dlpack] -array_api_tests/test_has_names.py::test_has_names[array_method-__dlpack__] -array_api_tests/test_has_names.py::test_has_names[array_method-__dlpack_device__] -array_api_tests/test_signatures.py::test_func_signature[from_dlpack] -array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] -array_api_tests/test_signatures.py::test_array_method_signature[__dlpack_device__] - -# qr() doesn't support matrix stacks -array_api_tests/test_linalg.py::test_qr - # cross has some promotion bug that is fixed in newer numpy versions array_api_tests/test_linalg.py::test_cross +# linspace(-0.0, -1.0, num=1) returns +0.0 instead of -0.0. +# Fixed in newer numpy versions. +array_api_tests/test_creation_functions.py::test_linspace + # vector_norm with ord=-1 which has since been fixed # https://github.com/numpy/numpy/issues/21083 array_api_tests/test_linalg.py::test_vector_norm -# argmax and argmin do not support keepdims -array_api_tests/test_searching_functions.py::test_argmax -array_api_tests/test_searching_functions.py::test_argmin -array_api_tests/test_signatures.py::test_func_signature[argmax] -array_api_tests/test_signatures.py::test_func_signature[argmin] - -# NumPy 1.21 doesn't support NPY_PROMOTION_STATE=weak, so many tests fail with +# NumPy 1.22 doesn't support NPY_PROMOTION_STATE=weak, so many tests fail with # type promotion issues +# NOTE: some of these may not fail until one runs array-api-tests with +# --max-examples 100000 array_api_tests/test_manipulation_functions.py::test_concat array_api_tests/test_operators_and_elementwise_functions.py::test_add[__add__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_add[__add__(x1, x2)] @@ -109,6 +92,7 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[_ array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[floor_divide(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_greater[__gt__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_greater[greater(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[__ge__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[greater_equal(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_hypot array_api_tests/test_operators_and_elementwise_functions.py::test_less[__lt__(x1, x2)] @@ -136,53 +120,11 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__isu array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__sub__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__sub__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[subtract(x1, x2)] -array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i < 0 and x2_i is +0) -> -infinity] -array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i < 0 and x2_i is -0) -> +infinity] -array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i > 0 and x2_i is +0) -> +infinity] -array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i > 0 and x2_i is -0) -> -infinity] -array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +0 and x2_i < 0) -> -0] -array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +0 and x2_i > 0) -> +0] -array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -0 and x2_i < 0) -> +0] -array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -0 and x2_i > 0) -> -0] -array_api_tests/test_special_cases.py::test_binary[__mod__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> x2_i] -array_api_tests/test_special_cases.py::test_binary[__mod__(isfinite(x1_i) and x1_i < 0 and x2_i is -infinity) -> x1_i] -array_api_tests/test_special_cases.py::test_binary[__mod__(isfinite(x1_i) and x1_i > 0 and x2_i is +infinity) -> x1_i] -array_api_tests/test_special_cases.py::test_binary[__mod__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> x2_i] -array_api_tests/test_special_cases.py::test_binary[__mod__(x1_i is +0 and x2_i < 0) -> -0] -array_api_tests/test_special_cases.py::test_binary[__mod__(x1_i is +0 and x2_i > 0) -> +0] -array_api_tests/test_special_cases.py::test_binary[__mod__(x1_i is -0 and x2_i < 0) -> -0] -array_api_tests/test_special_cases.py::test_binary[__mod__(x1_i is -0 and x2_i > 0) -> +0] -array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i < 0 and isfinite(x1_i) and isfinite(x2_i) and not x2_i.is_integer()) -> NaN] -array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is +0 and x2_i < 0) -> +infinity] -array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is +0 and x2_i > 0) -> +0] -array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -0 and x2_i < 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity] -array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -0 and x2_i < 0 and x2_i.is_integer() and x2_i % 2 == 1) -> -infinity] -array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -0 and x2_i > 0 and x2_i.is_integer() and x2_i % 2 == 1) -> -0] -array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -infinity and x2_i < 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0] -array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -infinity and x2_i < 0 and x2_i.is_integer() and x2_i % 2 == 1) -> -0] -array_api_tests/test_special_cases.py::test_binary[atan2(x1_i < 0 and x2_i is +0) -> roughly -pi/2] -array_api_tests/test_special_cases.py::test_binary[atan2(x1_i < 0 and x2_i is -0) -> roughly -pi/2] -array_api_tests/test_special_cases.py::test_binary[atan2(x1_i > 0 and x2_i is +0) -> roughly +pi/2] -array_api_tests/test_special_cases.py::test_binary[atan2(x1_i > 0 and x2_i is -0) -> roughly +pi/2] -array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i < 0 and x2_i is +0) -> -infinity] -array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i < 0 and x2_i is -0) -> +infinity] -array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i > 0 and x2_i is +0) -> +infinity] -array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i > 0 and x2_i is -0) -> -infinity] -array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +0 and x2_i < 0) -> -0] -array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +0 and x2_i > 0) -> +0] -array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -0 and x2_i < 0) -> +0] -array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -0 and x2_i > 0) -> -0] -array_api_tests/test_special_cases.py::test_binary[remainder(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> x2_i] -array_api_tests/test_special_cases.py::test_binary[remainder(isfinite(x1_i) and x1_i < 0 and x2_i is -infinity) -> x1_i] -array_api_tests/test_special_cases.py::test_binary[remainder(isfinite(x1_i) and x1_i > 0 and x2_i is +infinity) -> x1_i] -array_api_tests/test_special_cases.py::test_binary[remainder(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> x2_i] -array_api_tests/test_special_cases.py::test_binary[remainder(x1_i is +0 and x2_i < 0) -> -0] -array_api_tests/test_special_cases.py::test_binary[remainder(x1_i is +0 and x2_i > 0) -> +0] -array_api_tests/test_special_cases.py::test_binary[remainder(x1_i is -0 and x2_i < 0) -> -0] -array_api_tests/test_special_cases.py::test_binary[remainder(x1_i is -0 and x2_i > 0) -> +0] +array_api_tests/test_searching_functions.py::test_where array_api_tests/test_special_cases.py::test_iop[__iadd__(x1_i is -0 and x2_i is -0) -> -0] # 2023.12 support +array_api_tests/test_has_names.py::test_has_names[creation-from_dlpack] array_api_tests/test_signatures.py::test_func_signature[from_dlpack] array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] # uint64 repeats not supported @@ -215,6 +157,3 @@ array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] - -# numpy < 2 bug: type promotion of asarray([], 'float32') and (np.finfo(float32).max + 1) -> float64 -array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real diff --git a/numpy-1-26-xfails.txt b/numpy-1-26-xfails.txt index ed95083a..51e1a658 100644 --- a/numpy-1-26-xfails.txt +++ b/numpy-1-26-xfails.txt @@ -69,6 +69,3 @@ array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] - -# numpy < 2 bug: type promotion of asarray([], 'float32') and (finfo(float32).max + 1) gives float64 not float32 -array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real diff --git a/numpy-skips.txt b/numpy-skips.txt index cbf7235b..e69de29b 100644 --- a/numpy-skips.txt +++ b/numpy-skips.txt @@ -1,11 +0,0 @@ -# These tests cause a core dump on CI, so we have to skip them entirely -array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__ifloordiv__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__ifloordiv__(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[floor_divide(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__imod__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__imod__(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__mod__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__mod__(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[remainder(x1, x2)] diff --git a/numpy-xfails.txt b/numpy-xfails.txt index 0f09985e..632b4ec3 100644 --- a/numpy-xfails.txt +++ b/numpy-xfails.txt @@ -9,8 +9,6 @@ array_api_tests/test_signatures.py::test_func_signature[vecdot] array_api_tests/test_signatures.py::test_extension_func_signature[linalg.vecdot] # 2023.12 support -array_api_tests/test_signatures.py::test_func_signature[from_dlpack] -array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] # uint64 repeats not supported array_api_tests/test_manipulation_functions.py::test_repeat @@ -20,6 +18,8 @@ array_api_tests/test_signatures.py::test_func_signature[bitwise_left_shift] array_api_tests/test_signatures.py::test_func_signature[bitwise_or] array_api_tests/test_signatures.py::test_func_signature[bitwise_right_shift] array_api_tests/test_signatures.py::test_func_signature[bitwise_xor] + +# Stubs have a comment: (**note**: libraries may return ``NaN`` to match Python behavior.); Apparently, NumPy does just that array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] diff --git a/pyproject.toml b/pyproject.toml index f17c720f..aacebd11 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,13 +7,12 @@ name = "array-api-compat" dynamic = ["version"] description = "A wrapper around NumPy and other array libraries to make them compatible with the Array API standard" readme = "README.md" -requires-python = ">=3.9" +requires-python = ">=3.10" license = "MIT" authors = [{name = "Consortium for Python Data API Standards"}] classifiers = [ "Operating System :: OS Independent", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", @@ -24,11 +23,14 @@ classifiers = [ [project.optional-dependencies] cupy = ["cupy"] -dask = ["dask"] +dask = ["dask>=2024.9.0"] jax = ["jax"] -numpy = ["numpy"] +# Note: array-api-compat follows scikit-learn minimum dependencies, which support +# much older versions of NumPy than what SPEC0 recommends. +numpy = ["numpy>=1.22"] pytorch = ["torch"] sparse = ["sparse>=0.15.1"] +ndonnx = ["ndonnx"] docs = [ "furo", "linkify-it-py", @@ -39,13 +41,13 @@ docs = [ ] dev = [ "array-api-strict", - "dask[array]", + "dask[array]>=2024.9.0", "jax[cpu]", - "numpy", + "numpy>=1.22", "pytest", "torch", "sparse>=0.15.1", - "ndonnx; python_version>=\"3.10\"" + "ndonnx" ] [project.urls] diff --git a/tests/test_common.py b/tests/test_common.py index 54024d47..6b1aa160 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -276,10 +276,7 @@ def test_asarray_copy(library): is_lib_func = globals()[is_array_functions[library]] all = xp.all if library != 'dask.array' else lambda x: xp.all(x).compute() - if library == 'numpy' and xp.__version__[0] < '2' and not hasattr(np, "_CopyMode"): - supports_copy_false_other_ns = False - supports_copy_false_same_ns = False - elif library == 'cupy': + if library == 'cupy': supports_copy_false_other_ns = False supports_copy_false_same_ns = False elif library == 'dask.array': diff --git a/tests/test_dask.py b/tests/test_dask.py index 69c738f6..fb0a84d4 100644 --- a/tests/test_dask.py +++ b/tests/test_dask.py @@ -1,6 +1,5 @@ from contextlib import contextmanager -import array_api_strict import numpy as np import pytest @@ -171,9 +170,10 @@ def test_sort_argsort_chunk_size(xp, func, shape, chunks): @pytest.mark.parametrize("func", ["sort", "argsort"]) def test_sort_argsort_meta(xp, func): """Test meta-namespace other than numpy""" - typ = type(array_api_strict.asarray(0)) + mxp = pytest.importorskip("array_api_strict") + typ = type(mxp.asarray(0)) a = da.random.random(10) - b = a.map_blocks(array_api_strict.asarray) + b = a.map_blocks(mxp.asarray) assert isinstance(b._meta, typ) c = getattr(xp, func)(b) assert isinstance(c._meta, typ) diff --git a/torch-skips.txt b/torch-skips.txt index cbf7235b..e69de29b 100644 --- a/torch-skips.txt +++ b/torch-skips.txt @@ -1,11 +0,0 @@ -# These tests cause a core dump on CI, so we have to skip them entirely -array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__ifloordiv__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__ifloordiv__(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[floor_divide(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__imod__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__imod__(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__mod__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__mod__(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[remainder(x1, x2)] diff --git a/torch-xfails.txt b/torch-xfails.txt index e556fa4f..abee88b1 100644 --- a/torch-xfails.txt +++ b/torch-xfails.txt @@ -29,6 +29,10 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__trued array_api_tests/test_operators_and_elementwise_functions.py::test_equal[__eq__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_not_equal[__ne__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x, s)] +array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__ifloordiv__(x, s)] +array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__ifloordiv__(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[floor_divide(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_greater[__gt__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[__ge__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_less[__lt__(x1, x2)] From 00e7cceb338025d9428af2bb6afbe7eaac8cf414 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Tue, 15 Apr 2025 11:53:21 +0200 Subject: [PATCH 30/80] BUG: add torch.repeat --- array_api_compat/torch/_aliases.py | 7 ++++++- torch-xfails.txt | 3 +-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index a2ed1449..0a604b8c 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -574,6 +574,11 @@ def count_nonzero( return result +# "repeat" is torch.repeat_interleave; also the dim argument +def repeat(x: Array, repeats: int | Array, /, *, axis: int | None = None) -> Array: + return torch.repeat_interleave(x, repeats, axis) + + def where( condition: Array, x1: Array | bool | int | float | complex, @@ -854,6 +859,6 @@ def sign(x: Array, /) -> Array: 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult', 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values', 'matmul', 'matrix_transpose', 'vecdot', 'tensordot', 'isdtype', - 'take', 'take_along_axis', 'sign', 'finfo', 'iinfo'] + 'take', 'take_along_axis', 'sign', 'finfo', 'iinfo', 'repeat'] _all_ignore = ['torch', 'get_xp'] diff --git a/torch-xfails.txt b/torch-xfails.txt index e556fa4f..ab11f457 100644 --- a/torch-xfails.txt +++ b/torch-xfails.txt @@ -120,9 +120,8 @@ array_api_tests/test_data_type_functions.py::test_finfo_dtype array_api_tests/test_data_type_functions.py::test_iinfo_dtype # 2023.12 support -array_api_tests/test_has_names.py::test_has_names[manipulation-repeat] +# https://github.com/pytorch/pytorch/issues/151311: torch.repeat_interleave rejects short integers array_api_tests/test_manipulation_functions.py::test_repeat -array_api_tests/test_signatures.py::test_func_signature[repeat] # Argument 'device' missing from signature array_api_tests/test_signatures.py::test_func_signature[from_dlpack] # Argument 'max_version' missing from signature From d743dc13e16a2328e3ce0951dd3633629b6537a6 Mon Sep 17 00:00:00 2001 From: Guido Imperiale Date: Tue, 15 Apr 2025 11:54:15 +0100 Subject: [PATCH 31/80] MAINT: `__array_namespace_info__` docstrings tweaks (#300) --- array_api_compat/common/_aliases.py | 2 +- array_api_compat/cupy/_info.py | 20 ++++++++++---- array_api_compat/dask/array/_info.py | 19 +++++++------ array_api_compat/numpy/_info.py | 8 +++--- array_api_compat/torch/_info.py | 41 ++++++++++++++++++---------- 5 files changed, 56 insertions(+), 34 deletions(-) diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 46cbb359..351b5bd6 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -18,7 +18,7 @@ # These functions are modified from the NumPy versions. -# Creation functions add the device keyword (which does nothing for NumPy) +# Creation functions add the device keyword (which does nothing for NumPy and Dask) def arange( start: Union[int, float], diff --git a/array_api_compat/cupy/_info.py b/array_api_compat/cupy/_info.py index 790621e4..78e48a33 100644 --- a/array_api_compat/cupy/_info.py +++ b/array_api_compat/cupy/_info.py @@ -26,6 +26,7 @@ complex128, ) + class __array_namespace_info__: """ Get the array API inspection namespace for CuPy. @@ -49,7 +50,7 @@ class __array_namespace_info__: Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.default_dtypes() {'real floating': cupy.float64, 'complex floating': cupy.complex128, @@ -94,13 +95,13 @@ def capabilities(self): >>> info = xp.__array_namespace_info__() >>> info.capabilities() {'boolean indexing': True, - 'data-dependent shapes': True} + 'data-dependent shapes': True, + 'max dimensions': 64} """ return { "boolean indexing": True, "data-dependent shapes": True, - # 'max rank' will be part of the 2024.12 standard "max dimensions": 64, } @@ -117,7 +118,7 @@ def default_device(self): Returns ------- - device : str + device : Device The default device used for new CuPy arrays. Examples @@ -126,6 +127,15 @@ def default_device(self): >>> info.default_device() Device(0) + Notes + ----- + This method returns the static default device when CuPy is initialized. + However, the *current* device used by creation functions (``empty`` etc.) + can be changed globally or with a context manager. + + See Also + -------- + https://github.com/data-apis/array-api/issues/835 """ return cuda.Device(0) @@ -312,7 +322,7 @@ def devices(self): Returns ------- - devices : list of str + devices : list[Device] The devices supported by CuPy. See Also diff --git a/array_api_compat/dask/array/_info.py b/array_api_compat/dask/array/_info.py index fc70b5a2..614f43d9 100644 --- a/array_api_compat/dask/array/_info.py +++ b/array_api_compat/dask/array/_info.py @@ -50,7 +50,7 @@ class __array_namespace_info__: Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.default_dtypes() {'real floating': dask.float64, 'complex floating': dask.complex128, @@ -103,10 +103,11 @@ def capabilities(self): Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.capabilities() {'boolean indexing': True, - 'data-dependent shapes': True} + 'data-dependent shapes': True, + 'max dimensions': 64} """ return { @@ -130,12 +131,12 @@ def default_device(self): Returns ------- - device : str + device : Device The default device used for new Dask arrays. Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.default_device() 'cpu' @@ -173,7 +174,7 @@ def default_dtypes(self, *, device=None): Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.default_dtypes() {'real floating': dask.float64, 'complex floating': dask.complex128, @@ -239,7 +240,7 @@ def dtypes(self, *, device=None, kind=None): Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.dtypes(kind='signed integer') {'int8': dask.int8, 'int16': dask.int16, @@ -335,7 +336,7 @@ def devices(self): Returns ------- - devices : list of str + devices : list[Device] The devices supported by Dask. See Also @@ -347,7 +348,7 @@ def devices(self): Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.devices() ['cpu', DASK_DEVICE] diff --git a/array_api_compat/numpy/_info.py b/array_api_compat/numpy/_info.py index e706d118..365855b8 100644 --- a/array_api_compat/numpy/_info.py +++ b/array_api_compat/numpy/_info.py @@ -94,13 +94,13 @@ def capabilities(self): >>> info = np.__array_namespace_info__() >>> info.capabilities() {'boolean indexing': True, - 'data-dependent shapes': True} + 'data-dependent shapes': True, + 'max dimensions': 64} """ return { "boolean indexing": True, "data-dependent shapes": True, - # 'max rank' will be part of the 2024.12 standard "max dimensions": 64, } @@ -119,7 +119,7 @@ def default_device(self): Returns ------- - device : str + device : Device The default device used for new NumPy arrays. Examples @@ -326,7 +326,7 @@ def devices(self): Returns ------- - devices : list of str + devices : list[Device] The devices supported by NumPy. See Also diff --git a/array_api_compat/torch/_info.py b/array_api_compat/torch/_info.py index 34fbcb21..818e5d37 100644 --- a/array_api_compat/torch/_info.py +++ b/array_api_compat/torch/_info.py @@ -34,7 +34,7 @@ class __array_namespace_info__: Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.default_dtypes() {'real floating': numpy.float64, 'complex floating': numpy.complex128, @@ -76,16 +76,16 @@ def capabilities(self): Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.capabilities() {'boolean indexing': True, - 'data-dependent shapes': True} + 'data-dependent shapes': True, + 'max dimensions': 64} """ return { "boolean indexing": True, "data-dependent shapes": True, - # 'max rank' will be part of the 2024.12 standard "max dimensions": 64, } @@ -102,15 +102,24 @@ def default_device(self): Returns ------- - device : str + device : Device The default device used for new PyTorch arrays. Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.default_device() - 'cpu' + device(type='cpu') + Notes + ----- + This method returns the static default device when PyTorch is initialized. + However, the *current* device used by creation functions (``empty`` etc.) + can be changed at runtime. + + See Also + -------- + https://github.com/data-apis/array-api/issues/835 """ return torch.device("cpu") @@ -120,9 +129,9 @@ def default_dtypes(self, *, device=None): Parameters ---------- - device : str, optional - The device to get the default data types for. For PyTorch, only - ``'cpu'`` is allowed. + device : Device, optional + The device to get the default data types for. + Unused for PyTorch, as all devices use the same default dtypes. Returns ------- @@ -139,7 +148,7 @@ def default_dtypes(self, *, device=None): Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.default_dtypes() {'real floating': torch.float32, 'complex floating': torch.complex64, @@ -250,8 +259,9 @@ def dtypes(self, *, device=None, kind=None): Parameters ---------- - device : str, optional + device : Device, optional The device to get the data types for. + Unused for PyTorch, as all devices use the same dtypes. kind : str or tuple of str, optional The kind of data types to return. If ``None``, all data types are returned. If a string, only data types of that kind are returned. @@ -287,7 +297,7 @@ def dtypes(self, *, device=None, kind=None): Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.dtypes(kind='signed integer') {'int8': numpy.int8, 'int16': numpy.int16, @@ -310,7 +320,7 @@ def devices(self): Returns ------- - devices : list of str + devices : list[Device] The devices supported by PyTorch. See Also @@ -322,7 +332,7 @@ def devices(self): Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.devices() [device(type='cpu'), device(type='mps', index=0), device(type='meta')] @@ -333,6 +343,7 @@ def devices(self): # device: try: torch.device('notadevice') + raise AssertionError("unreachable") # pragma: nocover except RuntimeError as e: # The error message is something like: # "Expected one of cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, ideep, hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, hpu, mtia, privateuseone device type at start of device string: notadevice" From 9194c5cb7706e08f1a1092aece1fce76ac6e089a Mon Sep 17 00:00:00 2001 From: Guido Imperiale Date: Tue, 15 Apr 2025 12:04:09 +0100 Subject: [PATCH 32/80] MAINT: simplify `torch` dtype promotion (#303) reviewed at https://github.com/data-apis/array-api-compat/pull/303 --- array_api_compat/torch/_aliases.py | 99 ++++++++++++------------------ 1 file changed, 40 insertions(+), 59 deletions(-) diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index a2ed1449..5370803f 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -35,47 +35,23 @@ torch.complex128, } -_promotion_table = { - # bool - (torch.bool, torch.bool): torch.bool, +_promotion_table = { # ints - (torch.int8, torch.int8): torch.int8, (torch.int8, torch.int16): torch.int16, (torch.int8, torch.int32): torch.int32, (torch.int8, torch.int64): torch.int64, - (torch.int16, torch.int8): torch.int16, - (torch.int16, torch.int16): torch.int16, (torch.int16, torch.int32): torch.int32, (torch.int16, torch.int64): torch.int64, - (torch.int32, torch.int8): torch.int32, - (torch.int32, torch.int16): torch.int32, - (torch.int32, torch.int32): torch.int32, (torch.int32, torch.int64): torch.int64, - (torch.int64, torch.int8): torch.int64, - (torch.int64, torch.int16): torch.int64, - (torch.int64, torch.int32): torch.int64, - (torch.int64, torch.int64): torch.int64, - # uints - (torch.uint8, torch.uint8): torch.uint8, # ints and uints (mixed sign) - (torch.int8, torch.uint8): torch.int16, - (torch.int16, torch.uint8): torch.int16, - (torch.int32, torch.uint8): torch.int32, - (torch.int64, torch.uint8): torch.int64, (torch.uint8, torch.int8): torch.int16, (torch.uint8, torch.int16): torch.int16, (torch.uint8, torch.int32): torch.int32, (torch.uint8, torch.int64): torch.int64, # floats - (torch.float32, torch.float32): torch.float32, (torch.float32, torch.float64): torch.float64, - (torch.float64, torch.float32): torch.float64, - (torch.float64, torch.float64): torch.float64, # complexes - (torch.complex64, torch.complex64): torch.complex64, (torch.complex64, torch.complex128): torch.complex128, - (torch.complex128, torch.complex64): torch.complex128, - (torch.complex128, torch.complex128): torch.complex128, # Mixed float and complex (torch.float32, torch.complex64): torch.complex64, (torch.float32, torch.complex128): torch.complex128, @@ -83,6 +59,9 @@ (torch.float64, torch.complex128): torch.complex128, } +_promotion_table.update({(b, a): c for (a, b), c in _promotion_table.items()}) +_promotion_table.update({(a, a): a for a in _array_api_dtypes}) + def _two_arg(f): @_wraps(f) @@ -150,13 +129,18 @@ def result_type( return _reduce(_result_type, others + scalars) -def _result_type(x, y): +def _result_type( + x: Array | DType | bool | int | float | complex, + y: Array | DType | bool | int | float | complex, +) -> DType: if not (isinstance(x, _py_scalars) or isinstance(y, _py_scalars)): - xdt = x.dtype if not isinstance(x, torch.dtype) else x - ydt = y.dtype if not isinstance(y, torch.dtype) else y + xdt = x if isinstance(x, torch.dtype) else x.dtype + ydt = y if isinstance(y, torch.dtype) else y.dtype - if (xdt, ydt) in _promotion_table: + try: return _promotion_table[xdt, ydt] + except KeyError: + pass # This doesn't result_type(dtype, dtype) for non-array API dtypes # because torch.result_type only accepts tensors. This does however, allow @@ -301,6 +285,25 @@ def _reduce_multiple_axes(f, x, axis, keepdims=False, **kwargs): out = torch.unsqueeze(out, a) return out + +def _sum_prod_no_axis(x: Array, dtype: DType | None) -> Array: + """ + Implements `sum(..., axis=())` and `prod(..., axis=())`. + + Works around https://github.com/pytorch/pytorch/issues/29137 + """ + if dtype is not None: + return x.clone() if dtype == x.dtype else x.to(dtype) + + # We can't upcast uint8 according to the spec because there is no + # torch.uint64, so at least upcast to int64 which is what prod does + # when axis=None. + if x.dtype in (torch.uint8, torch.int8, torch.int16, torch.int32): + return x.to(torch.int64) + + return x.clone() + + def prod(x: Array, /, *, @@ -308,20 +311,9 @@ def prod(x: Array, dtype: Optional[DType] = None, keepdims: bool = False, **kwargs) -> Array: - ndim = x.ndim - # https://github.com/pytorch/pytorch/issues/29137. Separate from the logic - # below because it still needs to upcast. if axis == (): - if dtype is None: - # We can't upcast uint8 according to the spec because there is no - # torch.uint64, so at least upcast to int64 which is what sum does - # when axis=None. - if x.dtype in [torch.int8, torch.int16, torch.int32, torch.uint8]: - return x.to(torch.int64) - return x.clone() - return x.to(dtype) - + return _sum_prod_no_axis(x, dtype) # torch.prod doesn't support multiple axes # (https://github.com/pytorch/pytorch/issues/56586). if isinstance(axis, tuple): @@ -330,7 +322,7 @@ def prod(x: Array, # torch doesn't support keepdims with axis=None # (https://github.com/pytorch/pytorch/issues/71209) res = torch.prod(x, dtype=dtype, **kwargs) - res = _axis_none_keepdims(res, ndim, keepdims) + res = _axis_none_keepdims(res, x.ndim, keepdims) return res return torch.prod(x, axis, dtype=dtype, keepdims=keepdims, **kwargs) @@ -343,25 +335,14 @@ def sum(x: Array, dtype: Optional[DType] = None, keepdims: bool = False, **kwargs) -> Array: - ndim = x.ndim - # https://github.com/pytorch/pytorch/issues/29137. - # Make sure it upcasts. if axis == (): - if dtype is None: - # We can't upcast uint8 according to the spec because there is no - # torch.uint64, so at least upcast to int64 which is what sum does - # when axis=None. - if x.dtype in [torch.int8, torch.int16, torch.int32, torch.uint8]: - return x.to(torch.int64) - return x.clone() - return x.to(dtype) - + return _sum_prod_no_axis(x, dtype) if axis is None: # torch doesn't support keepdims with axis=None # (https://github.com/pytorch/pytorch/issues/71209) res = torch.sum(x, dtype=dtype, **kwargs) - res = _axis_none_keepdims(res, ndim, keepdims) + res = _axis_none_keepdims(res, x.ndim, keepdims) return res return torch.sum(x, axis, dtype=dtype, keepdims=keepdims, **kwargs) @@ -372,7 +353,7 @@ def any(x: Array, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, **kwargs) -> Array: - ndim = x.ndim + if axis == (): return x.to(torch.bool) # torch.any doesn't support multiple axes @@ -384,7 +365,7 @@ def any(x: Array, # torch doesn't support keepdims with axis=None # (https://github.com/pytorch/pytorch/issues/71209) res = torch.any(x, **kwargs) - res = _axis_none_keepdims(res, ndim, keepdims) + res = _axis_none_keepdims(res, x.ndim, keepdims) return res.to(torch.bool) # torch.any doesn't return bool for uint8 @@ -396,7 +377,7 @@ def all(x: Array, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, **kwargs) -> Array: - ndim = x.ndim + if axis == (): return x.to(torch.bool) # torch.all doesn't support multiple axes @@ -408,7 +389,7 @@ def all(x: Array, # torch doesn't support keepdims with axis=None # (https://github.com/pytorch/pytorch/issues/71209) res = torch.all(x, **kwargs) - res = _axis_none_keepdims(res, ndim, keepdims) + res = _axis_none_keepdims(res, x.ndim, keepdims) return res.to(torch.bool) # torch.all doesn't return bool for uint8 From b94efc1f5e490a23c0ca74aafb93cc3118471f46 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Tue, 15 Apr 2025 14:37:45 +0200 Subject: [PATCH 33/80] TST: skip testing nextafter with scalars on torch --- torch-xfails.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/torch-xfails.txt b/torch-xfails.txt index f8333d90..538403a3 100644 --- a/torch-xfails.txt +++ b/torch-xfails.txt @@ -144,6 +144,7 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_sc array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[logaddexp] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[maximum] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[minimum] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[nextafter] # https://github.com/pytorch/pytorch/issues/149815 array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[equal] From 205c967d658de24b2738dcae8d91684a1f99d2cd Mon Sep 17 00:00:00 2001 From: Joren Hammudoglu Date: Thu, 17 Apr 2025 20:14:48 +0200 Subject: [PATCH 34/80] TYP: Type annotations overhaul, episode 2 (#288) * TYP: annotate `_internal.get_xp` (and curse at `ParamSpec` for being so useless) * TYP: fix (or ignore) typing errors in `common._helpers` (and curse at cupy) * TYP: fix typing errors in `common._fft` * TYP: fix typing errors in `common._aliases` * TYP: fix typing errors in `common._linalg` * TYP: fix/ignore typing errors in `numpy.__init__` * TYP: fix typing errors in `numpy._typing` * TYP: fix typing errors in `numpy._aliases` * TYP: fix typing errors in `numpy._info` * TYP: fix typing errors in `numpy._fft` * TYP: it's a bad idea to import `TypeAlias` from `typing` on `python<3.10` * TYP: it's also a bad idea to import `TypeGuard` from `typing` on `python<3.10` * TYP: don't scare the prehistoric `dtype` from numpy 1.21 * TYP: dust off the DeLorean * TYP: figure out how to drive a DeLorean * TYP: apply review suggestions Co-authored-by: crusaderky * TYP: sprinkle some `TypeAlias`es and `Final`s around * TYP: `__dir__` * TYP: fix typing errors in `numpy.linalg` * TYP: add a `common._typing.Capabilities` typed dict type * TYP: `__array_namespace_info__` helper types * TYP: `dask.array` typing fixes and improvements * STY: give the `=` some breathing room Co-authored-by: Lucas Colley * STY: apply review suggestions Co-authored-by: lucascolley --------- Co-authored-by: crusaderky Co-authored-by: Lucas Colley --- array_api_compat/_internal.py | 25 +- array_api_compat/common/__init__.py | 2 +- array_api_compat/common/_aliases.py | 331 +++++++++++++++--------- array_api_compat/common/_fft.py | 69 ++--- array_api_compat/common/_helpers.py | 287 +++++++++++++------- array_api_compat/common/_linalg.py | 110 ++++++-- array_api_compat/common/_typing.py | 148 ++++++++++- array_api_compat/dask/array/__init__.py | 8 +- array_api_compat/dask/array/_aliases.py | 162 +++++++----- array_api_compat/dask/array/_info.py | 96 +++++-- array_api_compat/dask/array/linalg.py | 22 +- array_api_compat/numpy/__init__.py | 28 +- array_api_compat/numpy/_aliases.py | 86 +++--- array_api_compat/numpy/_info.py | 42 ++- array_api_compat/numpy/_typing.py | 35 ++- array_api_compat/numpy/fft.py | 16 +- array_api_compat/numpy/linalg.py | 97 +++++-- 17 files changed, 1076 insertions(+), 488 deletions(-) diff --git a/array_api_compat/_internal.py b/array_api_compat/_internal.py index 170a1ff9..cd8d939f 100644 --- a/array_api_compat/_internal.py +++ b/array_api_compat/_internal.py @@ -2,10 +2,16 @@ Internal helpers """ +from collections.abc import Callable from functools import wraps from inspect import signature +from types import ModuleType +from typing import TypeVar -def get_xp(xp): +_T = TypeVar("_T") + + +def get_xp(xp: ModuleType) -> Callable[[Callable[..., _T]], Callable[..., _T]]: """ Decorator to automatically replace xp with the corresponding array module. @@ -22,14 +28,14 @@ def func(x, /, xp, kwarg=None): """ - def inner(f): + def inner(f: Callable[..., _T], /) -> Callable[..., _T]: @wraps(f) - def wrapped_f(*args, **kwargs): + def wrapped_f(*args: object, **kwargs: object) -> object: return f(*args, xp=xp, **kwargs) sig = signature(f) new_sig = sig.replace( - parameters=[sig.parameters[i] for i in sig.parameters if i != "xp"] + parameters=[par for i, par in sig.parameters.items() if i != "xp"] ) if wrapped_f.__doc__ is None: @@ -40,7 +46,14 @@ def wrapped_f(*args, **kwargs): specification for more details. """ - wrapped_f.__signature__ = new_sig - return wrapped_f + wrapped_f.__signature__ = new_sig # pyright: ignore[reportAttributeAccessIssue] + return wrapped_f # pyright: ignore[reportReturnType] return inner + + +__all__ = ["get_xp"] + + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/common/__init__.py b/array_api_compat/common/__init__.py index 91ab1c40..82360807 100644 --- a/array_api_compat/common/__init__.py +++ b/array_api_compat/common/__init__.py @@ -1 +1 @@ -from ._helpers import * # noqa: F403 +from ._helpers import * # noqa: F403 diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 351b5bd6..8ea9162a 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -5,158 +5,170 @@ from __future__ import annotations import inspect -from typing import Any, NamedTuple, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Sequence, cast +from ._helpers import _check_device, array_namespace +from ._helpers import device as _get_device +from ._helpers import is_cupy_namespace as _is_cupy_namespace from ._typing import Array, Device, DType, Namespace -from ._helpers import ( - array_namespace, - _check_device, - device as _get_device, - is_cupy_namespace as _is_cupy_namespace -) +if TYPE_CHECKING: + # TODO: import from typing (requires Python >=3.13) + from typing_extensions import TypeIs # These functions are modified from the NumPy versions. # Creation functions add the device keyword (which does nothing for NumPy and Dask) + def arange( - start: Union[int, float], + start: float, /, - stop: Optional[Union[int, float]] = None, - step: Union[int, float] = 1, + stop: float | None = None, + step: float = 1, *, xp: Namespace, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, ) -> Array: _check_device(xp, device) return xp.arange(start, stop=stop, step=step, dtype=dtype, **kwargs) + def empty( - shape: Union[int, Tuple[int, ...]], + shape: int | tuple[int, ...], xp: Namespace, *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, ) -> Array: _check_device(xp, device) return xp.empty(shape, dtype=dtype, **kwargs) + def empty_like( x: Array, /, xp: Namespace, *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, ) -> Array: _check_device(xp, device) return xp.empty_like(x, dtype=dtype, **kwargs) + def eye( n_rows: int, - n_cols: Optional[int] = None, + n_cols: int | None = None, /, *, xp: Namespace, k: int = 0, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, ) -> Array: _check_device(xp, device) return xp.eye(n_rows, M=n_cols, k=k, dtype=dtype, **kwargs) + def full( - shape: Union[int, Tuple[int, ...]], - fill_value: bool | int | float | complex, + shape: int | tuple[int, ...], + fill_value: complex, xp: Namespace, *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, ) -> Array: _check_device(xp, device) return xp.full(shape, fill_value, dtype=dtype, **kwargs) + def full_like( x: Array, /, - fill_value: bool | int | float | complex, + fill_value: complex, *, xp: Namespace, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, ) -> Array: _check_device(xp, device) return xp.full_like(x, fill_value, dtype=dtype, **kwargs) + def linspace( - start: Union[int, float], - stop: Union[int, float], + start: float, + stop: float, /, num: int, *, xp: Namespace, - dtype: Optional[DType] = None, - device: Optional[Device] = None, + dtype: DType | None = None, + device: Device | None = None, endpoint: bool = True, - **kwargs, + **kwargs: object, ) -> Array: _check_device(xp, device) return xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint, **kwargs) + def ones( - shape: Union[int, Tuple[int, ...]], + shape: int | tuple[int, ...], xp: Namespace, *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, ) -> Array: _check_device(xp, device) return xp.ones(shape, dtype=dtype, **kwargs) + def ones_like( x: Array, /, xp: Namespace, *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, ) -> Array: _check_device(xp, device) return xp.ones_like(x, dtype=dtype, **kwargs) + def zeros( - shape: Union[int, Tuple[int, ...]], + shape: int | tuple[int, ...], xp: Namespace, *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, ) -> Array: _check_device(xp, device) return xp.zeros(shape, dtype=dtype, **kwargs) + def zeros_like( x: Array, /, xp: Namespace, *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, ) -> Array: _check_device(xp, device) return xp.zeros_like(x, dtype=dtype, **kwargs) + # np.unique() is split into four functions in the array API: # unique_all, unique_counts, unique_inverse, and unique_values (this is done # to remove polymorphic return types). @@ -164,6 +176,7 @@ def zeros_like( # The functions here return namedtuples (np.unique() returns a normal # tuple). + # Note that these named tuples aren't actually part of the standard namespace, # but I don't see any issue with exporting the names here regardless. class UniqueAllResult(NamedTuple): @@ -188,10 +201,11 @@ def _unique_kwargs(xp: Namespace) -> dict[str, bool]: # trying to parse version numbers, just check if equal_nan is in the # signature. s = inspect.signature(xp.unique) - if 'equal_nan' in s.parameters: - return {'equal_nan': False} + if "equal_nan" in s.parameters: + return {"equal_nan": False} return {} + def unique_all(x: Array, /, xp: Namespace) -> UniqueAllResult: kwargs = _unique_kwargs(xp) values, indices, inverse_indices, counts = xp.unique( @@ -215,11 +229,7 @@ def unique_all(x: Array, /, xp: Namespace) -> UniqueAllResult: def unique_counts(x: Array, /, xp: Namespace) -> UniqueCountsResult: kwargs = _unique_kwargs(xp) res = xp.unique( - x, - return_counts=True, - return_index=False, - return_inverse=False, - **kwargs + x, return_counts=True, return_index=False, return_inverse=False, **kwargs ) return UniqueCountsResult(*res) @@ -250,51 +260,58 @@ def unique_values(x: Array, /, xp: Namespace) -> Array: **kwargs, ) + # These functions have different keyword argument names + def std( x: Array, /, xp: Namespace, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - correction: Union[int, float] = 0.0, # correction instead of ddof + axis: int | tuple[int, ...] | None = None, + correction: float = 0.0, # correction instead of ddof keepdims: bool = False, - **kwargs, + **kwargs: object, ) -> Array: return xp.std(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs) + def var( x: Array, /, xp: Namespace, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - correction: Union[int, float] = 0.0, # correction instead of ddof + axis: int | tuple[int, ...] | None = None, + correction: float = 0.0, # correction instead of ddof keepdims: bool = False, - **kwargs, + **kwargs: object, ) -> Array: return xp.var(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs) + # cumulative_sum is renamed from cumsum, and adds the include_initial keyword # argument + def cumulative_sum( x: Array, /, xp: Namespace, *, - axis: Optional[int] = None, - dtype: Optional[DType] = None, + axis: int | None = None, + dtype: DType | None = None, include_initial: bool = False, - **kwargs, + **kwargs: object, ) -> Array: wrapped_xp = array_namespace(x) # TODO: The standard is not clear about what should happen when x.ndim == 0. if axis is None: if x.ndim > 1: - raise ValueError("axis must be specified in cumulative_sum for more than one dimension") + raise ValueError( + "axis must be specified in cumulative_sum for more than one dimension" + ) axis = 0 res = xp.cumsum(x, axis=axis, dtype=dtype, **kwargs) @@ -304,7 +321,12 @@ def cumulative_sum( initial_shape = list(x.shape) initial_shape[axis] = 1 res = xp.concatenate( - [wrapped_xp.zeros(shape=initial_shape, dtype=res.dtype, device=_get_device(res)), res], + [ + wrapped_xp.zeros( + shape=initial_shape, dtype=res.dtype, device=_get_device(res) + ), + res, + ], axis=axis, ) return res @@ -315,16 +337,18 @@ def cumulative_prod( /, xp: Namespace, *, - axis: Optional[int] = None, - dtype: Optional[DType] = None, + axis: int | None = None, + dtype: DType | None = None, include_initial: bool = False, - **kwargs, + **kwargs: object, ) -> Array: wrapped_xp = array_namespace(x) if axis is None: if x.ndim > 1: - raise ValueError("axis must be specified in cumulative_prod for more than one dimension") + raise ValueError( + "axis must be specified in cumulative_prod for more than one dimension" + ) axis = 0 res = xp.cumprod(x, axis=axis, dtype=dtype, **kwargs) @@ -334,24 +358,30 @@ def cumulative_prod( initial_shape = list(x.shape) initial_shape[axis] = 1 res = xp.concatenate( - [wrapped_xp.ones(shape=initial_shape, dtype=res.dtype, device=_get_device(res)), res], + [ + wrapped_xp.ones( + shape=initial_shape, dtype=res.dtype, device=_get_device(res) + ), + res, + ], axis=axis, ) return res + # The min and max argument names in clip are different and not optional in numpy, and type # promotion behavior is different. def clip( x: Array, /, - min: Optional[Union[int, float, Array]] = None, - max: Optional[Union[int, float, Array]] = None, + min: float | Array | None = None, + max: float | Array | None = None, *, xp: Namespace, # TODO: np.clip has other ufunc kwargs - out: Optional[Array] = None, + out: Array | None = None, ) -> Array: - def _isscalar(a): + def _isscalar(a: object) -> TypeIs[int | float | None]: return isinstance(a, (int, float, type(None))) min_shape = () if _isscalar(min) else min.shape @@ -378,7 +408,6 @@ def _isscalar(a): # but an answer of 0 might be preferred. See # https://github.com/numpy/numpy/issues/24976 for more discussion on this issue. - # At least handle the case of Python integers correctly (see # https://github.com/numpy/numpy/pull/26892). if wrapped_xp.isdtype(x.dtype, "integral"): @@ -390,6 +419,7 @@ def _isscalar(a): dev = _get_device(x) if out is None: out = wrapped_xp.empty(result_shape, dtype=x.dtype, device=dev) + assert out is not None # workaround for a type-narrowing issue in pyright out[()] = x if min is not None: @@ -407,19 +437,21 @@ def _isscalar(a): # Return a scalar for 0-D return out[()] + # Unlike transpose(), the axes argument to permute_dims() is required. -def permute_dims(x: Array, /, axes: Tuple[int, ...], xp: Namespace) -> Array: +def permute_dims(x: Array, /, axes: tuple[int, ...], xp: Namespace) -> Array: return xp.transpose(x, axes) + # np.reshape calls the keyword argument 'newshape' instead of 'shape' def reshape( x: Array, /, - shape: Tuple[int, ...], + shape: tuple[int, ...], xp: Namespace, *, copy: Optional[bool] = None, - **kwargs, + **kwargs: object, ) -> Array: if copy is True: x = x.copy() @@ -429,6 +461,7 @@ def reshape( return y return xp.reshape(x, shape, **kwargs) + # The descending keyword is new in sort and argsort, and 'kind' replaced with # 'stable' def argsort( @@ -439,13 +472,13 @@ def argsort( axis: int = -1, descending: bool = False, stable: bool = True, - **kwargs, + **kwargs: object, ) -> Array: # Note: this keyword argument is different, and the default is different. # We set it in kwargs like this because numpy.sort uses kind='quicksort' # as the default whereas cupy.sort uses kind=None. if stable: - kwargs['kind'] = "stable" + kwargs["kind"] = "stable" if not descending: res = xp.argsort(x, axis=axis, **kwargs) else: @@ -462,6 +495,7 @@ def argsort( res = max_i - res return res + def sort( x: Array, /, @@ -470,68 +504,78 @@ def sort( axis: int = -1, descending: bool = False, stable: bool = True, - **kwargs, + **kwargs: object, ) -> Array: # Note: this keyword argument is different, and the default is different. # We set it in kwargs like this because numpy.sort uses kind='quicksort' # as the default whereas cupy.sort uses kind=None. if stable: - kwargs['kind'] = "stable" + kwargs["kind"] = "stable" res = xp.sort(x, axis=axis, **kwargs) if descending: res = xp.flip(res, axis=axis) return res + # nonzero should error for zero-dimensional arrays -def nonzero(x: Array, /, xp: Namespace, **kwargs) -> Tuple[Array, ...]: +def nonzero(x: Array, /, xp: Namespace, **kwargs: object) -> tuple[Array, ...]: if x.ndim == 0: raise ValueError("nonzero() does not support zero-dimensional arrays") return xp.nonzero(x, **kwargs) + # ceil, floor, and trunc return integers for integer inputs -def ceil(x: Array, /, xp: Namespace, **kwargs) -> Array: + +def ceil(x: Array, /, xp: Namespace, **kwargs: object) -> Array: if xp.issubdtype(x.dtype, xp.integer): return x return xp.ceil(x, **kwargs) -def floor(x: Array, /, xp: Namespace, **kwargs) -> Array: + +def floor(x: Array, /, xp: Namespace, **kwargs: object) -> Array: if xp.issubdtype(x.dtype, xp.integer): return x return xp.floor(x, **kwargs) -def trunc(x: Array, /, xp: Namespace, **kwargs) -> Array: + +def trunc(x: Array, /, xp: Namespace, **kwargs: object) -> Array: if xp.issubdtype(x.dtype, xp.integer): return x return xp.trunc(x, **kwargs) + # linear algebra functions -def matmul(x1: Array, x2: Array, /, xp: Namespace, **kwargs) -> Array: + +def matmul(x1: Array, x2: Array, /, xp: Namespace, **kwargs: object) -> Array: return xp.matmul(x1, x2, **kwargs) + # Unlike transpose, matrix_transpose only transposes the last two axes. def matrix_transpose(x: Array, /, xp: Namespace) -> Array: if x.ndim < 2: raise ValueError("x must be at least 2-dimensional for matrix_transpose") return xp.swapaxes(x, -1, -2) + def tensordot( x1: Array, x2: Array, /, xp: Namespace, *, - axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2, - **kwargs, + axes: int | tuple[Sequence[int], Sequence[int]] = 2, + **kwargs: object, ) -> Array: return xp.tensordot(x1, x2, axes=axes, **kwargs) + def vecdot(x1: Array, x2: Array, /, xp: Namespace, *, axis: int = -1) -> Array: if x1.shape[axis] != x2.shape[axis]: raise ValueError("x1 and x2 must have the same size along the given axis") - if hasattr(xp, 'broadcast_tensors'): + if hasattr(xp, "broadcast_tensors"): _broadcast = xp.broadcast_tensors else: _broadcast = xp.broadcast_arrays @@ -543,14 +587,16 @@ def vecdot(x1: Array, x2: Array, /, xp: Namespace, *, axis: int = -1) -> Array: res = xp.conj(x1_[..., None, :]) @ x2_[..., None] return res[..., 0, 0] + # isdtype is a new function in the 2022.12 array API specification. + def isdtype( dtype: DType, - kind: Union[DType, str, Tuple[Union[DType, str], ...]], + kind: DType | str | tuple[DType | str, ...], xp: Namespace, *, - _tuple: bool = True, # Disallow nested tuples + _tuple: bool = True, # Disallow nested tuples ) -> bool: """ Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``. @@ -563,21 +609,24 @@ def isdtype( for more details """ if isinstance(kind, tuple) and _tuple: - return any(isdtype(dtype, k, xp, _tuple=False) for k in kind) + return any( + isdtype(dtype, k, xp, _tuple=False) + for k in cast("tuple[DType | str, ...]", kind) + ) elif isinstance(kind, str): - if kind == 'bool': + if kind == "bool": return dtype == xp.bool_ - elif kind == 'signed integer': + elif kind == "signed integer": return xp.issubdtype(dtype, xp.signedinteger) - elif kind == 'unsigned integer': + elif kind == "unsigned integer": return xp.issubdtype(dtype, xp.unsignedinteger) - elif kind == 'integral': + elif kind == "integral": return xp.issubdtype(dtype, xp.integer) - elif kind == 'real floating': + elif kind == "real floating": return xp.issubdtype(dtype, xp.floating) - elif kind == 'complex floating': + elif kind == "complex floating": return xp.issubdtype(dtype, xp.complexfloating) - elif kind == 'numeric': + elif kind == "numeric": return xp.issubdtype(dtype, xp.number) else: raise ValueError(f"Unrecognized data type kind: {kind!r}") @@ -588,24 +637,27 @@ def isdtype( # array_api_strict implementation will be very strict. return dtype == kind + # unstack is a new function in the 2023.12 array API standard -def unstack(x: Array, /, xp: Namespace, *, axis: int = 0) -> Tuple[Array, ...]: +def unstack(x: Array, /, xp: Namespace, *, axis: int = 0) -> tuple[Array, ...]: if x.ndim == 0: raise ValueError("Input array must be at least 1-d.") return tuple(xp.moveaxis(x, axis, 0)) + # numpy 1.26 does not use the standard definition for sign on complex numbers -def sign(x: Array, /, xp: Namespace, **kwargs) -> Array: - if isdtype(x.dtype, 'complex floating', xp=xp): - out = (x/xp.abs(x, **kwargs))[...] + +def sign(x: Array, /, xp: Namespace, **kwargs: object) -> Array: + if isdtype(x.dtype, "complex floating", xp=xp): + out = (x / xp.abs(x, **kwargs))[...] # sign(0) = 0 but the above formula would give nan - out[x == 0+0j] = 0+0j + out[x == 0j] = 0j else: out = xp.sign(x, **kwargs) # CuPy sign() does not propagate nans. See # https://github.com/data-apis/array-api-compat/issues/136 - if _is_cupy_namespace(xp) and isdtype(x.dtype, 'real floating', xp=xp): + if _is_cupy_namespace(xp) and isdtype(x.dtype, "real floating", xp=xp): out[xp.isnan(x)] = xp.nan return out[()] @@ -626,13 +678,50 @@ def iinfo(type_: DType | Array, /, xp: Namespace) -> Any: return xp.iinfo(type_.dtype) -__all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like', - 'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like', - 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult', - 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values', - 'std', 'var', 'cumulative_sum', 'cumulative_prod','clip', 'permute_dims', - 'reshape', 'argsort', 'sort', 'nonzero', 'ceil', 'floor', 'trunc', - 'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype', - 'unstack', 'sign', 'finfo', 'iinfo'] - -_all_ignore = ['inspect', 'array_namespace', 'NamedTuple'] +__all__ = [ + "arange", + "empty", + "empty_like", + "eye", + "full", + "full_like", + "linspace", + "ones", + "ones_like", + "zeros", + "zeros_like", + "UniqueAllResult", + "UniqueCountsResult", + "UniqueInverseResult", + "unique_all", + "unique_counts", + "unique_inverse", + "unique_values", + "std", + "var", + "cumulative_sum", + "cumulative_prod", + "clip", + "permute_dims", + "reshape", + "argsort", + "sort", + "nonzero", + "ceil", + "floor", + "trunc", + "matmul", + "matrix_transpose", + "tensordot", + "vecdot", + "isdtype", + "unstack", + "sign", + "finfo", + "iinfo", +] +_all_ignore = ["inspect", "array_namespace", "NamedTuple"] + + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/common/_fft.py b/array_api_compat/common/_fft.py index bd2a4e1a..18839d37 100644 --- a/array_api_compat/common/_fft.py +++ b/array_api_compat/common/_fft.py @@ -1,9 +1,11 @@ from __future__ import annotations from collections.abc import Sequence -from typing import Union, Optional, Literal +from typing import Literal, TypeAlias -from ._typing import Device, Array, DType, Namespace +from ._typing import Array, Device, DType, Namespace + +_Norm: TypeAlias = Literal["backward", "ortho", "forward"] # Note: NumPy fft functions improperly upcast float32 and complex64 to # complex128, which is why we require wrapping them all here. @@ -13,9 +15,9 @@ def fft( /, xp: Namespace, *, - n: Optional[int] = None, + n: int | None = None, axis: int = -1, - norm: Literal["backward", "ortho", "forward"] = "backward", + norm: _Norm = "backward", ) -> Array: res = xp.fft.fft(x, n=n, axis=axis, norm=norm) if x.dtype in [xp.float32, xp.complex64]: @@ -27,9 +29,9 @@ def ifft( /, xp: Namespace, *, - n: Optional[int] = None, + n: int | None = None, axis: int = -1, - norm: Literal["backward", "ortho", "forward"] = "backward", + norm: _Norm = "backward", ) -> Array: res = xp.fft.ifft(x, n=n, axis=axis, norm=norm) if x.dtype in [xp.float32, xp.complex64]: @@ -41,9 +43,9 @@ def fftn( /, xp: Namespace, *, - s: Sequence[int] = None, - axes: Sequence[int] = None, - norm: Literal["backward", "ortho", "forward"] = "backward", + s: Sequence[int] | None = None, + axes: Sequence[int] | None = None, + norm: _Norm = "backward", ) -> Array: res = xp.fft.fftn(x, s=s, axes=axes, norm=norm) if x.dtype in [xp.float32, xp.complex64]: @@ -55,9 +57,9 @@ def ifftn( /, xp: Namespace, *, - s: Sequence[int] = None, - axes: Sequence[int] = None, - norm: Literal["backward", "ortho", "forward"] = "backward", + s: Sequence[int] | None = None, + axes: Sequence[int] | None = None, + norm: _Norm = "backward", ) -> Array: res = xp.fft.ifftn(x, s=s, axes=axes, norm=norm) if x.dtype in [xp.float32, xp.complex64]: @@ -69,9 +71,9 @@ def rfft( /, xp: Namespace, *, - n: Optional[int] = None, + n: int | None = None, axis: int = -1, - norm: Literal["backward", "ortho", "forward"] = "backward", + norm: _Norm = "backward", ) -> Array: res = xp.fft.rfft(x, n=n, axis=axis, norm=norm) if x.dtype == xp.float32: @@ -83,9 +85,9 @@ def irfft( /, xp: Namespace, *, - n: Optional[int] = None, + n: int | None = None, axis: int = -1, - norm: Literal["backward", "ortho", "forward"] = "backward", + norm: _Norm = "backward", ) -> Array: res = xp.fft.irfft(x, n=n, axis=axis, norm=norm) if x.dtype == xp.complex64: @@ -97,9 +99,9 @@ def rfftn( /, xp: Namespace, *, - s: Sequence[int] = None, - axes: Sequence[int] = None, - norm: Literal["backward", "ortho", "forward"] = "backward", + s: Sequence[int] | None = None, + axes: Sequence[int] | None = None, + norm: _Norm = "backward", ) -> Array: res = xp.fft.rfftn(x, s=s, axes=axes, norm=norm) if x.dtype == xp.float32: @@ -111,9 +113,9 @@ def irfftn( /, xp: Namespace, *, - s: Sequence[int] = None, - axes: Sequence[int] = None, - norm: Literal["backward", "ortho", "forward"] = "backward", + s: Sequence[int] | None = None, + axes: Sequence[int] | None = None, + norm: _Norm = "backward", ) -> Array: res = xp.fft.irfftn(x, s=s, axes=axes, norm=norm) if x.dtype == xp.complex64: @@ -125,9 +127,9 @@ def hfft( /, xp: Namespace, *, - n: Optional[int] = None, + n: int | None = None, axis: int = -1, - norm: Literal["backward", "ortho", "forward"] = "backward", + norm: _Norm = "backward", ) -> Array: res = xp.fft.hfft(x, n=n, axis=axis, norm=norm) if x.dtype in [xp.float32, xp.complex64]: @@ -139,9 +141,9 @@ def ihfft( /, xp: Namespace, *, - n: Optional[int] = None, + n: int | None = None, axis: int = -1, - norm: Literal["backward", "ortho", "forward"] = "backward", + norm: _Norm = "backward", ) -> Array: res = xp.fft.ihfft(x, n=n, axis=axis, norm=norm) if x.dtype in [xp.float32, xp.complex64]: @@ -154,8 +156,8 @@ def fftfreq( xp: Namespace, *, d: float = 1.0, - dtype: Optional[DType] = None, - device: Optional[Device] = None, + dtype: DType | None = None, + device: Device | None = None, ) -> Array: if device not in ["cpu", None]: raise ValueError(f"Unsupported device {device!r}") @@ -170,8 +172,8 @@ def rfftfreq( xp: Namespace, *, d: float = 1.0, - dtype: Optional[DType] = None, - device: Optional[Device] = None, + dtype: DType | None = None, + device: Device | None = None, ) -> Array: if device not in ["cpu", None]: raise ValueError(f"Unsupported device {device!r}") @@ -181,12 +183,12 @@ def rfftfreq( return res def fftshift( - x: Array, /, xp: Namespace, *, axes: Union[int, Sequence[int]] = None + x: Array, /, xp: Namespace, *, axes: int | Sequence[int] | None = None ) -> Array: return xp.fft.fftshift(x, axes=axes) def ifftshift( - x: Array, /, xp: Namespace, *, axes: Union[int, Sequence[int]] = None + x: Array, /, xp: Namespace, *, axes: int | Sequence[int] | None = None ) -> Array: return xp.fft.ifftshift(x, axes=axes) @@ -206,3 +208,6 @@ def ifftshift( "fftshift", "ifftshift", ] + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 67c619b8..db3e4cd7 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -5,33 +5,82 @@ that are in __all__ are intended as additional helper functions for use by end users of the compat library. """ + from __future__ import annotations -import sys -import math import inspect +import math +import sys import warnings -from typing import Optional, Union, Any +from collections.abc import Collection +from typing import ( + TYPE_CHECKING, + Any, + Final, + Literal, + SupportsIndex, + TypeAlias, + TypeGuard, + TypeVar, + cast, + overload, +) + +from ._typing import Array, Device, HasShape, Namespace, SupportsArrayNamespace + +if TYPE_CHECKING: + + import dask.array as da + import jax + import ndonnx as ndx + import numpy as np + import numpy.typing as npt + import sparse # pyright: ignore[reportMissingTypeStubs] + import torch + + # TODO: import from typing (requires Python >=3.13) + from typing_extensions import TypeIs, TypeVar -from ._typing import Array, Device, Namespace + _SizeT = TypeVar("_SizeT", bound = int | None) + _ZeroGradientArray: TypeAlias = npt.NDArray[np.void] + _CupyArray: TypeAlias = Any # cupy has no py.typed -def _is_jax_zero_gradient_array(x: object) -> bool: + _ArrayApiObj: TypeAlias = ( + npt.NDArray[Any] + | da.Array + | jax.Array + | ndx.Array + | sparse.SparseArray + | torch.Tensor + | SupportsArrayNamespace[Any] + | _CupyArray + ) + +_API_VERSIONS_OLD: Final = frozenset({"2021.12", "2022.12", "2023.12"}) +_API_VERSIONS: Final = _API_VERSIONS_OLD | frozenset({"2024.12"}) + + +def _is_jax_zero_gradient_array(x: object) -> TypeGuard[_ZeroGradientArray]: """Return True if `x` is a zero-gradient array. These arrays are a design quirk of Jax that may one day be removed. See https://github.com/google/jax/issues/20620. """ - if 'numpy' not in sys.modules or 'jax' not in sys.modules: + if "numpy" not in sys.modules or "jax" not in sys.modules: return False - import numpy as np import jax + import numpy as np - return isinstance(x, np.ndarray) and x.dtype == jax.float0 + jax_float0 = cast("np.dtype[np.void]", jax.float0) + return ( + isinstance(x, np.ndarray) + and cast("npt.NDArray[np.void]", x).dtype == jax_float0 + ) -def is_numpy_array(x: object) -> bool: +def is_numpy_array(x: object) -> TypeGuard[npt.NDArray[Any]]: """ Return True if `x` is a NumPy array. @@ -53,14 +102,14 @@ def is_numpy_array(x: object) -> bool: is_pydata_sparse_array """ # Avoid importing NumPy if it isn't already - if 'numpy' not in sys.modules: + if "numpy" not in sys.modules: return False import numpy as np # TODO: Should we reject ndarray subclasses? return (isinstance(x, (np.ndarray, np.generic)) - and not _is_jax_zero_gradient_array(x)) + and not _is_jax_zero_gradient_array(x)) # pyright: ignore[reportUnknownArgumentType] # fmt: skip def is_cupy_array(x: object) -> bool: @@ -85,16 +134,16 @@ def is_cupy_array(x: object) -> bool: is_pydata_sparse_array """ # Avoid importing CuPy if it isn't already - if 'cupy' not in sys.modules: + if "cupy" not in sys.modules: return False - import cupy as cp + import cupy as cp # pyright: ignore[reportMissingTypeStubs] # TODO: Should we reject ndarray subclasses? - return isinstance(x, cp.ndarray) + return isinstance(x, cp.ndarray) # pyright: ignore[reportUnknownMemberType] -def is_torch_array(x: object) -> bool: +def is_torch_array(x: object) -> TypeIs[torch.Tensor]: """ Return True if `x` is a PyTorch tensor. @@ -113,7 +162,7 @@ def is_torch_array(x: object) -> bool: is_pydata_sparse_array """ # Avoid importing torch if it isn't already - if 'torch' not in sys.modules: + if "torch" not in sys.modules: return False import torch @@ -122,7 +171,7 @@ def is_torch_array(x: object) -> bool: return isinstance(x, torch.Tensor) -def is_ndonnx_array(x: object) -> bool: +def is_ndonnx_array(x: object) -> TypeIs[ndx.Array]: """ Return True if `x` is a ndonnx Array. @@ -142,7 +191,7 @@ def is_ndonnx_array(x: object) -> bool: is_pydata_sparse_array """ # Avoid importing torch if it isn't already - if 'ndonnx' not in sys.modules: + if "ndonnx" not in sys.modules: return False import ndonnx as ndx @@ -150,7 +199,7 @@ def is_ndonnx_array(x: object) -> bool: return isinstance(x, ndx.Array) -def is_dask_array(x: object) -> bool: +def is_dask_array(x: object) -> TypeIs[da.Array]: """ Return True if `x` is a dask.array Array. @@ -170,7 +219,7 @@ def is_dask_array(x: object) -> bool: is_pydata_sparse_array """ # Avoid importing dask if it isn't already - if 'dask.array' not in sys.modules: + if "dask.array" not in sys.modules: return False import dask.array @@ -178,7 +227,7 @@ def is_dask_array(x: object) -> bool: return isinstance(x, dask.array.Array) -def is_jax_array(x: object) -> bool: +def is_jax_array(x: object) -> TypeIs[jax.Array]: """ Return True if `x` is a JAX array. @@ -199,7 +248,7 @@ def is_jax_array(x: object) -> bool: is_pydata_sparse_array """ # Avoid importing jax if it isn't already - if 'jax' not in sys.modules: + if "jax" not in sys.modules: return False import jax @@ -207,7 +256,7 @@ def is_jax_array(x: object) -> bool: return isinstance(x, jax.Array) or _is_jax_zero_gradient_array(x) -def is_pydata_sparse_array(x) -> bool: +def is_pydata_sparse_array(x: object) -> TypeIs[sparse.SparseArray]: """ Return True if `x` is an array from the `sparse` package. @@ -228,16 +277,16 @@ def is_pydata_sparse_array(x) -> bool: is_jax_array """ # Avoid importing jax if it isn't already - if 'sparse' not in sys.modules: + if "sparse" not in sys.modules: return False - import sparse + import sparse # pyright: ignore[reportMissingTypeStubs] # TODO: Account for other backends. return isinstance(x, sparse.SparseArray) -def is_array_api_obj(x: object) -> bool: +def is_array_api_obj(x: object) -> TypeIs[_ArrayApiObj]: # pyright: ignore[reportUnknownParameterType] """ Return True if `x` is an array API compatible array object. @@ -252,18 +301,20 @@ def is_array_api_obj(x: object) -> bool: is_dask_array is_jax_array """ - return is_numpy_array(x) \ - or is_cupy_array(x) \ - or is_torch_array(x) \ - or is_dask_array(x) \ - or is_jax_array(x) \ - or is_pydata_sparse_array(x) \ - or hasattr(x, '__array_namespace__') + return ( + is_numpy_array(x) + or is_cupy_array(x) + or is_torch_array(x) + or is_dask_array(x) + or is_jax_array(x) + or is_pydata_sparse_array(x) + or hasattr(x, "__array_namespace__") + ) def _compat_module_name() -> str: - assert __name__.endswith('.common._helpers') - return __name__.removesuffix('.common._helpers') + assert __name__.endswith(".common._helpers") + return __name__.removesuffix(".common._helpers") def is_numpy_namespace(xp: Namespace) -> bool: @@ -284,7 +335,7 @@ def is_numpy_namespace(xp: Namespace) -> bool: is_pydata_sparse_namespace is_array_api_strict_namespace """ - return xp.__name__ in {'numpy', _compat_module_name() + '.numpy'} + return xp.__name__ in {"numpy", _compat_module_name() + ".numpy"} def is_cupy_namespace(xp: Namespace) -> bool: @@ -305,7 +356,7 @@ def is_cupy_namespace(xp: Namespace) -> bool: is_pydata_sparse_namespace is_array_api_strict_namespace """ - return xp.__name__ in {'cupy', _compat_module_name() + '.cupy'} + return xp.__name__ in {"cupy", _compat_module_name() + ".cupy"} def is_torch_namespace(xp: Namespace) -> bool: @@ -326,7 +377,7 @@ def is_torch_namespace(xp: Namespace) -> bool: is_pydata_sparse_namespace is_array_api_strict_namespace """ - return xp.__name__ in {'torch', _compat_module_name() + '.torch'} + return xp.__name__ in {"torch", _compat_module_name() + ".torch"} def is_ndonnx_namespace(xp: Namespace) -> bool: @@ -345,7 +396,7 @@ def is_ndonnx_namespace(xp: Namespace) -> bool: is_pydata_sparse_namespace is_array_api_strict_namespace """ - return xp.__name__ == 'ndonnx' + return xp.__name__ == "ndonnx" def is_dask_namespace(xp: Namespace) -> bool: @@ -366,7 +417,7 @@ def is_dask_namespace(xp: Namespace) -> bool: is_pydata_sparse_namespace is_array_api_strict_namespace """ - return xp.__name__ in {'dask.array', _compat_module_name() + '.dask.array'} + return xp.__name__ in {"dask.array", _compat_module_name() + ".dask.array"} def is_jax_namespace(xp: Namespace) -> bool: @@ -388,7 +439,7 @@ def is_jax_namespace(xp: Namespace) -> bool: is_pydata_sparse_namespace is_array_api_strict_namespace """ - return xp.__name__ in {'jax.numpy', 'jax.experimental.array_api'} + return xp.__name__ in {"jax.numpy", "jax.experimental.array_api"} def is_pydata_sparse_namespace(xp: Namespace) -> bool: @@ -407,7 +458,7 @@ def is_pydata_sparse_namespace(xp: Namespace) -> bool: is_jax_namespace is_array_api_strict_namespace """ - return xp.__name__ == 'sparse' + return xp.__name__ == "sparse" def is_array_api_strict_namespace(xp: Namespace) -> bool: @@ -426,21 +477,24 @@ def is_array_api_strict_namespace(xp: Namespace) -> bool: is_jax_namespace is_pydata_sparse_namespace """ - return xp.__name__ == 'array_api_strict' + return xp.__name__ == "array_api_strict" -def _check_api_version(api_version: str) -> None: - if api_version in ['2021.12', '2022.12', '2023.12']: - warnings.warn(f"The {api_version} version of the array API specification was requested but the returned namespace is actually version 2024.12") - elif api_version is not None and api_version not in ['2021.12', '2022.12', - '2023.12', '2024.12']: - raise ValueError("Only the 2024.12 version of the array API specification is currently supported") +def _check_api_version(api_version: str | None) -> None: + if api_version in _API_VERSIONS_OLD: + warnings.warn( + f"The {api_version} version of the array API specification was requested but the returned namespace is actually version 2024.12" + ) + elif api_version is not None and api_version not in _API_VERSIONS: + raise ValueError( + "Only the 2024.12 version of the array API specification is currently supported" + ) def array_namespace( - *xs: Union[Array, bool, int, float, complex, None], - api_version: Optional[str] = None, - use_compat: Optional[bool] = None, + *xs: Array | complex | None, + api_version: str | None = None, + use_compat: bool | None = None, ) -> Namespace: """ Get the array API compatible namespace for the arrays `xs`. @@ -510,11 +564,13 @@ def your_function(x, y): _use_compat = use_compat in [None, True] - namespaces = set() + namespaces: set[Namespace] = set() for x in xs: if is_numpy_array(x): - from .. import numpy as numpy_namespace import numpy as np + + from .. import numpy as numpy_namespace + if use_compat is True: _check_api_version(api_version) namespaces.add(numpy_namespace) @@ -528,25 +584,31 @@ def your_function(x, y): if _use_compat: _check_api_version(api_version) from .. import cupy as cupy_namespace + namespaces.add(cupy_namespace) else: - import cupy as cp + import cupy as cp # pyright: ignore[reportMissingTypeStubs] + namespaces.add(cp) elif is_torch_array(x): if _use_compat: _check_api_version(api_version) from .. import torch as torch_namespace + namespaces.add(torch_namespace) else: import torch + namespaces.add(torch) elif is_dask_array(x): if _use_compat: _check_api_version(api_version) from ..dask import array as dask_namespace + namespaces.add(dask_namespace) else: import dask.array as da + namespaces.add(da) elif is_jax_array(x): if use_compat is True: @@ -558,23 +620,27 @@ def your_function(x, y): # JAX v0.4.32 and newer implements the array API directly in jax.numpy. # For older JAX versions, it is available via jax.experimental.array_api. import jax.numpy + if hasattr(jax.numpy, "__array_api_version__"): jnp = jax.numpy else: - import jax.experimental.array_api as jnp + import jax.experimental.array_api as jnp # pyright: ignore[reportMissingImports] namespaces.add(jnp) elif is_pydata_sparse_array(x): if use_compat is True: _check_api_version(api_version) raise ValueError("`sparse` does not have an array-api-compat wrapper") else: - import sparse + import sparse # pyright: ignore[reportMissingTypeStubs] # `sparse` is already an array namespace. We do not have a wrapper # submodule for it. namespaces.add(sparse) - elif hasattr(x, '__array_namespace__'): + elif hasattr(x, "__array_namespace__"): if use_compat is True: - raise ValueError("The given array does not have an array-api-compat wrapper") + raise ValueError( + "The given array does not have an array-api-compat wrapper" + ) + x = cast("SupportsArrayNamespace[Any]", x) namespaces.add(x.__array_namespace__(api_version=api_version)) elif isinstance(x, (bool, int, float, complex, type(None))): continue @@ -588,15 +654,16 @@ def your_function(x, y): if len(namespaces) != 1: raise TypeError(f"Multiple namespaces for array inputs: {namespaces}") - xp, = namespaces + (xp,) = namespaces return xp + # backwards compatibility alias get_namespace = array_namespace -def _check_device(bare_xp, device): +def _check_device(bare_xp: Namespace, device: Device) -> None: # pyright: ignore[reportUnusedFunction] """ Validate dummy device on device-less array backends. @@ -609,11 +676,11 @@ def _check_device(bare_xp, device): https://github.com/data-apis/array-api-compat/pull/293 """ - if bare_xp is sys.modules.get('numpy'): + if bare_xp is sys.modules.get("numpy"): if device not in ("cpu", None): raise ValueError(f"Unsupported device for NumPy: {device!r}") - elif bare_xp is sys.modules.get('dask.array'): + elif bare_xp is sys.modules.get("dask.array"): if device not in ("cpu", _DASK_DEVICE, None): raise ValueError(f"Unsupported device for Dask: {device!r}") @@ -622,18 +689,20 @@ def _check_device(bare_xp, device): # when the array backend is not the CPU. # (since it is not easy to tell which device a dask array is on) class _dask_device: - def __repr__(self): + def __repr__(self) -> Literal["DASK_DEVICE"]: return "DASK_DEVICE" + _DASK_DEVICE = _dask_device() + # device() is not on numpy.ndarray or dask.array and to_device() is not on numpy.ndarray # or cupy.ndarray. They are not included in array objects of this library # because this library just reuses the respective ndarray classes without # wrapping or subclassing them. These helper functions can be used instead of # the wrapper functions for libraries that need to support both NumPy/CuPy and # other libraries that use devices. -def device(x: Array, /) -> Device: +def device(x: _ArrayApiObj, /) -> Device: """ Hardware device the array data resides on. @@ -669,7 +738,7 @@ def device(x: Array, /) -> Device: return "cpu" elif is_dask_array(x): # Peek at the metadata of the Dask array to determine type - if is_numpy_array(x._meta): + if is_numpy_array(x._meta): # pyright: ignore # Must be on CPU since backed by numpy return "cpu" return _DASK_DEVICE @@ -679,7 +748,7 @@ def device(x: Array, /) -> Device: # Return None in this case. Note that this workaround breaks # the standard and will result in new arrays being created on the # default device instead of the same device as the input array(s). - x_device = getattr(x, 'device', None) + x_device = getattr(x, "device", None) # Older JAX releases had .device() as a method, which has been replaced # with a property in accordance with the standard. if inspect.ismethod(x_device): @@ -688,27 +757,34 @@ def device(x: Array, /) -> Device: return x_device elif is_pydata_sparse_array(x): # `sparse` will gain `.device`, so check for this first. - x_device = getattr(x, 'device', None) + x_device = getattr(x, "device", None) if x_device is not None: return x_device # Everything but DOK has this attr. try: - inner = x.data + inner = x.data # pyright: ignore except AttributeError: return "cpu" # Return the device of the constituent array - return device(inner) - return x.device + return device(inner) # pyright: ignore + return x.device # pyright: ignore + # Prevent shadowing, used below _device = device + # Based on cupy.array_api.Array.to_device -def _cupy_to_device(x, device, /, stream=None): - import cupy as cp - from cupy.cuda import Device as _Device - from cupy.cuda import stream as stream_module - from cupy_backends.cuda.api import runtime +def _cupy_to_device( + x: _CupyArray, + device: Device, + /, + stream: int | Any | None = None, +) -> _CupyArray: + import cupy as cp # pyright: ignore[reportMissingTypeStubs] + from cupy.cuda import Device as _Device # pyright: ignore + from cupy.cuda import stream as stream_module # pyright: ignore + from cupy_backends.cuda.api import runtime # pyright: ignore if device == x.device: return x @@ -721,33 +797,40 @@ def _cupy_to_device(x, device, /, stream=None): raise ValueError(f"Unsupported device {device!r}") else: # see cupy/cupy#5985 for the reason how we handle device/stream here - prev_device = runtime.getDevice() - prev_stream: stream_module.Stream = None + prev_device: Any = runtime.getDevice() # pyright: ignore[reportUnknownMemberType] + prev_stream = None if stream is not None: - prev_stream = stream_module.get_current_stream() + prev_stream: Any = stream_module.get_current_stream() # pyright: ignore # stream can be an int as specified in __dlpack__, or a CuPy stream if isinstance(stream, int): - stream = cp.cuda.ExternalStream(stream) - elif isinstance(stream, cp.cuda.Stream): + stream = cp.cuda.ExternalStream(stream) # pyright: ignore + elif isinstance(stream, cp.cuda.Stream): # pyright: ignore[reportUnknownMemberType] pass else: - raise ValueError('the input stream is not recognized') - stream.use() + raise ValueError("the input stream is not recognized") + stream.use() # pyright: ignore[reportUnknownMemberType] try: - runtime.setDevice(device.id) + runtime.setDevice(device.id) # pyright: ignore[reportUnknownMemberType] arr = x.copy() finally: - runtime.setDevice(prev_device) + runtime.setDevice(prev_device) # pyright: ignore[reportUnknownMemberType] if stream is not None: prev_stream.use() return arr -def _torch_to_device(x, device, /, stream=None): + +def _torch_to_device( + x: torch.Tensor, + device: torch.device | str | int, + /, + stream: None = None, +) -> torch.Tensor: if stream is not None: raise NotImplementedError return x.to(device) -def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] = None) -> Array: + +def to_device(x: Array, device: Device, /, *, stream: int | Any | None = None) -> Array: """ Copy the array from the device on which it currently resides to the specified ``device``. @@ -767,7 +850,7 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] a ``device`` object (see the `Device Support `__ section of the array API specification). - stream: Optional[Union[int, Any]] + stream: int | Any | None stream object to use during copy. In addition to the types supported in ``array.__dlpack__``, implementations may choose to support any library-specific stream object with the caveat that any code using @@ -799,25 +882,26 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] if is_numpy_array(x): if stream is not None: raise ValueError("The stream argument to to_device() is not supported") - if device == 'cpu': + if device == "cpu": return x raise ValueError(f"Unsupported device {device!r}") elif is_cupy_array(x): # cupy does not yet have to_device return _cupy_to_device(x, device, stream=stream) elif is_torch_array(x): - return _torch_to_device(x, device, stream=stream) + return _torch_to_device(x, device, stream=stream) # pyright: ignore[reportArgumentType] elif is_dask_array(x): if stream is not None: raise ValueError("The stream argument to to_device() is not supported") # TODO: What if our array is on the GPU already? - if device == 'cpu': + if device == "cpu": return x raise ValueError(f"Unsupported device {device!r}") elif is_jax_array(x): if not hasattr(x, "__array_namespace__"): # In JAX v0.4.31 and older, this import adds to_device method to x... - import jax.experimental.array_api # noqa: F401 + import jax.experimental.array_api # noqa: F401 # pyright: ignore + # ... but only on eager JAX. It won't work inside jax.jit. if not hasattr(x, "to_device"): return x @@ -826,10 +910,16 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] # Perform trivial check to return the same array if # device is same instead of err-ing. return x - return x.to_device(device, stream=stream) + return x.to_device(device, stream=stream) # pyright: ignore -def size(x: Array) -> int | None: +@overload +def size(x: HasShape[Collection[SupportsIndex]]) -> int: ... +@overload +def size(x: HasShape[Collection[None]]) -> None: ... +@overload +def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None: ... +def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None: """ Return the total number of elements of x. @@ -844,7 +934,7 @@ def size(x: Array) -> int | None: # Lazy API compliant arrays, such as ndonnx, can contain None in their shape if None in x.shape: return None - out = math.prod(x.shape) + out = math.prod(cast("Collection[SupportsIndex]", x.shape)) # dask.array.Array.shape can contain NaN return None if math.isnan(out) else out @@ -907,7 +997,7 @@ def is_lazy_array(x: object) -> bool: # on __bool__ (dask is one such example, which however is special-cased above). # Select a single point of the array - s = size(x) + s = size(cast("HasShape[Collection[SupportsIndex | None]]", x)) if s is None: return True xp = array_namespace(x) @@ -952,4 +1042,7 @@ def is_lazy_array(x: object) -> bool: "to_device", ] -_all_ignore = ['sys', 'math', 'inspect', 'warnings'] +_all_ignore = ["sys", "math", "inspect", "warnings"] + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/common/_linalg.py b/array_api_compat/common/_linalg.py index d1e7ebd8..7e002aed 100644 --- a/array_api_compat/common/_linalg.py +++ b/array_api_compat/common/_linalg.py @@ -1,23 +1,33 @@ from __future__ import annotations import math -from typing import Literal, NamedTuple, Optional, Tuple, Union +from typing import Literal, NamedTuple, cast import numpy as np + if np.__version__[0] == "2": from numpy.lib.array_utils import normalize_axis_tuple else: from numpy.core.numeric import normalize_axis_tuple -from ._aliases import matmul, matrix_transpose, tensordot, vecdot, isdtype from .._internal import get_xp -from ._typing import Array, Namespace +from ._aliases import isdtype, matmul, matrix_transpose, tensordot, vecdot +from ._typing import Array, DType, Namespace + # These are in the main NumPy namespace but not in numpy.linalg -def cross(x1: Array, x2: Array, /, xp: Namespace, *, axis: int = -1, **kwargs) -> Array: +def cross( + x1: Array, + x2: Array, + /, + xp: Namespace, + *, + axis: int = -1, + **kwargs: object, +) -> Array: return xp.cross(x1, x2, axis=axis, **kwargs) -def outer(x1: Array, x2: Array, /, xp: Namespace, **kwargs) -> Array: +def outer(x1: Array, x2: Array, /, xp: Namespace, **kwargs: object) -> Array: return xp.outer(x1, x2, **kwargs) class EighResult(NamedTuple): @@ -39,46 +49,66 @@ class SVDResult(NamedTuple): # These functions are the same as their NumPy counterparts except they return # a namedtuple. -def eigh(x: Array, /, xp: Namespace, **kwargs) -> EighResult: +def eigh(x: Array, /, xp: Namespace, **kwargs: object) -> EighResult: return EighResult(*xp.linalg.eigh(x, **kwargs)) -def qr(x: Array, /, xp: Namespace, *, mode: Literal['reduced', 'complete'] = 'reduced', - **kwargs) -> QRResult: +def qr( + x: Array, + /, + xp: Namespace, + *, + mode: Literal["reduced", "complete"] = "reduced", + **kwargs: object, +) -> QRResult: return QRResult(*xp.linalg.qr(x, mode=mode, **kwargs)) -def slogdet(x: Array, /, xp: Namespace, **kwargs) -> SlogdetResult: +def slogdet(x: Array, /, xp: Namespace, **kwargs: object) -> SlogdetResult: return SlogdetResult(*xp.linalg.slogdet(x, **kwargs)) def svd( - x: Array, /, xp: Namespace, *, full_matrices: bool = True, **kwargs + x: Array, + /, + xp: Namespace, + *, + full_matrices: bool = True, + **kwargs: object, ) -> SVDResult: return SVDResult(*xp.linalg.svd(x, full_matrices=full_matrices, **kwargs)) # These functions have additional keyword arguments # The upper keyword argument is new from NumPy -def cholesky(x: Array, /, xp: Namespace, *, upper: bool = False, **kwargs) -> Array: +def cholesky( + x: Array, + /, + xp: Namespace, + *, + upper: bool = False, + **kwargs: object, +) -> Array: L = xp.linalg.cholesky(x, **kwargs) if upper: U = get_xp(xp)(matrix_transpose)(L) if get_xp(xp)(isdtype)(U.dtype, 'complex floating'): - U = xp.conj(U) + U = xp.conj(U) # pyright: ignore[reportConstantRedefinition] return U return L # The rtol keyword argument of matrix_rank() and pinv() is new from NumPy. # Note that it has a different semantic meaning from tol and rcond. -def matrix_rank(x: Array, - /, - xp: Namespace, - *, - rtol: Optional[Union[float, Array]] = None, - **kwargs) -> Array: +def matrix_rank( + x: Array, + /, + xp: Namespace, + *, + rtol: float | Array | None = None, + **kwargs: object, +) -> Array: # this is different from xp.linalg.matrix_rank, which supports 1 # dimensional arrays. if x.ndim < 2: raise xp.linalg.LinAlgError("1-dimensional array given. Array must be at least two-dimensional") - S = get_xp(xp)(svdvals)(x, **kwargs) + S: Array = get_xp(xp)(svdvals)(x, **kwargs) if rtol is None: tol = S.max(axis=-1, keepdims=True) * max(x.shape[-2:]) * xp.finfo(S.dtype).eps else: @@ -88,7 +118,12 @@ def matrix_rank(x: Array, return xp.count_nonzero(S > tol, axis=-1) def pinv( - x: Array, /, xp: Namespace, *, rtol: Optional[Union[float, Array]] = None, **kwargs + x: Array, + /, + xp: Namespace, + *, + rtol: float | Array | None = None, + **kwargs: object, ) -> Array: # this is different from xp.linalg.pinv, which does not multiply the # default tolerance by max(M, N). @@ -104,13 +139,13 @@ def matrix_norm( xp: Namespace, *, keepdims: bool = False, - ord: Optional[Union[int, float, Literal['fro', 'nuc']]] = 'fro', + ord: float | Literal["fro", "nuc"] | None = "fro", ) -> Array: return xp.linalg.norm(x, axis=(-2, -1), keepdims=keepdims, ord=ord) # svdvals is not in NumPy (but it is in SciPy). It is equivalent to # xp.linalg.svd(compute_uv=False). -def svdvals(x: Array, /, xp: Namespace) -> Union[Array, Tuple[Array, ...]]: +def svdvals(x: Array, /, xp: Namespace) -> Array | tuple[Array, ...]: return xp.linalg.svd(x, compute_uv=False) def vector_norm( @@ -118,9 +153,9 @@ def vector_norm( /, xp: Namespace, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, + axis: int | tuple[int, ...] | None = None, keepdims: bool = False, - ord: Optional[Union[int, float]] = 2, + ord: float = 2, ) -> Array: # xp.linalg.norm tries to do a matrix norm whenever axis is a 2-tuple or # when axis=None and the input is 2-D, so to force a vector norm, we make @@ -133,7 +168,10 @@ def vector_norm( elif isinstance(axis, tuple): # Note: The axis argument supports any number of axes, whereas # xp.linalg.norm() only supports a single axis for vector norm. - normalized_axis = normalize_axis_tuple(axis, x.ndim) + normalized_axis = cast( + "tuple[int, ...]", + normalize_axis_tuple(axis, x.ndim), # pyright: ignore[reportCallIssue] + ) rest = tuple(i for i in range(x.ndim) if i not in normalized_axis) newshape = axis + rest _x = xp.transpose(x, newshape).reshape( @@ -149,7 +187,13 @@ def vector_norm( # We can't reuse xp.linalg.norm(keepdims) because of the reshape hacks # above to avoid matrix norm logic. shape = list(x.shape) - _axis = normalize_axis_tuple(range(x.ndim) if axis is None else axis, x.ndim) + _axis = cast( + "tuple[int, ...]", + normalize_axis_tuple( # pyright: ignore[reportCallIssue] + range(x.ndim) if axis is None else axis, + x.ndim, + ), + ) for i in _axis: shape[i] = 1 res = xp.reshape(res, tuple(shape)) @@ -159,11 +203,17 @@ def vector_norm( # xp.diagonal and xp.trace operate on the first two axes whereas these # operates on the last two -def diagonal(x: Array, /, xp: Namespace, *, offset: int = 0, **kwargs) -> Array: +def diagonal(x: Array, /, xp: Namespace, *, offset: int = 0, **kwargs: object) -> Array: return xp.diagonal(x, offset=offset, axis1=-2, axis2=-1, **kwargs) def trace( - x: Array, /, xp: Namespace, *, offset: int = 0, dtype=None, **kwargs + x: Array, + /, + xp: Namespace, + *, + offset: int = 0, + dtype: DType | None = None, + **kwargs: object, ) -> Array: return xp.asarray( xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs) @@ -176,3 +226,7 @@ def trace( 'trace'] _all_ignore = ['math', 'normalize_axis_tuple', 'get_xp', 'np', 'isdtype'] + + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/common/_typing.py b/array_api_compat/common/_typing.py index 4c3b356b..d7deade1 100644 --- a/array_api_compat/common/_typing.py +++ b/array_api_compat/common/_typing.py @@ -1,24 +1,150 @@ from __future__ import annotations + +from collections.abc import Mapping from types import ModuleType as Namespace -from typing import Any, TypeVar, Protocol +from typing import TYPE_CHECKING, Literal, Protocol, TypeAlias, TypedDict, TypeVar + +if TYPE_CHECKING: + from _typeshed import Incomplete + + SupportsBufferProtocol: TypeAlias = Incomplete + Array: TypeAlias = Incomplete + Device: TypeAlias = Incomplete + DType: TypeAlias = Incomplete +else: + SupportsBufferProtocol = object + Array = object + Device = object + DType = object + + +_T_co = TypeVar("_T_co", covariant=True) + + +class NestedSequence(Protocol[_T_co]): + def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ... + def __len__(self, /) -> int: ... + + +class SupportsArrayNamespace(Protocol[_T_co]): + def __array_namespace__(self, /, *, api_version: str | None) -> _T_co: ... + + +class HasShape(Protocol[_T_co]): + @property + def shape(self, /) -> _T_co: ... + + +# Return type of `__array_namespace_info__.default_dtypes` +Capabilities = TypedDict( + "Capabilities", + { + "boolean indexing": bool, + "data-dependent shapes": bool, + "max dimensions": int, + }, +) + +# Return type of `__array_namespace_info__.default_dtypes` +DefaultDTypes = TypedDict( + "DefaultDTypes", + { + "real floating": DType, + "complex floating": DType, + "integral": DType, + "indexing": DType, + }, +) + + +_DTypeKind: TypeAlias = Literal[ + "bool", + "signed integer", + "unsigned integer", + "integral", + "real floating", + "complex floating", + "numeric", +] +# Type of the `kind` parameter in `__array_namespace_info__.dtypes` +DTypeKind: TypeAlias = _DTypeKind | tuple[_DTypeKind, ...] + + +# `__array_namespace_info__.dtypes(kind="bool")` +class DTypesBool(TypedDict): + bool: DType + + +# `__array_namespace_info__.dtypes(kind="signed integer")` +class DTypesSigned(TypedDict): + int8: DType + int16: DType + int32: DType + int64: DType + + +# `__array_namespace_info__.dtypes(kind="unsigned integer")` +class DTypesUnsigned(TypedDict): + uint8: DType + uint16: DType + uint32: DType + uint64: DType + + +# `__array_namespace_info__.dtypes(kind="integral")` +class DTypesIntegral(DTypesSigned, DTypesUnsigned): + pass + + +# `__array_namespace_info__.dtypes(kind="real floating")` +class DTypesReal(TypedDict): + float32: DType + float64: DType + + +# `__array_namespace_info__.dtypes(kind="complex floating")` +class DTypesComplex(TypedDict): + complex64: DType + complex128: DType + + +# `__array_namespace_info__.dtypes(kind="numeric")` +class DTypesNumeric(DTypesIntegral, DTypesReal, DTypesComplex): + pass + + +# `__array_namespace_info__.dtypes(kind=None)` (default) +class DTypesAll(DTypesBool, DTypesNumeric): + pass + + +# `__array_namespace_info__.dtypes(kind=?)` (fallback) +DTypesAny: TypeAlias = Mapping[str, DType] + __all__ = [ "Array", + "Capabilities", "DType", + "DTypeKind", + "DTypesAny", + "DTypesAll", + "DTypesBool", + "DTypesNumeric", + "DTypesIntegral", + "DTypesSigned", + "DTypesUnsigned", + "DTypesReal", + "DTypesComplex", + "DefaultDTypes", "Device", + "HasShape", "Namespace", "NestedSequence", + "SupportsArrayNamespace", "SupportsBufferProtocol", ] -_T_co = TypeVar("_T_co", covariant=True) - -class NestedSequence(Protocol[_T_co]): - def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ... - def __len__(self, /) -> int: ... - -SupportsBufferProtocol = Any -Array = Any -Device = Any -DType = Any +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/dask/array/__init__.py b/array_api_compat/dask/array/__init__.py index bb649306..1e47b960 100644 --- a/array_api_compat/dask/array/__init__.py +++ b/array_api_compat/dask/array/__init__.py @@ -1,9 +1,11 @@ -from dask.array import * # noqa: F403 +from typing import Final + +from dask.array import * # noqa: F403 # These imports may overwrite names from the import * above. -from ._aliases import * # noqa: F403 +from ._aliases import * # noqa: F403 -__array_api_version__ = '2024.12' +__array_api_version__: Final = "2024.12" # See the comment in the numpy __init__.py __import__(__package__ + '.linalg') diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index e7ddde78..9687a9cd 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -1,28 +1,38 @@ +# pyright: reportPrivateUsage=false +# pyright: reportUnknownArgumentType=false +# pyright: reportUnknownMemberType=false +# pyright: reportUnknownVariableType=false + from __future__ import annotations -from typing import Callable, Optional, Union +from builtins import bool as py_bool +from collections.abc import Callable +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from typing_extensions import TypeIs +import dask.array as da import numpy as np +from numpy import bool_ as bool from numpy import ( - # dtypes - bool_ as bool, + can_cast, + complex64, + complex128, float32, float64, int8, int16, int32, int64, + result_type, uint8, uint16, uint32, uint64, - complex64, - complex128, - can_cast, - result_type, ) -import dask.array as da +from ..._internal import get_xp from ...common import _aliases, _helpers, array_namespace from ...common._typing import ( Array, @@ -31,7 +41,6 @@ NestedSequence, SupportsBufferProtocol, ) -from ..._internal import get_xp from ._info import __array_namespace_info__ isdtype = get_xp(np)(_aliases.isdtype) @@ -44,8 +53,8 @@ def astype( dtype: DType, /, *, - copy: bool = True, - device: Optional[Device] = None, + copy: py_bool = True, + device: Device | None = None, ) -> Array: """ Array API compatibility wrapper for astype(). @@ -69,14 +78,14 @@ def astype( # not pass stop/step as keyword arguments, which will cause # an error with dask def arange( - start: Union[int, float], + start: float, /, - stop: Optional[Union[int, float]] = None, - step: Union[int, float] = 1, + stop: float | None = None, + step: float = 1, *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, ) -> Array: """ Array API compatibility wrapper for arange(). @@ -87,7 +96,7 @@ def arange( # TODO: respect device keyword? _helpers._check_device(da, device) - args = [start] + args: list[Any] = [start] if stop is not None: args.append(stop) else: @@ -137,18 +146,13 @@ def arange( # asarray also adds the copy keyword, which is not present in numpy 1.0. def asarray( - obj: ( - Array - | bool | int | float | complex - | NestedSequence[bool | int | float | complex] - | SupportsBufferProtocol - ), + obj: complex | NestedSequence[complex] | Array | SupportsBufferProtocol, /, *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - copy: Optional[bool] = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + copy: py_bool | None = None, + **kwargs: object, ) -> Array: """ Array API compatibility wrapper for asarray(). @@ -164,7 +168,7 @@ def asarray( if copy is False: raise ValueError("Unable to avoid copy when changing dtype") obj = obj.astype(dtype) - return obj.copy() if copy else obj + return obj.copy() if copy else obj # pyright: ignore[reportAttributeAccessIssue] if copy is False: raise NotImplementedError( @@ -177,22 +181,21 @@ def asarray( return da.from_array(obj) -from dask.array import ( - # Element wise aliases - arccos as acos, - arccosh as acosh, - arcsin as asin, - arcsinh as asinh, - arctan as atan, - arctan2 as atan2, - arctanh as atanh, - left_shift as bitwise_left_shift, - right_shift as bitwise_right_shift, - invert as bitwise_invert, - power as pow, - # Other - concatenate as concat, -) +# Element wise aliases +from dask.array import arccos as acos +from dask.array import arccosh as acosh +from dask.array import arcsin as asin +from dask.array import arcsinh as asinh +from dask.array import arctan as atan +from dask.array import arctan2 as atan2 +from dask.array import arctanh as atanh + +# Other +from dask.array import concatenate as concat +from dask.array import invert as bitwise_invert +from dask.array import left_shift as bitwise_left_shift +from dask.array import power as pow +from dask.array import right_shift as bitwise_right_shift # dask.array.clip does not work unless all three arguments are provided. @@ -202,8 +205,8 @@ def asarray( def clip( x: Array, /, - min: Optional[Union[int, float, Array]] = None, - max: Optional[Union[int, float, Array]] = None, + min: float | Array | None = None, + max: float | Array | None = None, ) -> Array: """ Array API compatibility wrapper for clip(). @@ -212,8 +215,8 @@ def clip( specification for more details. """ - def _isscalar(a): - return isinstance(a, (int, float, type(None))) + def _isscalar(a: float | Array | None, /) -> TypeIs[float | None]: + return a is None or isinstance(a, (int, float)) min_shape = () if _isscalar(min) else min.shape max_shape = () if _isscalar(max) else max.shape @@ -266,7 +269,12 @@ def _ensure_single_chunk(x: Array, axis: int) -> tuple[Array, Callable[[Array], def sort( - x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True + x: Array, + /, + *, + axis: int = -1, + descending: py_bool = False, + stable: py_bool = True, ) -> Array: """ Array API compatibility layer around the lack of sort() in Dask. @@ -296,7 +304,12 @@ def sort( def argsort( - x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True + x: Array, + /, + *, + axis: int = -1, + descending: py_bool = False, + stable: py_bool = True, ) -> Array: """ Array API compatibility layer around the lack of argsort() in Dask. @@ -330,25 +343,34 @@ def argsort( # dask.array.count_nonzero does not have keepdims def count_nonzero( x: Array, - axis=None, - keepdims=False + axis: int | None = None, + keepdims: py_bool = False, ) -> Array: - result = da.count_nonzero(x, axis) - if keepdims: - if axis is None: - return da.reshape(result, [1]*x.ndim) - return da.expand_dims(result, axis) - return result - - + result = da.count_nonzero(x, axis) + if keepdims: + if axis is None: + return da.reshape(result, [1] * x.ndim) + return da.expand_dims(result, axis) + return result + + +__all__ = [ + "__array_namespace_info__", + "count_nonzero", + "bool", + "int8", "int16", "int32", "int64", + "uint8", "uint16", "uint32", "uint64", + "float32", "float64", + "complex64", "complex128", + "asarray", "astype", "can_cast", "result_type", + "pow", + "concat", + "acos", "acosh", "asin", "asinh", "atan", "atan2", "atanh", + "bitwise_left_shift", "bitwise_right_shift", "bitwise_invert", +] # fmt: skip +__all__ += _aliases.__all__ +_all_ignore = ["array_namespace", "get_xp", "da", "np"] -__all__ = _aliases.__all__ + [ - '__array_namespace_info__', 'asarray', 'astype', 'acos', - 'acosh', 'asin', 'asinh', 'atan', 'atan2', - 'atanh', 'bitwise_left_shift', 'bitwise_invert', - 'bitwise_right_shift', 'concat', 'pow', 'can_cast', - 'result_type', 'bool', 'float32', 'float64', 'int8', 'int16', 'int32', 'int64', - 'uint8', 'uint16', 'uint32', 'uint64', 'complex64', 'complex128', - 'can_cast', 'count_nonzero', 'result_type'] -_all_ignore = ["array_namespace", "get_xp", "da", "np"] +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/dask/array/_info.py b/array_api_compat/dask/array/_info.py index 614f43d9..9e4d736f 100644 --- a/array_api_compat/dask/array/_info.py +++ b/array_api_compat/dask/array/_info.py @@ -7,25 +7,51 @@ more details. """ + +# pyright: reportPrivateUsage=false + +from __future__ import annotations + +from typing import Literal as L +from typing import TypeAlias, overload + +from numpy import bool_ as bool from numpy import ( + complex64, + complex128, dtype, - bool_ as bool, - intp, + float32, + float64, int8, int16, int32, int64, + intp, uint8, uint16, uint32, uint64, - float32, - float64, - complex64, - complex128, ) -from ...common._helpers import _DASK_DEVICE +from ...common._helpers import _DASK_DEVICE, _dask_device +from ...common._typing import ( + Capabilities, + DefaultDTypes, + DType, + DTypeKind, + DTypesAll, + DTypesAny, + DTypesBool, + DTypesComplex, + DTypesIntegral, + DTypesNumeric, + DTypesReal, + DTypesSigned, + DTypesUnsigned, +) + +_Device: TypeAlias = L["cpu"] | _dask_device + class __array_namespace_info__: """ @@ -59,9 +85,9 @@ class __array_namespace_info__: """ - __module__ = 'dask.array' + __module__ = "dask.array" - def capabilities(self): + def capabilities(self) -> Capabilities: """ Return a dictionary of array API library capabilities. @@ -116,7 +142,7 @@ def capabilities(self): "max dimensions": 64, } - def default_device(self): + def default_device(self) -> L["cpu"]: """ The default device used for new Dask arrays. @@ -143,7 +169,7 @@ def default_device(self): """ return "cpu" - def default_dtypes(self, *, device=None): + def default_dtypes(self, /, *, device: _Device | None = None) -> DefaultDTypes: """ The default data types used for new Dask arrays. @@ -184,8 +210,8 @@ def default_dtypes(self, *, device=None): """ if device not in ["cpu", _DASK_DEVICE, None]: raise ValueError( - 'Device not understood. Only "cpu" or _DASK_DEVICE is allowed, but received:' - f' {device}' + f'Device not understood. Only "cpu" or _DASK_DEVICE is allowed, ' + f"but received: {device!r}" ) return { "real floating": dtype(float64), @@ -194,7 +220,41 @@ def default_dtypes(self, *, device=None): "indexing": dtype(intp), } - def dtypes(self, *, device=None, kind=None): + @overload + def dtypes( + self, /, *, device: _Device | None = None, kind: None = None + ) -> DTypesAll: ... + @overload + def dtypes( + self, /, *, device: _Device | None = None, kind: L["bool"] + ) -> DTypesBool: ... + @overload + def dtypes( + self, /, *, device: _Device | None = None, kind: L["signed integer"] + ) -> DTypesSigned: ... + @overload + def dtypes( + self, /, *, device: _Device | None = None, kind: L["unsigned integer"] + ) -> DTypesUnsigned: ... + @overload + def dtypes( + self, /, *, device: _Device | None = None, kind: L["integral"] + ) -> DTypesIntegral: ... + @overload + def dtypes( + self, /, *, device: _Device | None = None, kind: L["real floating"] + ) -> DTypesReal: ... + @overload + def dtypes( + self, /, *, device: _Device | None = None, kind: L["complex floating"] + ) -> DTypesComplex: ... + @overload + def dtypes( + self, /, *, device: _Device | None = None, kind: L["numeric"] + ) -> DTypesNumeric: ... + def dtypes( + self, /, *, device: _Device | None = None, kind: DTypeKind | None = None + ) -> DTypesAny: """ The array API data types supported by Dask. @@ -251,7 +311,7 @@ def dtypes(self, *, device=None, kind=None): if device not in ["cpu", _DASK_DEVICE, None]: raise ValueError( 'Device not understood. Only "cpu" or _DASK_DEVICE is allowed, but received:' - f' {device}' + f" {device}" ) if kind is None: return { @@ -321,14 +381,14 @@ def dtypes(self, *, device=None, kind=None): "complex64": dtype(complex64), "complex128": dtype(complex128), } - if isinstance(kind, tuple): - res = {} + if isinstance(kind, tuple): # type: ignore[reportUnnecessaryIsinstanceCall] + res: dict[str, DType] = {} for k in kind: res.update(self.dtypes(kind=k)) return res raise ValueError(f"unsupported kind: {kind!r}") - def devices(self): + def devices(self) -> list[_Device]: """ The devices supported by Dask. diff --git a/array_api_compat/dask/array/linalg.py b/array_api_compat/dask/array/linalg.py index bd53f0df..0825386e 100644 --- a/array_api_compat/dask/array/linalg.py +++ b/array_api_compat/dask/array/linalg.py @@ -3,15 +3,16 @@ from typing import Literal import dask.array as da + +# The `matmul` and `tensordot` functions are in both the main and linalg namespaces +from dask.array import matmul, outer, tensordot + # Exports -from dask.array.linalg import * # noqa: F403 -from dask.array import outer -# These functions are in both the main and linalg namespaces -from dask.array import matmul, tensordot +from dask.array.linalg import * # noqa: F403 from ..._internal import get_xp from ...common import _linalg -from ...common._typing import Array +from ...common._typing import Array as _Array from ._aliases import matrix_transpose, vecdot # dask.array.linalg doesn't have __all__. If it is added, replace this with @@ -32,8 +33,11 @@ # supports the mode keyword on QR # https://github.com/dask/dask/issues/10388 #qr = get_xp(da)(_linalg.qr) -def qr(x: Array, mode: Literal['reduced', 'complete'] = 'reduced', - **kwargs) -> QRResult: +def qr( + x: _Array, + mode: Literal["reduced", "complete"] = "reduced", + **kwargs: object, +) -> QRResult: if mode != "reduced": raise ValueError("dask arrays only support using mode='reduced'") return QRResult(*da.linalg.qr(x, **kwargs)) @@ -46,12 +50,12 @@ def qr(x: Array, mode: Literal['reduced', 'complete'] = 'reduced', # Wrap the svd functions to not pass full_matrices to dask # when full_matrices=False (as that is the default behavior for dask), # and dask doesn't have the full_matrices keyword -def svd(x: Array, full_matrices: bool = True, **kwargs) -> SVDResult: +def svd(x: _Array, full_matrices: bool = True, **kwargs) -> SVDResult: if full_matrices: raise ValueError("full_matrics=True is not supported by dask.") return da.linalg.svd(x, coerce_signs=False, **kwargs) -def svdvals(x: Array) -> Array: +def svdvals(x: _Array) -> _Array: # TODO: can't avoid computing U or V for dask _, s, _ = svd(x) return s diff --git a/array_api_compat/numpy/__init__.py b/array_api_compat/numpy/__init__.py index 6a5d9867..f7b558ba 100644 --- a/array_api_compat/numpy/__init__.py +++ b/array_api_compat/numpy/__init__.py @@ -1,10 +1,16 @@ -from numpy import * # noqa: F403 +# ruff: noqa: PLC0414 +from typing import Final + +from numpy import * # noqa: F403 # pyright: ignore[reportWildcardImportFromLibrary] # from numpy import * doesn't overwrite these builtin names -from numpy import abs, max, min, round # noqa: F401 +from numpy import abs as abs +from numpy import max as max +from numpy import min as min +from numpy import round as round # These imports may overwrite names from the import * above. -from ._aliases import * # noqa: F403 +from ._aliases import * # noqa: F403 # Don't know why, but we have to do an absolute import to import linalg. If we # instead do @@ -13,9 +19,17 @@ # # It doesn't overwrite np.linalg from above. The import is generated # dynamically so that the library can be vendored. -__import__(__package__ + '.linalg') -__import__(__package__ + '.fft') +__import__(__package__ + ".linalg") + +__import__(__package__ + ".fft") + +from ..common._helpers import * # noqa: F403 +from .linalg import matrix_transpose, vecdot # noqa: F401 -from .linalg import matrix_transpose, vecdot # noqa: F401 +try: + # Used in asarray(). Not present in older versions. + from numpy import _CopyMode # noqa: F401 +except ImportError: + pass -__array_api_version__ = '2024.12' +__array_api_version__: Final = "2024.12" diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index d1fd46a1..d8792611 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -1,6 +1,10 @@ +# pyright: reportPrivateUsage=false from __future__ import annotations -from typing import Optional, Union +from builtins import bool as py_bool +from typing import TYPE_CHECKING, Any, Literal, TypeAlias, cast + +import numpy as np from .._internal import get_xp from ..common import _aliases, _helpers @@ -8,7 +12,12 @@ from ._info import __array_namespace_info__ from ._typing import Array, Device, DType -import numpy as np +if TYPE_CHECKING: + from typing_extensions import Buffer, TypeIs + +# The values of the `_CopyMode` enum can be either `False`, `True`, or `2`: +# https://github.com/numpy/numpy/blob/5a8a6a79d9c2fff8f07dcab5d41e14f8508d673f/numpy/_globals.pyi#L7-L10 +_Copy: TypeAlias = py_bool | Literal[2] | np._CopyMode bool = np.bool_ @@ -65,9 +74,9 @@ iinfo = get_xp(np)(_aliases.iinfo) -def _supports_buffer_protocol(obj): +def _supports_buffer_protocol(obj: object) -> TypeIs[Buffer]: # pyright: ignore[reportUnusedFunction] try: - memoryview(obj) + memoryview(obj) # pyright: ignore[reportArgumentType] except TypeError: return False return True @@ -78,18 +87,13 @@ def _supports_buffer_protocol(obj): # complicated enough that it's easier to define it separately for each module # rather than trying to combine everything into one function in common/ def asarray( - obj: ( - Array - | bool | int | float | complex - | NestedSequence[bool | int | float | complex] - | SupportsBufferProtocol - ), + obj: Array | complex | NestedSequence[complex] | SupportsBufferProtocol, /, *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - copy: Optional[Union[bool, np._CopyMode]] = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + copy: _Copy | None = None, + **kwargs: Any, ) -> Array: """ Array API compatibility wrapper for asarray(). @@ -106,7 +110,7 @@ def asarray( elif copy is True: copy = np._CopyMode.ALWAYS - return np.array(obj, copy=copy, dtype=dtype, **kwargs) + return np.array(obj, copy=copy, dtype=dtype, **kwargs) # pyright: ignore def astype( @@ -114,8 +118,8 @@ def astype( dtype: DType, /, *, - copy: bool = True, - device: Optional[Device] = None, + copy: py_bool = True, + device: Device | None = None, ) -> Array: _helpers._check_device(np, device) return x.astype(dtype=dtype, copy=copy) @@ -123,8 +127,14 @@ def astype( # count_nonzero returns a python int for axis=None and keepdims=False # https://github.com/numpy/numpy/issues/17562 -def count_nonzero(x: Array, axis=None, keepdims=False) -> Array: - result = np.count_nonzero(x, axis=axis, keepdims=keepdims) +def count_nonzero( + x: Array, + axis: int | tuple[int, ...] | None = None, + keepdims: py_bool = False, +) -> Array: + # NOTE: this is currently incorrectly typed in numpy, but will be fixed in + # numpy 2.2.5 and 2.3.0: https://github.com/numpy/numpy/pull/28750 + result = cast("Any", np.count_nonzero(x, axis=axis, keepdims=keepdims)) # pyright: ignore[reportArgumentType, reportCallIssue] if axis is None and not keepdims: return np.asarray(result) return result @@ -132,25 +142,43 @@ def count_nonzero(x: Array, axis=None, keepdims=False) -> Array: # These functions are completely new here. If the library already has them # (i.e., numpy 2.0), use the library version instead of our wrapper. -if hasattr(np, 'vecdot'): +if hasattr(np, "vecdot"): vecdot = np.vecdot else: vecdot = get_xp(np)(_aliases.vecdot) -if hasattr(np, 'isdtype'): +if hasattr(np, "isdtype"): isdtype = np.isdtype else: isdtype = get_xp(np)(_aliases.isdtype) -if hasattr(np, 'unstack'): +if hasattr(np, "unstack"): unstack = np.unstack else: unstack = get_xp(np)(_aliases.unstack) -__all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', 'astype', - 'acos', 'acosh', 'asin', 'asinh', 'atan', - 'atan2', 'atanh', 'bitwise_left_shift', - 'bitwise_invert', 'bitwise_right_shift', - 'bool', 'concat', 'count_nonzero', 'pow'] - -_all_ignore = ['np', 'get_xp'] +__all__ = [ + "__array_namespace_info__", + "asarray", + "astype", + "acos", + "acosh", + "asin", + "asinh", + "atan", + "atan2", + "atanh", + "bitwise_left_shift", + "bitwise_invert", + "bitwise_right_shift", + "bool", + "concat", + "count_nonzero", + "pow", +] +__all__ += _aliases.__all__ +_all_ignore = ["np", "get_xp"] + + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/numpy/_info.py b/array_api_compat/numpy/_info.py index 365855b8..f307f62c 100644 --- a/array_api_compat/numpy/_info.py +++ b/array_api_compat/numpy/_info.py @@ -7,24 +7,28 @@ more details. """ +from __future__ import annotations + +from numpy import bool_ as bool from numpy import ( + complex64, + complex128, dtype, - bool_ as bool, - intp, + float32, + float64, int8, int16, int32, int64, + intp, uint8, uint16, uint32, uint64, - float32, - float64, - complex64, - complex128, ) +from ._typing import Device, DType + class __array_namespace_info__: """ @@ -131,7 +135,11 @@ def default_device(self): """ return "cpu" - def default_dtypes(self, *, device=None): + def default_dtypes( + self, + *, + device: Device | None = None, + ) -> dict[str, dtype[intp | float64 | complex128]]: """ The default data types used for new NumPy arrays. @@ -183,7 +191,12 @@ def default_dtypes(self, *, device=None): "indexing": dtype(intp), } - def dtypes(self, *, device=None, kind=None): + def dtypes( + self, + *, + device: Device | None = None, + kind: str | tuple[str, ...] | None = None, + ) -> dict[str, DType]: """ The array API data types supported by NumPy. @@ -260,7 +273,7 @@ def dtypes(self, *, device=None, kind=None): "complex128": dtype(complex128), } if kind == "bool": - return {"bool": bool} + return {"bool": dtype(bool)} if kind == "signed integer": return { "int8": dtype(int8), @@ -312,13 +325,13 @@ def dtypes(self, *, device=None, kind=None): "complex128": dtype(complex128), } if isinstance(kind, tuple): - res = {} + res: dict[str, DType] = {} for k in kind: res.update(self.dtypes(kind=k)) return res raise ValueError(f"unsupported kind: {kind!r}") - def devices(self): + def devices(self) -> list[Device]: """ The devices supported by NumPy. @@ -344,3 +357,10 @@ def devices(self): """ return ["cpu"] + + +__all__ = ["__array_namespace_info__"] + + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/numpy/_typing.py b/array_api_compat/numpy/_typing.py index a6c96924..e771c788 100644 --- a/array_api_compat/numpy/_typing.py +++ b/array_api_compat/numpy/_typing.py @@ -1,31 +1,30 @@ from __future__ import annotations -__all__ = ["Array", "DType", "Device"] -_all_ignore = ["np"] - -from typing import Literal, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Literal, TypeAlias import numpy as np -from numpy import ndarray as Array -Device = Literal["cpu"] +Device: TypeAlias = Literal["cpu"] + if TYPE_CHECKING: + # NumPy 1.x on Python 3.10 fails to parse np.dtype[] - DType = np.dtype[ - np.intp - | np.int8 - | np.int16 - | np.int32 - | np.int64 - | np.uint8 - | np.uint16 - | np.uint32 - | np.uint64 + DType: TypeAlias = np.dtype[ + np.bool_ + | np.integer[Any] | np.float32 | np.float64 | np.complex64 | np.complex128 - | np.bool ] + Array: TypeAlias = np.ndarray[Any, DType] else: - DType = np.dtype + DType: TypeAlias = np.dtype + Array: TypeAlias = np.ndarray + +__all__ = ["Array", "DType", "Device"] +_all_ignore = ["np"] + + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/numpy/fft.py b/array_api_compat/numpy/fft.py index 28667594..06875f00 100644 --- a/array_api_compat/numpy/fft.py +++ b/array_api_compat/numpy/fft.py @@ -1,10 +1,9 @@ -from numpy.fft import * # noqa: F403 +import numpy as np from numpy.fft import __all__ as fft_all +from numpy.fft import fft2, ifft2, irfft2, rfft2 -from ..common import _fft from .._internal import get_xp - -import numpy as np +from ..common import _fft fft = get_xp(np)(_fft.fft) ifft = get_xp(np)(_fft.ifft) @@ -21,7 +20,14 @@ fftshift = get_xp(np)(_fft.fftshift) ifftshift = get_xp(np)(_fft.ifftshift) -__all__ = fft_all + _fft.__all__ + +__all__ = ["rfft2", "irfft2", "fft2", "ifft2"] +__all__ += _fft.__all__ + + +def __dir__() -> list[str]: + return __all__ + del get_xp del np diff --git a/array_api_compat/numpy/linalg.py b/array_api_compat/numpy/linalg.py index 8f01593b..2d3e731d 100644 --- a/array_api_compat/numpy/linalg.py +++ b/array_api_compat/numpy/linalg.py @@ -1,14 +1,35 @@ -from numpy.linalg import * # noqa: F403 -from numpy.linalg import __all__ as linalg_all -import numpy as _np +# pyright: reportAttributeAccessIssue=false +# pyright: reportUnknownArgumentType=false +# pyright: reportUnknownMemberType=false +# pyright: reportUnknownVariableType=false + +from __future__ import annotations + +import numpy as np + +# intersection of `np.linalg.__all__` on numpy 1.22 and 2.2, minus `_linalg.__all__` +from numpy.linalg import ( + LinAlgError, + cond, + det, + eig, + eigvals, + eigvalsh, + inv, + lstsq, + matrix_power, + multi_dot, + norm, + tensorinv, + tensorsolve, +) -from ..common import _linalg from .._internal import get_xp +from ..common import _linalg # These functions are in both the main and linalg namespaces -from ._aliases import matmul, matrix_transpose, tensordot, vecdot # noqa: F401 - -import numpy as np +from ._aliases import matmul, matrix_transpose, tensordot, vecdot # noqa: F401 +from ._typing import Array cross = get_xp(np)(_linalg.cross) outer = get_xp(np)(_linalg.outer) @@ -38,19 +59,28 @@ # To workaround this, the below is the code from np.linalg.solve except # only calling solve1 in the exactly 1D case. + # This code is here instead of in common because it is numpy specific. Also # note that CuPy's solve() does not currently support broadcasting (see # https://github.com/cupy/cupy/blob/main/cupy/cublas.py#L43). -def solve(x1: _np.ndarray, x2: _np.ndarray, /) -> _np.ndarray: +def solve(x1: Array, x2: Array, /) -> Array: try: from numpy.linalg._linalg import ( - _makearray, _assert_stacked_2d, _assert_stacked_square, - _commonType, isComplexType, _raise_linalgerror_singular + _assert_stacked_2d, + _assert_stacked_square, + _commonType, + _makearray, + _raise_linalgerror_singular, + isComplexType, ) except ImportError: from numpy.linalg.linalg import ( - _makearray, _assert_stacked_2d, _assert_stacked_square, - _commonType, isComplexType, _raise_linalgerror_singular + _assert_stacked_2d, + _assert_stacked_square, + _commonType, + _makearray, + _raise_linalgerror_singular, + isComplexType, ) from numpy.linalg import _umath_linalg @@ -61,6 +91,7 @@ def solve(x1: _np.ndarray, x2: _np.ndarray, /) -> _np.ndarray: t, result_t = _commonType(x1, x2) # This part is different from np.linalg.solve + gufunc: np.ufunc if x2.ndim == 1: gufunc = _umath_linalg.solve1 else: @@ -68,23 +99,45 @@ def solve(x1: _np.ndarray, x2: _np.ndarray, /) -> _np.ndarray: # This does nothing currently but is left in because it will be relevant # when complex dtype support is added to the spec in 2022. - signature = 'DD->D' if isComplexType(t) else 'dd->d' - with _np.errstate(call=_raise_linalgerror_singular, invalid='call', - over='ignore', divide='ignore', under='ignore'): - r = gufunc(x1, x2, signature=signature) + signature = "DD->D" if isComplexType(t) else "dd->d" + with np.errstate( + call=_raise_linalgerror_singular, + invalid="call", + over="ignore", + divide="ignore", + under="ignore", + ): + r: Array = gufunc(x1, x2, signature=signature) return wrap(r.astype(result_t, copy=False)) + # These functions are completely new here. If the library already has them # (i.e., numpy 2.0), use the library version instead of our wrapper. -if hasattr(np.linalg, 'vector_norm'): +if hasattr(np.linalg, "vector_norm"): vector_norm = np.linalg.vector_norm else: vector_norm = get_xp(np)(_linalg.vector_norm) -__all__ = linalg_all + _linalg.__all__ + ['solve'] -del get_xp -del np -del linalg_all -del _linalg +__all__ = [ + "LinAlgError", + "cond", + "det", + "eig", + "eigvals", + "eigvalsh", + "inv", + "lstsq", + "matrix_power", + "multi_dot", + "norm", + "tensorinv", + "tensorsolve", +] +__all__ += _linalg.__all__ +__all__ += ["solve", "vector_norm"] + + +def __dir__() -> list[str]: + return __all__ From 5e14b53a3558765a8f9b921c72f0249cc0c1c5b9 Mon Sep 17 00:00:00 2001 From: Joren Hammudoglu Date: Sat, 19 Apr 2025 16:08:41 +0200 Subject: [PATCH 35/80] TYP: reject `bool` in the `ord` params of `vector_norm` and `matrix_norm` (#310) * TYP: auto-plagiarize the optypean `Just*` types * TYP: reject `bool` in the `ord` params of `vector_norm` and `matrix_norm` * TYP: remove accidental type alias * TYP: Tighten the `ord` param of `matrix_norm` Co-authored-by: Lucas Colley --------- Co-authored-by: Lucas Colley --- array_api_compat/common/_linalg.py | 6 ++-- array_api_compat/common/_typing.py | 44 +++++++++++++++++++++++++++++- array_api_compat/torch/linalg.py | 8 ++++-- 3 files changed, 52 insertions(+), 6 deletions(-) diff --git a/array_api_compat/common/_linalg.py b/array_api_compat/common/_linalg.py index 7e002aed..7ad87a1b 100644 --- a/array_api_compat/common/_linalg.py +++ b/array_api_compat/common/_linalg.py @@ -12,7 +12,7 @@ from .._internal import get_xp from ._aliases import isdtype, matmul, matrix_transpose, tensordot, vecdot -from ._typing import Array, DType, Namespace +from ._typing import Array, DType, JustFloat, JustInt, Namespace # These are in the main NumPy namespace but not in numpy.linalg @@ -139,7 +139,7 @@ def matrix_norm( xp: Namespace, *, keepdims: bool = False, - ord: float | Literal["fro", "nuc"] | None = "fro", + ord: Literal[1, 2, -1, -2] | JustFloat | Literal["fro", "nuc"] | None = "fro", ) -> Array: return xp.linalg.norm(x, axis=(-2, -1), keepdims=keepdims, ord=ord) @@ -155,7 +155,7 @@ def vector_norm( *, axis: int | tuple[int, ...] | None = None, keepdims: bool = False, - ord: float = 2, + ord: JustInt | JustFloat = 2, ) -> Array: # xp.linalg.norm tries to do a matrix norm whenever axis is a 2-tuple or # when axis=None and the input is 2-D, so to force a vector norm, we make diff --git a/array_api_compat/common/_typing.py b/array_api_compat/common/_typing.py index d7deade1..cd26feeb 100644 --- a/array_api_compat/common/_typing.py +++ b/array_api_compat/common/_typing.py @@ -2,7 +2,15 @@ from collections.abc import Mapping from types import ModuleType as Namespace -from typing import TYPE_CHECKING, Literal, Protocol, TypeAlias, TypedDict, TypeVar +from typing import ( + TYPE_CHECKING, + Literal, + Protocol, + TypeAlias, + TypedDict, + TypeVar, + final, +) if TYPE_CHECKING: from _typeshed import Incomplete @@ -21,6 +29,37 @@ _T_co = TypeVar("_T_co", covariant=True) +# These "Just" types are equivalent to the `Just` type from the `optype` library, +# apart from them not being `@runtime_checkable`. +# - docs: https://github.com/jorenham/optype/blob/master/README.md#just +# - code: https://github.com/jorenham/optype/blob/master/optype/_core/_just.py +@final +class JustInt(Protocol): + @property + def __class__(self, /) -> type[int]: ... + @__class__.setter + def __class__(self, value: type[int], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride] + + +@final +class JustFloat(Protocol): + @property + def __class__(self, /) -> type[float]: ... + @__class__.setter + def __class__(self, value: type[float], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride] + + +@final +class JustComplex(Protocol): + @property + def __class__(self, /) -> type[complex]: ... + @__class__.setter + def __class__(self, value: type[complex], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride] + + +# + + class NestedSequence(Protocol[_T_co]): def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ... def __len__(self, /) -> int: ... @@ -140,6 +179,9 @@ class DTypesAll(DTypesBool, DTypesNumeric): "Device", "HasShape", "Namespace", + "JustInt", + "JustFloat", + "JustComplex", "NestedSequence", "SupportsArrayNamespace", "SupportsBufferProtocol", diff --git a/array_api_compat/torch/linalg.py b/array_api_compat/torch/linalg.py index 1ff7319d..70d72405 100644 --- a/array_api_compat/torch/linalg.py +++ b/array_api_compat/torch/linalg.py @@ -16,6 +16,7 @@ # These functions are in both the main and linalg namespaces from ._aliases import matmul, matrix_transpose, tensordot from ._typing import Array, DType +from ..common._typing import JustInt, JustFloat # Note: torch.linalg.cross does not default to axis=-1 (it defaults to the # first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743 @@ -84,8 +85,8 @@ def vector_norm( *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, - # float stands for inf | -inf, which are not valid for Literal - ord: Union[int, float] = 2, + # JustFloat stands for inf | -inf, which are not valid for Literal + ord: JustInt | JustFloat = 2, **kwargs, ) -> Array: # torch.vector_norm incorrectly treats axis=() the same as axis=None @@ -115,3 +116,6 @@ def vector_norm( _all_ignore = ['torch_linalg', 'sum'] del linalg_all + +def __dir__() -> list[str]: + return __all__ From 52e01beae335c088d25bd6d76f5ae44a231800f5 Mon Sep 17 00:00:00 2001 From: Guido Imperiale Date: Mon, 21 Apr 2025 19:38:53 +0100 Subject: [PATCH 36/80] ENH: cache helper functions (#308) * ENH: cache helper functions --- array_api_compat/common/_helpers.py | 192 ++++++++++++++++------------ 1 file changed, 108 insertions(+), 84 deletions(-) diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index db3e4cd7..d50e0d83 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -12,7 +12,8 @@ import math import sys import warnings -from collections.abc import Collection +from collections.abc import Collection, Hashable +from functools import lru_cache from typing import ( TYPE_CHECKING, Any, @@ -61,23 +62,37 @@ _API_VERSIONS: Final = _API_VERSIONS_OLD | frozenset({"2024.12"}) +@lru_cache(100) +def _issubclass_fast(cls: type, modname: str, clsname: str) -> bool: + try: + mod = sys.modules[modname] + except KeyError: + return False + parent_cls = getattr(mod, clsname) + return issubclass(cls, parent_cls) + + def _is_jax_zero_gradient_array(x: object) -> TypeGuard[_ZeroGradientArray]: """Return True if `x` is a zero-gradient array. These arrays are a design quirk of Jax that may one day be removed. See https://github.com/google/jax/issues/20620. """ - if "numpy" not in sys.modules or "jax" not in sys.modules: + # Fast exit + try: + dtype = x.dtype # type: ignore[attr-defined] + except AttributeError: + return False + cls = cast(Hashable, type(dtype)) + if not _issubclass_fast(cls, "numpy.dtypes", "VoidDType"): return False - import jax - import numpy as np + if "jax" not in sys.modules: + return False - jax_float0 = cast("np.dtype[np.void]", jax.float0) - return ( - isinstance(x, np.ndarray) - and cast("npt.NDArray[np.void]", x).dtype == jax_float0 - ) + import jax + # jax.float0 is a np.dtype([('float0', 'V')]) + return dtype == jax.float0 def is_numpy_array(x: object) -> TypeGuard[npt.NDArray[Any]]: @@ -101,15 +116,12 @@ def is_numpy_array(x: object) -> TypeGuard[npt.NDArray[Any]]: is_jax_array is_pydata_sparse_array """ - # Avoid importing NumPy if it isn't already - if "numpy" not in sys.modules: - return False - - import numpy as np - # TODO: Should we reject ndarray subclasses? - return (isinstance(x, (np.ndarray, np.generic)) - and not _is_jax_zero_gradient_array(x)) # pyright: ignore[reportUnknownArgumentType] # fmt: skip + cls = cast(Hashable, type(x)) + return ( + _issubclass_fast(cls, "numpy", "ndarray") + or _issubclass_fast(cls, "numpy", "generic") + ) and not _is_jax_zero_gradient_array(x) def is_cupy_array(x: object) -> bool: @@ -133,14 +145,8 @@ def is_cupy_array(x: object) -> bool: is_jax_array is_pydata_sparse_array """ - # Avoid importing CuPy if it isn't already - if "cupy" not in sys.modules: - return False - - import cupy as cp # pyright: ignore[reportMissingTypeStubs] - - # TODO: Should we reject ndarray subclasses? - return isinstance(x, cp.ndarray) # pyright: ignore[reportUnknownMemberType] + cls = cast(Hashable, type(x)) + return _issubclass_fast(cls, "cupy", "ndarray") def is_torch_array(x: object) -> TypeIs[torch.Tensor]: @@ -161,14 +167,8 @@ def is_torch_array(x: object) -> TypeIs[torch.Tensor]: is_jax_array is_pydata_sparse_array """ - # Avoid importing torch if it isn't already - if "torch" not in sys.modules: - return False - - import torch - - # TODO: Should we reject ndarray subclasses? - return isinstance(x, torch.Tensor) + cls = cast(Hashable, type(x)) + return _issubclass_fast(cls, "torch", "Tensor") def is_ndonnx_array(x: object) -> TypeIs[ndx.Array]: @@ -190,13 +190,8 @@ def is_ndonnx_array(x: object) -> TypeIs[ndx.Array]: is_jax_array is_pydata_sparse_array """ - # Avoid importing torch if it isn't already - if "ndonnx" not in sys.modules: - return False - - import ndonnx as ndx - - return isinstance(x, ndx.Array) + cls = cast(Hashable, type(x)) + return _issubclass_fast(cls, "ndonnx", "Array") def is_dask_array(x: object) -> TypeIs[da.Array]: @@ -218,13 +213,8 @@ def is_dask_array(x: object) -> TypeIs[da.Array]: is_jax_array is_pydata_sparse_array """ - # Avoid importing dask if it isn't already - if "dask.array" not in sys.modules: - return False - - import dask.array - - return isinstance(x, dask.array.Array) + cls = cast(Hashable, type(x)) + return _issubclass_fast(cls, "dask.array", "Array") def is_jax_array(x: object) -> TypeIs[jax.Array]: @@ -247,13 +237,8 @@ def is_jax_array(x: object) -> TypeIs[jax.Array]: is_dask_array is_pydata_sparse_array """ - # Avoid importing jax if it isn't already - if "jax" not in sys.modules: - return False - - import jax - - return isinstance(x, jax.Array) or _is_jax_zero_gradient_array(x) + cls = cast(Hashable, type(x)) + return _issubclass_fast(cls, "jax", "Array") or _is_jax_zero_gradient_array(x) def is_pydata_sparse_array(x: object) -> TypeIs[sparse.SparseArray]: @@ -276,14 +261,9 @@ def is_pydata_sparse_array(x: object) -> TypeIs[sparse.SparseArray]: is_dask_array is_jax_array """ - # Avoid importing jax if it isn't already - if "sparse" not in sys.modules: - return False - - import sparse # pyright: ignore[reportMissingTypeStubs] - # TODO: Account for other backends. - return isinstance(x, sparse.SparseArray) + cls = cast(Hashable, type(x)) + return _issubclass_fast(cls, "sparse", "SparseArray") def is_array_api_obj(x: object) -> TypeIs[_ArrayApiObj]: # pyright: ignore[reportUnknownParameterType] @@ -302,13 +282,23 @@ def is_array_api_obj(x: object) -> TypeIs[_ArrayApiObj]: # pyright: ignore[repo is_jax_array """ return ( - is_numpy_array(x) - or is_cupy_array(x) - or is_torch_array(x) - or is_dask_array(x) - or is_jax_array(x) - or is_pydata_sparse_array(x) - or hasattr(x, "__array_namespace__") + hasattr(x, '__array_namespace__') + or _is_array_api_cls(cast(Hashable, type(x))) + ) + + +@lru_cache(100) +def _is_array_api_cls(cls: type) -> bool: + return ( + # TODO: drop support for numpy<2 which didn't have __array_namespace__ + _issubclass_fast(cls, "numpy", "ndarray") + or _issubclass_fast(cls, "numpy", "generic") + or _issubclass_fast(cls, "cupy", "ndarray") + or _issubclass_fast(cls, "torch", "Tensor") + or _issubclass_fast(cls, "dask.array", "Array") + or _issubclass_fast(cls, "sparse", "SparseArray") + # TODO: drop support for jax<0.4.32 which didn't have __array_namespace__ + or _issubclass_fast(cls, "jax", "Array") ) @@ -317,6 +307,7 @@ def _compat_module_name() -> str: return __name__.removesuffix(".common._helpers") +@lru_cache(100) def is_numpy_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is a NumPy namespace. @@ -338,6 +329,7 @@ def is_numpy_namespace(xp: Namespace) -> bool: return xp.__name__ in {"numpy", _compat_module_name() + ".numpy"} +@lru_cache(100) def is_cupy_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is a CuPy namespace. @@ -359,6 +351,7 @@ def is_cupy_namespace(xp: Namespace) -> bool: return xp.__name__ in {"cupy", _compat_module_name() + ".cupy"} +@lru_cache(100) def is_torch_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is a PyTorch namespace. @@ -399,6 +392,7 @@ def is_ndonnx_namespace(xp: Namespace) -> bool: return xp.__name__ == "ndonnx" +@lru_cache(100) def is_dask_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is a Dask namespace. @@ -939,6 +933,19 @@ def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None: return None if math.isnan(out) else out +@lru_cache(100) +def _is_writeable_cls(cls: type) -> bool | None: + if ( + _issubclass_fast(cls, "numpy", "generic") + or _issubclass_fast(cls, "jax", "Array") + or _issubclass_fast(cls, "sparse", "SparseArray") + ): + return False + if _is_array_api_cls(cls): + return True + return None + + def is_writeable_array(x: object) -> bool: """ Return False if ``x.__setitem__`` is expected to raise; True otherwise. @@ -949,11 +956,32 @@ def is_writeable_array(x: object) -> bool: As there is no standard way to check if an array is writeable without actually writing to it, this function blindly returns True for all unknown array types. """ - if is_numpy_array(x): - return x.flags.writeable - if is_jax_array(x) or is_pydata_sparse_array(x): + cls = cast(Hashable, type(x)) + if _issubclass_fast(cls, "numpy", "ndarray"): + return cast("npt.NDArray", x).flags.writeable + res = _is_writeable_cls(cls) + if res is not None: + return res + return hasattr(x, '__array_namespace__') + + +@lru_cache(100) +def _is_lazy_cls(cls: type) -> bool | None: + if ( + _issubclass_fast(cls, "numpy", "ndarray") + or _issubclass_fast(cls, "numpy", "generic") + or _issubclass_fast(cls, "cupy", "ndarray") + or _issubclass_fast(cls, "torch", "Tensor") + or _issubclass_fast(cls, "sparse", "SparseArray") + ): return False - return is_array_api_obj(x) + if ( + _issubclass_fast(cls, "jax", "Array") + or _issubclass_fast(cls, "dask.array", "Array") + or _issubclass_fast(cls, "ndonnx", "Array") + ): + return True + return None def is_lazy_array(x: object) -> bool: @@ -969,14 +997,6 @@ def is_lazy_array(x: object) -> bool: This function errs on the side of caution for array types that may or may not be lazy, e.g. JAX arrays, by always returning True for them. """ - if ( - is_numpy_array(x) - or is_cupy_array(x) - or is_torch_array(x) - or is_pydata_sparse_array(x) - ): - return False - # **JAX note:** while it is possible to determine if you're inside or outside # jax.jit by testing the subclass of a jax.Array object, as well as testing bool() # as we do below for unknown arrays, this is not recommended by JAX best practices. @@ -986,10 +1006,14 @@ def is_lazy_array(x: object) -> bool: # compatibility, is highly detrimental to performance as the whole graph will end # up being computed multiple times. - if is_jax_array(x) or is_dask_array(x) or is_ndonnx_array(x): - return True + # Note: skipping reclassification of JAX zero gradient arrays, as one will + # exclusively get them once they leave a jax.grad JIT context. + cls = cast(Hashable, type(x)) + res = _is_lazy_cls(cls) + if res is not None: + return res - if not is_array_api_obj(x): + if not hasattr(x, "__array_namespace__"): return False # Unknown Array API compatible object. Note that this test may have dire consequences @@ -1042,7 +1066,7 @@ def is_lazy_array(x: object) -> bool: "to_device", ] -_all_ignore = ["sys", "math", "inspect", "warnings"] +_all_ignore = ['lru_cache', 'sys', 'math', 'inspect', 'warnings'] def __dir__() -> list[str]: return __all__ From e600449a645c2e6ce5a2276da0006491f097c096 Mon Sep 17 00:00:00 2001 From: Guido Imperiale Date: Thu, 24 Apr 2025 10:09:45 +0100 Subject: [PATCH 37/80] ENH: Simplify CuPy `asarray` and `to_device` (#314) reviewed at https://github.com/data-apis/array-api-compat/pull/314 --- array_api_compat/common/_helpers.py | 48 ++++++++++------------------- array_api_compat/cupy/_aliases.py | 30 +++++------------- cupy-xfails.txt | 3 -- tests/test_common.py | 24 +++++++++------ tests/test_cupy.py | 22 +++++++++++++ 5 files changed, 61 insertions(+), 66 deletions(-) create mode 100644 tests/test_cupy.py diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index d50e0d83..77175d0d 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -775,42 +775,28 @@ def _cupy_to_device( /, stream: int | Any | None = None, ) -> _CupyArray: - import cupy as cp # pyright: ignore[reportMissingTypeStubs] - from cupy.cuda import Device as _Device # pyright: ignore - from cupy.cuda import stream as stream_module # pyright: ignore - from cupy_backends.cuda.api import runtime # pyright: ignore + import cupy as cp - if device == x.device: - return x - elif device == "cpu": + if device == "cpu": # allowing us to use `to_device(x, "cpu")` # is useful for portable test swapping between # host and device backends return x.get() - elif not isinstance(device, _Device): - raise ValueError(f"Unsupported device {device!r}") - else: - # see cupy/cupy#5985 for the reason how we handle device/stream here - prev_device: Any = runtime.getDevice() # pyright: ignore[reportUnknownMemberType] - prev_stream = None - if stream is not None: - prev_stream: Any = stream_module.get_current_stream() # pyright: ignore - # stream can be an int as specified in __dlpack__, or a CuPy stream - if isinstance(stream, int): - stream = cp.cuda.ExternalStream(stream) # pyright: ignore - elif isinstance(stream, cp.cuda.Stream): # pyright: ignore[reportUnknownMemberType] - pass - else: - raise ValueError("the input stream is not recognized") - stream.use() # pyright: ignore[reportUnknownMemberType] - try: - runtime.setDevice(device.id) # pyright: ignore[reportUnknownMemberType] - arr = x.copy() - finally: - runtime.setDevice(prev_device) # pyright: ignore[reportUnknownMemberType] - if stream is not None: - prev_stream.use() - return arr + if not isinstance(device, cp.cuda.Device): + raise TypeError(f"Unsupported device type {device!r}") + + if stream is None: + with device: + return cp.asarray(x) + + # stream can be an int as specified in __dlpack__, or a CuPy stream + if isinstance(stream, int): + stream = cp.cuda.ExternalStream(stream) + elif not isinstance(stream, cp.cuda.Stream): + raise TypeError(f"Unsupported stream type {stream!r}") + + with device, stream: + return cp.asarray(x) def _torch_to_device( diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index fd1460ae..adb74bff 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -64,8 +64,6 @@ finfo = get_xp(cp)(_aliases.finfo) iinfo = get_xp(cp)(_aliases.iinfo) -_copy_default = object() - # asarray also adds the copy keyword, which is not present in numpy 1.0. def asarray( @@ -79,7 +77,7 @@ def asarray( *, dtype: Optional[DType] = None, device: Optional[Device] = None, - copy: Optional[bool] = _copy_default, + copy: Optional[bool] = None, **kwargs, ) -> Array: """ @@ -89,25 +87,13 @@ def asarray( specification for more details. """ with cp.cuda.Device(device): - # cupy is like NumPy 1.26 (except without _CopyMode). See the comments - # in asarray in numpy/_aliases.py. - if copy is not _copy_default: - # A future version of CuPy will change the meaning of copy=False - # to mean no-copy. We don't know for certain what version it will - # be yet, so to avoid breaking that version, we use a different - # default value for copy so asarray(obj) with no copy kwarg will - # always do the copy-if-needed behavior. - - # This will still need to be updated to remove the - # NotImplementedError for copy=False, but at least this won't - # break the default or existing behavior. - if copy is None: - copy = False - elif copy is False: - raise NotImplementedError("asarray(copy=False) is not yet supported in cupy") - kwargs['copy'] = copy - - return cp.array(obj, dtype=dtype, **kwargs) + if copy is None: + return cp.asarray(obj, dtype=dtype, **kwargs) + else: + res = cp.array(obj, dtype=dtype, copy=copy, **kwargs) + if not copy and res is not obj: + raise ValueError("Unable to avoid copy while creating an array as requested") + return res def astype( diff --git a/cupy-xfails.txt b/cupy-xfails.txt index a30572f8..df85d9ca 100644 --- a/cupy-xfails.txt +++ b/cupy-xfails.txt @@ -11,9 +11,6 @@ array_api_tests/test_array_object.py::test_scalar_casting[__index__(int64)] # testsuite bug (https://github.com/data-apis/array-api-tests/issues/172) array_api_tests/test_array_object.py::test_getitem -# copy=False is not yet implemented -array_api_tests/test_creation_functions.py::test_asarray_arrays - # attributes are np.float32 instead of float # (see also https://github.com/data-apis/array-api/issues/405) array_api_tests/test_data_type_functions.py::test_finfo[float32] diff --git a/tests/test_common.py b/tests/test_common.py index 6b1aa160..d1933899 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -17,6 +17,7 @@ from array_api_compat import ( device, is_array_api_obj, is_lazy_array, is_writeable_array, size, to_device ) +from array_api_compat.common._helpers import _DASK_DEVICE from ._helpers import all_libraries, import_, wrapped_libraries, xfail @@ -189,23 +190,26 @@ class C: @pytest.mark.parametrize("library", all_libraries) -def test_device(library, request): +def test_device_to_device(library, request): if library == "ndonnx": - xfail(request, reason="Needs ndonnx >=0.9.4") + xfail(request, reason="Stub raises ValueError") + if library == "sparse": + xfail(request, reason="No __array_namespace_info__()") xp = import_(library, wrapper=True) + devices = xp.__array_namespace_info__().devices() - # We can't test much for device() and to_device() other than that - # x.to_device(x.device) works. - + # Default device x = xp.asarray([1, 2, 3]) dev = device(x) - x2 = to_device(x, dev) - assert device(x2) == device(x) - - x3 = xp.asarray(x, device=dev) - assert device(x3) == device(x) + for dev in devices: + if dev is None: # JAX >=0.5.3 + continue + if dev is _DASK_DEVICE: # TODO this needs a better design + continue + y = to_device(x, dev) + assert device(y) == dev @pytest.mark.parametrize("library", wrapped_libraries) diff --git a/tests/test_cupy.py b/tests/test_cupy.py new file mode 100644 index 00000000..f8b4a4d8 --- /dev/null +++ b/tests/test_cupy.py @@ -0,0 +1,22 @@ +import pytest +from array_api_compat import device, to_device + +xp = pytest.importorskip("array_api_compat.cupy") +from cupy.cuda import Stream + + +def test_to_device_with_stream(): + devices = xp.__array_namespace_info__().devices() + streams = [ + Stream(), + Stream(non_blocking=True), + Stream(null=True), + Stream(ptds=True), + 123, # dlpack stream + ] + + a = xp.asarray([1, 2, 3]) + for dev in devices: + for stream in streams: + b = to_device(a, dev, stream=stream) + assert device(b) == dev From 1acba0c1cd06bd26eb526bd08168f2c60f22f0b9 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Fri, 9 May 2025 13:13:46 +0200 Subject: [PATCH 38/80] BUG: take_along_axis: add numpy and cupy aliases, skip testing on dask (#317) --- array_api_compat/cupy/_aliases.py | 8 +++++++- array_api_compat/numpy/_aliases.py | 6 ++++++ dask-xfails.txt | 3 +++ 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index adb74bff..90b48f05 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -124,6 +124,11 @@ def count_nonzero( return result +# take_along_axis: axis defaults to -1 but in cupy (and numpy) axis is a required arg +def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1): + return cp.take_along_axis(x, indices, axis=axis) + + # These functions are completely new here. If the library already has them # (i.e., numpy 2.0), use the library version instead of our wrapper. if hasattr(cp, 'vecdot'): @@ -145,6 +150,7 @@ def count_nonzero( 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atan2', 'atanh', 'bitwise_left_shift', 'bitwise_invert', 'bitwise_right_shift', - 'bool', 'concat', 'count_nonzero', 'pow', 'sign'] + 'bool', 'concat', 'count_nonzero', 'pow', 'sign', + 'take_along_axis'] _all_ignore = ['cp', 'get_xp'] diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index d8792611..a1aee5c0 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -140,6 +140,11 @@ def count_nonzero( return result +# take_along_axis: axis defaults to -1 but in numpy axis is a required arg +def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1): + return np.take_along_axis(x, indices, axis=axis) + + # These functions are completely new here. If the library already has them # (i.e., numpy 2.0), use the library version instead of our wrapper. if hasattr(np, "vecdot"): @@ -175,6 +180,7 @@ def count_nonzero( "concat", "count_nonzero", "pow", + "take_along_axis" ] __all__ += _aliases.__all__ _all_ignore = ["np", "get_xp"] diff --git a/dask-xfails.txt b/dask-xfails.txt index 932aeada..3efb4f96 100644 --- a/dask-xfails.txt +++ b/dask-xfails.txt @@ -24,6 +24,9 @@ array_api_tests/test_creation_functions.py::test_linspace # Shape mismatch array_api_tests/test_indexing_functions.py::test_take +# missing `take_along_axis`, https://github.com/dask/dask/issues/3663 +array_api_tests/test_indexing_functions.py::test_take_along_axis + # Array methods and attributes not already on da.Array cannot be wrapped array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__] array_api_tests/test_has_names.py::test_has_names[array_method-to_device] From ddbbc35ab2bebed4637f18e227d6a9138c0f7669 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Fri, 9 May 2025 16:19:59 +0200 Subject: [PATCH 39/80] TST: add a skip for CuPy pow(er) is not fully NEP50 compatible in CuPy 13.x --- cupy-xfails.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cupy-xfails.txt b/cupy-xfails.txt index a30572f8..55c6437d 100644 --- a/cupy-xfails.txt +++ b/cupy-xfails.txt @@ -37,6 +37,8 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__sub array_api_tests/test_operators_and_elementwise_functions.py::test_add[__add__(x, s)] # floating point inaccuracy array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[remainder(x1, x2)] +# incomplete NEP50 support in CuPy 13.x (fixed in 14.0.0a1) +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[pow] # cupy (arg)min/max wrong with infinities # https://github.com/cupy/cupy/issues/7424 From 7d7a85862b345e0247e20dd64dbe6d327098a869 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Fri, 9 May 2025 14:46:06 +0000 Subject: [PATCH 40/80] TST: update CuPy skips --- cupy-xfails.txt | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cupy-xfails.txt b/cupy-xfails.txt index 55c6437d..89e9af54 100644 --- a/cupy-xfails.txt +++ b/cupy-xfails.txt @@ -39,6 +39,9 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_add[__add__(x, array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[remainder(x1, x2)] # incomplete NEP50 support in CuPy 13.x (fixed in 14.0.0a1) array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[pow] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[logaddexp] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[nextafter] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[hypot] # cupy (arg)min/max wrong with infinities # https://github.com/cupy/cupy/issues/7424 @@ -187,7 +190,6 @@ array_api_tests/test_signatures.py::test_func_signature[from_dlpack] array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] # 2024.12 support -array_api_tests/test_signatures.py::test_func_signature[count_nonzero] array_api_tests/test_signatures.py::test_func_signature[bitwise_and] array_api_tests/test_signatures.py::test_func_signature[bitwise_left_shift] array_api_tests/test_signatures.py::test_func_signature[bitwise_or] From 4e3d809646653a919d9c494b8afd730291e441fb Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Fri, 9 May 2025 17:10:49 +0200 Subject: [PATCH 41/80] TST: add xfails for NumPy 1.22 and 1.26 / python scalars --- numpy-1-22-xfails.txt | 7 +++++++ numpy-1-26-xfails.txt | 2 ++ 2 files changed, 9 insertions(+) diff --git a/numpy-1-22-xfails.txt b/numpy-1-22-xfails.txt index 93edf311..c1de77d8 100644 --- a/numpy-1-22-xfails.txt +++ b/numpy-1-22-xfails.txt @@ -123,6 +123,13 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[subtr array_api_tests/test_searching_functions.py::test_where array_api_tests/test_special_cases.py::test_iop[__iadd__(x1_i is -0 and x2_i is -0) -> -0] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[add] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[divide] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[hypot] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[subtract] + +array_api_tests/test_operators_and_elementwise_functions.py::test_where_with_scalars + # 2023.12 support array_api_tests/test_has_names.py::test_has_names[creation-from_dlpack] array_api_tests/test_signatures.py::test_func_signature[from_dlpack] diff --git a/numpy-1-26-xfails.txt b/numpy-1-26-xfails.txt index 51e1a658..98cb9f6c 100644 --- a/numpy-1-26-xfails.txt +++ b/numpy-1-26-xfails.txt @@ -50,6 +50,8 @@ array_api_tests/test_signatures.py::test_func_signature[bitwise_right_shift] array_api_tests/test_signatures.py::test_func_signature[bitwise_xor] array_api_tests/test_data_type_functions.py::TestResultType::test_with_scalars +array_api_tests/test_operators_and_elementwise_functions.py::test_where_with_scalars + # Stubs have a comment: (**note**: libraries may return ``NaN`` to match Python behavior.); Apparently, NumPy does just that array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] From 9b8f252683bdd90090649b801dc31402c58fdc96 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Thu, 8 May 2025 22:49:09 +0200 Subject: [PATCH 42/80] BUG: torch: fix count_nonzero with axis tuple and keepdims --- array_api_compat/torch/_aliases.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 027a0261..335008e4 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -548,8 +548,12 @@ def count_nonzero( ) -> Array: result = torch.count_nonzero(x, dim=axis) if keepdims: - if axis is not None: + if isinstance(axis, int): return result.unsqueeze(axis) + elif isinstance(axis, tuple): + n_axis = [x.ndim + ax if ax < 0 else ax for ax in axis] + sh = [1 if i in n_axis else x.shape[i] for i in range(x.ndim)] + return torch.reshape(result, sh) return _axis_none_keepdims(result, x.ndim, keepdims) else: return result From 8c62443da64b2dee5fbf0623f9fd510e62577c45 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Fri, 9 May 2025 23:35:22 +0200 Subject: [PATCH 43/80] TST: update numpy 1.22 xfails --- numpy-1-22-xfails.txt | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/numpy-1-22-xfails.txt b/numpy-1-22-xfails.txt index c1de77d8..cacb95b7 100644 --- a/numpy-1-22-xfails.txt +++ b/numpy-1-22-xfails.txt @@ -127,6 +127,13 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_sc array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[divide] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[hypot] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[subtract] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[logaddexp] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[nextafter] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[multiply] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[maximum] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[copysign] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[pow] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[pow] array_api_tests/test_operators_and_elementwise_functions.py::test_where_with_scalars From 5597ec755d44cb005f01601b3c2193f9f56b604f Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sun, 11 May 2025 12:56:49 +0200 Subject: [PATCH 44/80] CI: use ARRAY_API_TESTS_XFAIL_MARK on CI --- .github/workflows/array-api-tests-dask.yml | 2 ++ .github/workflows/array-api-tests-numpy-1-22.yml | 2 ++ .github/workflows/array-api-tests-numpy-1-26.yml | 2 ++ .github/workflows/array-api-tests-numpy-dev.yml | 2 ++ .github/workflows/array-api-tests-numpy-latest.yml | 2 ++ .github/workflows/array-api-tests-torch.yml | 1 + 6 files changed, 11 insertions(+) diff --git a/.github/workflows/array-api-tests-dask.yml b/.github/workflows/array-api-tests-dask.yml index afc67975..964fb52d 100644 --- a/.github/workflows/array-api-tests-dask.yml +++ b/.github/workflows/array-api-tests-dask.yml @@ -16,3 +16,5 @@ jobs: # the full test suite with at least 200 examples. pytest-extra-args: --max-examples=5 python-versions: '[''3.10'', ''3.13'']' + extra-env-vars: | + ARRAY_API_TESTS_XFAIL_MARK=skip diff --git a/.github/workflows/array-api-tests-numpy-1-22.yml b/.github/workflows/array-api-tests-numpy-1-22.yml index d8f60432..1cf6e26d 100644 --- a/.github/workflows/array-api-tests-numpy-1-22.yml +++ b/.github/workflows/array-api-tests-numpy-1-22.yml @@ -10,3 +10,5 @@ jobs: package-version: '== 1.22.*' xfails-file-extra: '-1-22' python-versions: '[''3.10'']' + extra-env-vars: | + ARRAY_API_TESTS_XFAIL_MARK=skip diff --git a/.github/workflows/array-api-tests-numpy-1-26.yml b/.github/workflows/array-api-tests-numpy-1-26.yml index 33780760..a2788d2f 100644 --- a/.github/workflows/array-api-tests-numpy-1-26.yml +++ b/.github/workflows/array-api-tests-numpy-1-26.yml @@ -10,3 +10,5 @@ jobs: package-version: '== 1.26.*' xfails-file-extra: '-1-26' python-versions: '[''3.10'', ''3.12'']' + extra-env-vars: | + ARRAY_API_TESTS_XFAIL_MARK=skip diff --git a/.github/workflows/array-api-tests-numpy-dev.yml b/.github/workflows/array-api-tests-numpy-dev.yml index d6de1a53..dce0813f 100644 --- a/.github/workflows/array-api-tests-numpy-dev.yml +++ b/.github/workflows/array-api-tests-numpy-dev.yml @@ -10,3 +10,5 @@ jobs: extra-requires: '--pre --extra-index https://pypi.anaconda.org/scientific-python-nightly-wheels/simple' xfails-file-extra: '-dev' python-versions: '[''3.11'', ''3.13'']' + extra-env-vars: | + ARRAY_API_TESTS_XFAIL_MARK=skip diff --git a/.github/workflows/array-api-tests-numpy-latest.yml b/.github/workflows/array-api-tests-numpy-latest.yml index 4d3667f6..54b21a25 100644 --- a/.github/workflows/array-api-tests-numpy-latest.yml +++ b/.github/workflows/array-api-tests-numpy-latest.yml @@ -8,3 +8,5 @@ jobs: with: package-name: numpy python-versions: '[''3.10'', ''3.13'']' + extra-env-vars: | + ARRAY_API_TESTS_XFAIL_MARK=skip diff --git a/.github/workflows/array-api-tests-torch.yml b/.github/workflows/array-api-tests-torch.yml index ac20df25..4dcb3347 100644 --- a/.github/workflows/array-api-tests-torch.yml +++ b/.github/workflows/array-api-tests-torch.yml @@ -10,4 +10,5 @@ jobs: extra-requires: '--index-url https://download.pytorch.org/whl/cpu' extra-env-vars: | ARRAY_API_TESTS_SKIP_DTYPES=uint16,uint32,uint64 + ARRAY_API_TESTS_XFAIL_MARK=skip python-versions: '[''3.10'', ''3.13'']' From 5cf5d8f404b18ff67543762ed8e92cb0f359f885 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sun, 11 May 2025 21:58:35 +0200 Subject: [PATCH 45/80] BUG: torch: meshgrid defaults to indexing="xy" As of version 2.6, torch defaults to indexing='ij', and is planning to transition to 'xy' at some point. When it does, we'll be able to drop our wrapper. --- array_api_compat/torch/_aliases.py | 10 ++++++++-- tests/test_torch.py | 17 +++++++++++++++++ 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 335008e4..de5d1a5d 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -2,7 +2,7 @@ from functools import reduce as _reduce, wraps as _wraps from builtins import all as _builtin_all, any as _builtin_any -from typing import Any, List, Optional, Sequence, Tuple, Union +from typing import Any, List, Optional, Sequence, Tuple, Union, Literal import torch @@ -828,6 +828,12 @@ def sign(x: Array, /) -> Array: return out +def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> List[Array]: + # enforce the default of 'xy' + # TODO: is the return type a list or a tuple + return list(torch.meshgrid(*arrays, indexing='xy')) + + __all__ = ['__array_namespace_info__', 'asarray', 'result_type', 'can_cast', 'permute_dims', 'bitwise_invert', 'newaxis', 'conj', 'add', 'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or', @@ -844,6 +850,6 @@ def sign(x: Array, /) -> Array: 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult', 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values', 'matmul', 'matrix_transpose', 'vecdot', 'tensordot', 'isdtype', - 'take', 'take_along_axis', 'sign', 'finfo', 'iinfo', 'repeat'] + 'take', 'take_along_axis', 'sign', 'finfo', 'iinfo', 'repeat', 'meshgrid'] _all_ignore = ['torch', 'get_xp'] diff --git a/tests/test_torch.py b/tests/test_torch.py index e8340f31..7adb4ab3 100644 --- a/tests/test_torch.py +++ b/tests/test_torch.py @@ -100,3 +100,20 @@ def test_gh_273(self, default_dt, dtype_a, dtype_b): assert dtype_1 == dtype_2 finally: torch.set_default_dtype(prev_default) + + +def test_meshgrid(): + """Verify that array_api_compat.torch.meshgrid defaults to indexing='xy'.""" + + x, y = xp.asarray([1, 2]), xp.asarray([4]) + + X, Y = xp.meshgrid(x, y) + + # output of torch.meshgrid(x, y, indexing='xy') -- indexing='ij' is different + X_xy, Y_xy = xp.asarray([[1, 2]]), xp.asarray([[4, 4]]) + + assert X.shape == X_xy.shape + assert xp.all(X == X_xy) + + assert Y.shape == Y_xy.shape + assert xp.all(Y == Y_xy) From 6488ad81748a1b92f7a0de42e5a10461a9df6b62 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 15 May 2025 08:55:20 +0100 Subject: [PATCH 46/80] TST: revisit test for `asarray` `copy=` parameter --- array_api_compat/dask/array/_aliases.py | 2 +- tests/test_common.py | 67 ++++++++++--------------- 2 files changed, 28 insertions(+), 41 deletions(-) diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index 9687a9cd..d43881ab 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -171,7 +171,7 @@ def asarray( return obj.copy() if copy else obj # pyright: ignore[reportAttributeAccessIssue] if copy is False: - raise NotImplementedError( + raise ValueError( "Unable to avoid copy when converting a non-dask object to dask" ) diff --git a/tests/test_common.py b/tests/test_common.py index d1933899..fe4fe598 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -278,87 +278,73 @@ def test_asarray_copy(library): xp = import_(library, wrapper=True) asarray = xp.asarray is_lib_func = globals()[is_array_functions[library]] - all = xp.all if library != 'dask.array' else lambda x: xp.all(x).compute() - - if library == 'cupy': - supports_copy_false_other_ns = False - supports_copy_false_same_ns = False - elif library == 'dask.array': - supports_copy_false_other_ns = False - supports_copy_false_same_ns = True - else: - supports_copy_false_other_ns = True - supports_copy_false_same_ns = True a = asarray([1]) b = asarray(a, copy=True) assert is_lib_func(b) a[0] = 0 - assert all(b[0] == 1) - assert all(a[0] == 0) + assert b[0] == 1 + assert a[0] == 0 a = asarray([1]) - if supports_copy_false_same_ns: - b = asarray(a, copy=False) - assert is_lib_func(b) - a[0] = 0 - assert all(b[0] == 0) - else: - pytest.raises(NotImplementedError, lambda: asarray(a, copy=False)) - a = asarray([1]) - if supports_copy_false_same_ns: - pytest.raises(ValueError, lambda: asarray(a, copy=False, - dtype=xp.float64)) - else: - pytest.raises(NotImplementedError, lambda: asarray(a, copy=False, dtype=xp.float64)) + # Test copy=False within the same namespace + b = asarray(a, copy=False) + assert is_lib_func(b) + a[0] = 0 + assert b[0] == 0 + with pytest.raises(ValueError): + asarray(a, copy=False, dtype=xp.float64) + # copy=None defaults to False when possible a = asarray([1]) b = asarray(a, copy=None) assert is_lib_func(b) a[0] = 0 - assert all(b[0] == 0) + assert b[0] == 0 + # copy=None defaults to True when impossible a = asarray([1.0], dtype=xp.float32) assert a.dtype == xp.float32 b = asarray(a, dtype=xp.float64, copy=None) assert is_lib_func(b) assert b.dtype == xp.float64 a[0] = 0.0 - assert all(b[0] == 1.0) + assert b[0] == 1.0 + # copy=None defaults to False when possible a = asarray([1.0], dtype=xp.float64) assert a.dtype == xp.float64 b = asarray(a, dtype=xp.float64, copy=None) assert is_lib_func(b) assert b.dtype == xp.float64 a[0] = 0.0 - assert all(b[0] == 0.0) + assert b[0] == 0.0 # Python built-in types for obj in [True, 0, 0.0, 0j, [0], [[0]]]: asarray(obj, copy=True) # No error asarray(obj, copy=None) # No error - if supports_copy_false_other_ns: - pytest.raises(ValueError, lambda: asarray(obj, copy=False)) - else: - pytest.raises(NotImplementedError, lambda: asarray(obj, copy=False)) + + with pytest.raises(ValueError): + asarray(obj, copy=False) # Use the standard library array to test the buffer protocol a = array.array('f', [1.0]) b = asarray(a, copy=True) assert is_lib_func(b) a[0] = 0.0 - assert all(b[0] == 1.0) + assert b[0] == 1.0 a = array.array('f', [1.0]) - if supports_copy_false_other_ns: + if library in ('cupy', 'dask.array'): + with pytest.raises(ValueError): + asarray(a, copy=False) + else: b = asarray(a, copy=False) assert is_lib_func(b) a[0] = 0.0 - assert all(b[0] == 0.0) - else: - pytest.raises(NotImplementedError, lambda: asarray(a, copy=False)) + assert b[0] == 0.0 a = array.array('f', [1.0]) b = asarray(a, copy=None) @@ -369,9 +355,10 @@ def test_asarray_copy(library): # dask changed behaviour of copy=None in 2024.12 to copy; # this wrapper ensures the same behaviour in older versions too. # https://github.com/dask/dask/pull/11524/ - assert all(b[0] == 1.0) + assert b[0] == 1.0 else: - assert all(b[0] == 0.0) + # copy=None defaults to False when possible + assert b[0] == 0.0 @pytest.mark.parametrize("library", ["numpy", "cupy", "torch"]) From 7c3d68c47147663399cf4f23de24b9a4193d6f65 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 15 May 2025 09:01:14 +0100 Subject: [PATCH 47/80] TST: fix cupy `to_device` test on multiple devices --- tests/test_cupy.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/tests/test_cupy.py b/tests/test_cupy.py index f8b4a4d8..fb0c69e4 100644 --- a/tests/test_cupy.py +++ b/tests/test_cupy.py @@ -8,15 +8,17 @@ def test_to_device_with_stream(): devices = xp.__array_namespace_info__().devices() streams = [ - Stream(), - Stream(non_blocking=True), - Stream(null=True), - Stream(ptds=True), - 123, # dlpack stream + lambda: Stream(), + lambda: Stream(non_blocking=True), + lambda: Stream(null=True), + lambda: Stream(ptds=True), + lambda: 123, # dlpack stream ] a = xp.asarray([1, 2, 3]) for dev in devices: - for stream in streams: + for stream_gen in streams: + with dev: + stream = stream_gen() b = to_device(a, dev, stream=stream) assert device(b) == dev From c829ef744cb04474b8eedf520557f1ca05bb77dc Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 15 May 2025 09:07:15 +0100 Subject: [PATCH 48/80] nits --- tests/test_cupy.py | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/tests/test_cupy.py b/tests/test_cupy.py index fb0c69e4..5aac36f8 100644 --- a/tests/test_cupy.py +++ b/tests/test_cupy.py @@ -5,20 +5,26 @@ from cupy.cuda import Stream -def test_to_device_with_stream(): - devices = xp.__array_namespace_info__().devices() - streams = [ +@pytest.mark.parametrize( + "make_stream", + [ lambda: Stream(), - lambda: Stream(non_blocking=True), + lambda: Stream(non_blocking=True), lambda: Stream(null=True), - lambda: Stream(ptds=True), + lambda: Stream(ptds=True), lambda: 123, # dlpack stream - ] + ], +) +def test_to_device_with_stream(make_stream): + devices = xp.__array_namespace_info__().devices() a = xp.asarray([1, 2, 3]) for dev in devices: - for stream_gen in streams: - with dev: - stream = stream_gen() - b = to_device(a, dev, stream=stream) - assert device(b) == dev + # Streams are device-specific and must be created within + # the context of the device... + with dev: + stream = make_stream() + # ... however, to_device() does not need to be inside the + # device context. + b = to_device(a, dev, stream=stream) + assert device(b) == dev From 44e7828b0666e5edd26958fb2337d755cf1a6002 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 15 May 2025 09:08:18 +0100 Subject: [PATCH 49/80] lint --- tests/test_common.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/tests/test_common.py b/tests/test_common.py index fe4fe598..54b5ed69 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -268,7 +268,6 @@ def test_asarray_cross_library(source_library, target_library, request): assert b.dtype == tgt_lib.int32 - @pytest.mark.parametrize("library", wrapped_libraries) def test_asarray_copy(library): # Note, we have this test here because the test suite currently doesn't @@ -323,21 +322,21 @@ def test_asarray_copy(library): # Python built-in types for obj in [True, 0, 0.0, 0j, [0], [[0]]]: - asarray(obj, copy=True) # No error - asarray(obj, copy=None) # No error + asarray(obj, copy=True) # No error + asarray(obj, copy=None) # No error with pytest.raises(ValueError): asarray(obj, copy=False) # Use the standard library array to test the buffer protocol - a = array.array('f', [1.0]) + a = array.array("f", [1.0]) b = asarray(a, copy=True) assert is_lib_func(b) a[0] = 0.0 assert b[0] == 1.0 - a = array.array('f', [1.0]) - if library in ('cupy', 'dask.array'): + a = array.array("f", [1.0]) + if library in ("cupy", "dask.array"): with pytest.raises(ValueError): asarray(a, copy=False) else: @@ -346,11 +345,11 @@ def test_asarray_copy(library): a[0] = 0.0 assert b[0] == 0.0 - a = array.array('f', [1.0]) + a = array.array("f", [1.0]) b = asarray(a, copy=None) assert is_lib_func(b) a[0] = 0.0 - if library in ('cupy', 'dask.array'): + if library in ("cupy", "dask.array"): # A copy is required for libraries where the default device is not CPU # dask changed behaviour of copy=None in 2024.12 to copy; # this wrapper ensures the same behaviour in older versions too. From 0433b8e94ca802d9f6402acacb81f9c4fef6f84a Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 15 May 2025 09:26:53 +0100 Subject: [PATCH 50/80] skip segmentation fault --- tests/test_cupy.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/test_cupy.py b/tests/test_cupy.py index 5aac36f8..8b71d978 100644 --- a/tests/test_cupy.py +++ b/tests/test_cupy.py @@ -12,7 +12,11 @@ lambda: Stream(non_blocking=True), lambda: Stream(null=True), lambda: Stream(ptds=True), - lambda: 123, # dlpack stream + pytest.param( + lambda: 123, + id="dlpack stream", + marks=pytest.mark.skip(reason="segmentation fault reported (#326)") + ), ], ) def test_to_device_with_stream(make_stream): From ebd3fd9356664c0502506adba96d1df72c47ec49 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 15 May 2025 09:35:08 +0100 Subject: [PATCH 51/80] Use pointers --- tests/test_cupy.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/tests/test_cupy.py b/tests/test_cupy.py index 8b71d978..4745b983 100644 --- a/tests/test_cupy.py +++ b/tests/test_cupy.py @@ -12,11 +12,6 @@ lambda: Stream(non_blocking=True), lambda: Stream(null=True), lambda: Stream(ptds=True), - pytest.param( - lambda: 123, - id="dlpack stream", - marks=pytest.mark.skip(reason="segmentation fault reported (#326)") - ), ], ) def test_to_device_with_stream(make_stream): @@ -32,3 +27,19 @@ def test_to_device_with_stream(make_stream): # device context. b = to_device(a, dev, stream=stream) assert device(b) == dev + + +def test_to_device_with_dlpack_stream(): + devices = xp.__array_namespace_info__().devices() + + a = xp.asarray([1, 2, 3]) + for dev in devices: + # Streams are device-specific and must be created within + # the context of the device... + with dev: + s1 = Stream() + + # ... however, to_device() does not need to be inside the + # device context. + b = to_device(a, dev, stream=s1.ptr) + assert device(b) == dev From 1c53eeb895c5d1ec93db82e813912589a9aa3b41 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 15 May 2025 09:42:29 +0100 Subject: [PATCH 52/80] MAINT: don't import helpers in numpy namespace --- array_api_compat/numpy/__init__.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/array_api_compat/numpy/__init__.py b/array_api_compat/numpy/__init__.py index f7b558ba..8eab0405 100644 --- a/array_api_compat/numpy/__init__.py +++ b/array_api_compat/numpy/__init__.py @@ -23,13 +23,6 @@ __import__(__package__ + ".fft") -from ..common._helpers import * # noqa: F403 -from .linalg import matrix_transpose, vecdot # noqa: F401 - -try: - # Used in asarray(). Not present in older versions. - from numpy import _CopyMode # noqa: F401 -except ImportError: - pass +from .linalg import matrix_transpose, vecdot # type: ignore[no-redef] # noqa: F401 __array_api_version__: Final = "2024.12" From e945af9debb715da60b807c028776a4e3d1a0c52 Mon Sep 17 00:00:00 2001 From: Guido Imperiale Date: Thu, 15 May 2025 09:45:10 +0100 Subject: [PATCH 53/80] Update array_api_compat/numpy/__init__.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- array_api_compat/numpy/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_compat/numpy/__init__.py b/array_api_compat/numpy/__init__.py index 8eab0405..3e138f53 100644 --- a/array_api_compat/numpy/__init__.py +++ b/array_api_compat/numpy/__init__.py @@ -23,6 +23,6 @@ __import__(__package__ + ".fft") -from .linalg import matrix_transpose, vecdot # type: ignore[no-redef] # noqa: F401 +from .linalg import matrix_transpose, vecdot # type: ignore[no-redef] # noqa: F401 __array_api_version__: Final = "2024.12" From 6b3ec935eb325d443c327d6490e16e69f273da06 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 12 May 2025 13:48:46 +0200 Subject: [PATCH 54/80] DOC: add 1.12 changelog --- docs/changelog.md | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/docs/changelog.md b/docs/changelog.md index 18928e98..c2d5b2c5 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -1,5 +1,42 @@ # Changelog +## 1.12.0 (2025-05-13) + + +### Major changes + +- The build system has been updated to use `pyproject.toml` instead of `setup.py` +- Support for Python 3.9 has been dropped. The minimum supported Python version is now + 3.10; the minimum supported NumPy version is 1.22; the minimum supported `ndonnx` + version is 0.10.1. +- The `linalg` extension works correctly with `pytorch==2.7`. +- Multiple improvements to handling of `device` arguments in `numpy`, `cupy`, `torch`, + and `dask` backends. Support for multiple devices is still relatively immature, + and rough edges can be expected. Please report any issues you encounter. + +### Minor changes + +- `finfo` and `iinfo` functions now accept array arguments, in accordance with the + Array API spec; +- `torch.asarray` function propagates the device of the input array. This works around + the [pytorch issue #150199](https://github.com/pytorch/pytorch/issues/150199); +- `torch.repeat` function is now available; +- `torch.count_nonzero` function now correctly handles the case of a tuple `axis` + arguments and `keepdims=True`; +- `torch.meshgrid` wrapper defaults to `indexing="xy"`, in accordance with the + array API specification; +- `cupy.asarray` function now implements the `copy=True` argument; + + +The following users contributed to this release: + +Evgeni Burovski, +Lucas Colley, +Neil Girdhar, +Joren Hammudoglu, +Guido Imperiale + + ## 1.11.2 (2025-03-20) This is a bugfix release with no new features compared to version 1.11. From 97e3cc5b1b32bd0a0d5c2a9810df9145012992ed Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Tue, 13 May 2025 15:45:34 +0200 Subject: [PATCH 55/80] MAINT: update numpy 1.22 xfails --- numpy-1-22-xfails.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/numpy-1-22-xfails.txt b/numpy-1-22-xfails.txt index cacb95b7..e0b96c61 100644 --- a/numpy-1-22-xfails.txt +++ b/numpy-1-22-xfails.txt @@ -131,6 +131,7 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_sc array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[nextafter] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[multiply] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[maximum] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[minimum] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[copysign] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[pow] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[pow] From 34e2c6f2799e7f0237c035f61b8f1891baae02d7 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Thu, 15 May 2025 11:47:52 +0200 Subject: [PATCH 56/80] DOC: update the changelog Co-authored-by: Guido Imperiale --- docs/changelog.md | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/docs/changelog.md b/docs/changelog.md index c2d5b2c5..c00c62db 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -7,12 +7,12 @@ - The build system has been updated to use `pyproject.toml` instead of `setup.py` - Support for Python 3.9 has been dropped. The minimum supported Python version is now - 3.10; the minimum supported NumPy version is 1.22; the minimum supported `ndonnx` - version is 0.10.1. -- The `linalg` extension works correctly with `pytorch==2.7`. -- Multiple improvements to handling of `device` arguments in `numpy`, `cupy`, `torch`, - and `dask` backends. Support for multiple devices is still relatively immature, - and rough edges can be expected. Please report any issues you encounter. + 3.10; the minimum supported NumPy version is 1.22. +- The `linalg` extension works correctly with `pytorch>=2.7`. +- Multiple improvements to handling of devices in CuPy and PyTorch backends. + Support for multiple devices in CuPy is still immature and you should use + context managers rather than relying on input-output device propagation or + on the `device` parameter. ### Minor changes @@ -25,7 +25,10 @@ arguments and `keepdims=True`; - `torch.meshgrid` wrapper defaults to `indexing="xy"`, in accordance with the array API specification; -- `cupy.asarray` function now implements the `copy=True` argument; +- `cupy.asarray` function now implements the `copy=False` argument, albeit + at the cost of risking to make a temporary copy. +- In `numpy.take_along_axis` and `cupy.take_along_axis` the `axis` parameter now + defaults to -1, in accordance to the Array API spec. The following users contributed to this release: From cdd1213ea28af34b721d105a25d4b7ff2414ef18 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Thu, 15 May 2025 11:48:47 +0200 Subject: [PATCH 57/80] Update docs/changelog.md --- docs/changelog.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/changelog.md b/docs/changelog.md index c00c62db..6f6c1251 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -12,7 +12,7 @@ - Multiple improvements to handling of devices in CuPy and PyTorch backends. Support for multiple devices in CuPy is still immature and you should use context managers rather than relying on input-output device propagation or - on the `device` parameter. + on the `device` parameter. Please report any issues you encounter. ### Minor changes From 26a1d2016517ae3bb86ddfef137247fa15ddb512 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Thu, 15 May 2025 12:02:16 +0200 Subject: [PATCH 58/80] MAINT: update CuPy xfails --- cupy-xfails.txt | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/cupy-xfails.txt b/cupy-xfails.txt index 77def129..0a91cafe 100644 --- a/cupy-xfails.txt +++ b/cupy-xfails.txt @@ -39,6 +39,11 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_sc array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[logaddexp] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[nextafter] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[hypot] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[atan2] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[divide] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[copysign] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[divide] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[maximum] # cupy (arg)min/max wrong with infinities # https://github.com/cupy/cupy/issues/7424 From 8005d6d02c0f1717881de37a710871bb955eb5cd Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Tue, 13 May 2025 14:31:01 +0200 Subject: [PATCH 59/80] REL: bump the version to 1.12.0 --- array_api_compat/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_compat/__init__.py b/array_api_compat/__init__.py index 60b37e97..653cb40a 100644 --- a/array_api_compat/__init__.py +++ b/array_api_compat/__init__.py @@ -17,6 +17,6 @@ this implementation for the default when working with NumPy arrays. """ -__version__ = '1.12.dev0' +__version__ = '1.12.0' from .common import * # noqa: F401, F403 From 91dd626ce8b2612979e513af235be3809791f94b Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Fri, 16 May 2025 10:58:43 +0200 Subject: [PATCH 60/80] MAINT: bump version to 1.13.0.dev0 --- array_api_compat/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_compat/__init__.py b/array_api_compat/__init__.py index 653cb40a..a00e8cbc 100644 --- a/array_api_compat/__init__.py +++ b/array_api_compat/__init__.py @@ -17,6 +17,6 @@ this implementation for the default when working with NumPy arrays. """ -__version__ = '1.12.0' +__version__ = '1.13.0.dev0' from .common import * # noqa: F401, F403 From 3350f670e1b67a37888228c102e0e560f43077bd Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sun, 11 May 2025 14:12:51 +0200 Subject: [PATCH 61/80] CI: run 500 examples on NumPy and PyTorch; 50 on Dask --- .github/workflows/array-api-tests-dask.yml | 2 +- .github/workflows/array-api-tests.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/array-api-tests-dask.yml b/.github/workflows/array-api-tests-dask.yml index 964fb52d..a60b28a4 100644 --- a/.github/workflows/array-api-tests-dask.yml +++ b/.github/workflows/array-api-tests-dask.yml @@ -14,7 +14,7 @@ jobs: # workflow is barely more than a smoke test, and one should expect extreme # flakiness. Before changes to dask-xfails.txt or dask-skips.txt, please run # the full test suite with at least 200 examples. - pytest-extra-args: --max-examples=5 + pytest-extra-args: --max-examples=50 python-versions: '[''3.10'', ''3.13'']' extra-env-vars: | ARRAY_API_TESTS_XFAIL_MARK=skip diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index 31bedde6..f652438b 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -37,7 +37,7 @@ on: description: Multiline string of environment variables to set for the test run. env: - PYTEST_ARGS: "--max-examples 200 -v -rxXfE --ci ${{ inputs.pytest-extra-args }} --hypothesis-disable-deadline --durations 10" + PYTEST_ARGS: "--max-examples 500 -v -rxXfE --ci ${{ inputs.pytest-extra-args }} --hypothesis-disable-deadline --durations 20" jobs: tests: From 5b1ece468fb9b9b789304b57035de6801a39c7b1 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 12 May 2025 14:18:30 +0200 Subject: [PATCH 62/80] CI: use 4 workers; bump the # of examples to 1000 (np/torch), 200 (dask) --- .github/workflows/array-api-tests-dask.yml | 2 +- .github/workflows/array-api-tests-numpy-1-22.yml | 1 + .github/workflows/array-api-tests-numpy-1-26.yml | 1 + .github/workflows/array-api-tests-numpy-dev.yml | 1 + .github/workflows/array-api-tests-numpy-latest.yml | 1 + .github/workflows/array-api-tests-torch.yml | 1 + .github/workflows/array-api-tests.yml | 3 ++- 7 files changed, 8 insertions(+), 2 deletions(-) diff --git a/.github/workflows/array-api-tests-dask.yml b/.github/workflows/array-api-tests-dask.yml index a60b28a4..ef430d9c 100644 --- a/.github/workflows/array-api-tests-dask.yml +++ b/.github/workflows/array-api-tests-dask.yml @@ -14,7 +14,7 @@ jobs: # workflow is barely more than a smoke test, and one should expect extreme # flakiness. Before changes to dask-xfails.txt or dask-skips.txt, please run # the full test suite with at least 200 examples. - pytest-extra-args: --max-examples=50 + pytest-extra-args: --max-examples=200 -n 4 python-versions: '[''3.10'', ''3.13'']' extra-env-vars: | ARRAY_API_TESTS_XFAIL_MARK=skip diff --git a/.github/workflows/array-api-tests-numpy-1-22.yml b/.github/workflows/array-api-tests-numpy-1-22.yml index 1cf6e26d..83d4cf1d 100644 --- a/.github/workflows/array-api-tests-numpy-1-22.yml +++ b/.github/workflows/array-api-tests-numpy-1-22.yml @@ -10,5 +10,6 @@ jobs: package-version: '== 1.22.*' xfails-file-extra: '-1-22' python-versions: '[''3.10'']' + pytest-extra-args: -n 4 extra-env-vars: | ARRAY_API_TESTS_XFAIL_MARK=skip diff --git a/.github/workflows/array-api-tests-numpy-1-26.yml b/.github/workflows/array-api-tests-numpy-1-26.yml index a2788d2f..13124644 100644 --- a/.github/workflows/array-api-tests-numpy-1-26.yml +++ b/.github/workflows/array-api-tests-numpy-1-26.yml @@ -10,5 +10,6 @@ jobs: package-version: '== 1.26.*' xfails-file-extra: '-1-26' python-versions: '[''3.10'', ''3.12'']' + pytest-extra-args: -n 4 extra-env-vars: | ARRAY_API_TESTS_XFAIL_MARK=skip diff --git a/.github/workflows/array-api-tests-numpy-dev.yml b/.github/workflows/array-api-tests-numpy-dev.yml index dce0813f..dec4c7ae 100644 --- a/.github/workflows/array-api-tests-numpy-dev.yml +++ b/.github/workflows/array-api-tests-numpy-dev.yml @@ -10,5 +10,6 @@ jobs: extra-requires: '--pre --extra-index https://pypi.anaconda.org/scientific-python-nightly-wheels/simple' xfails-file-extra: '-dev' python-versions: '[''3.11'', ''3.13'']' + pytest-extra-args: -n 4 extra-env-vars: | ARRAY_API_TESTS_XFAIL_MARK=skip diff --git a/.github/workflows/array-api-tests-numpy-latest.yml b/.github/workflows/array-api-tests-numpy-latest.yml index 54b21a25..65bbc9a2 100644 --- a/.github/workflows/array-api-tests-numpy-latest.yml +++ b/.github/workflows/array-api-tests-numpy-latest.yml @@ -8,5 +8,6 @@ jobs: with: package-name: numpy python-versions: '[''3.10'', ''3.13'']' + pytest-extra-args: -n 4 extra-env-vars: | ARRAY_API_TESTS_XFAIL_MARK=skip diff --git a/.github/workflows/array-api-tests-torch.yml b/.github/workflows/array-api-tests-torch.yml index 4dcb3347..4b4b945e 100644 --- a/.github/workflows/array-api-tests-torch.yml +++ b/.github/workflows/array-api-tests-torch.yml @@ -12,3 +12,4 @@ jobs: ARRAY_API_TESTS_SKIP_DTYPES=uint16,uint32,uint64 ARRAY_API_TESTS_XFAIL_MARK=skip python-versions: '[''3.10'', ''3.13'']' + pytest-extra-args: -n 4 diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index f652438b..53c1474d 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -37,7 +37,7 @@ on: description: Multiline string of environment variables to set for the test run. env: - PYTEST_ARGS: "--max-examples 500 -v -rxXfE --ci ${{ inputs.pytest-extra-args }} --hypothesis-disable-deadline --durations 20" + PYTEST_ARGS: "--max-examples 1000 -v -rxXfE --ci ${{ inputs.pytest-extra-args }} --hypothesis-disable-deadline --durations 20" jobs: tests: @@ -76,6 +76,7 @@ jobs: python -m pip install --upgrade pip python -m pip install '${{ inputs.package-name }} ${{ inputs.package-version }}' ${{ inputs.extra-requires }} python -m pip install -r ${GITHUB_WORKSPACE}/array-api-tests/requirements.txt + python -m pip install pytest-xdist - name: Dump pip environment run: pip freeze From 37d5d668674f10ae41709a7ebc3d2ab2ae6b25c4 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 12 May 2025 16:42:01 +0200 Subject: [PATCH 63/80] MAINT: update numpy-1.22 xfails --- numpy-1-22-xfails.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/numpy-1-22-xfails.txt b/numpy-1-22-xfails.txt index cacb95b7..d4022b31 100644 --- a/numpy-1-22-xfails.txt +++ b/numpy-1-22-xfails.txt @@ -133,7 +133,7 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_sc array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[maximum] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[copysign] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[pow] -array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[pow] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[minimum] array_api_tests/test_operators_and_elementwise_functions.py::test_where_with_scalars From 43435808041951df2c7b7cae28204b3ce61f6e46 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 12 May 2025 16:42:29 +0200 Subject: [PATCH 64/80] MAINT: remove --ci pytest switch The warning says it's deprecated. --- .github/workflows/array-api-tests.yml | 2 +- numpy-1-22-xfails.txt | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index 53c1474d..e832f870 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -37,7 +37,7 @@ on: description: Multiline string of environment variables to set for the test run. env: - PYTEST_ARGS: "--max-examples 1000 -v -rxXfE --ci ${{ inputs.pytest-extra-args }} --hypothesis-disable-deadline --durations 20" + PYTEST_ARGS: "--max-examples 1000 -v -rxXfE ${{ inputs.pytest-extra-args }} --hypothesis-disable-deadline --durations 20" jobs: tests: diff --git a/numpy-1-22-xfails.txt b/numpy-1-22-xfails.txt index d4022b31..e1c4f832 100644 --- a/numpy-1-22-xfails.txt +++ b/numpy-1-22-xfails.txt @@ -134,6 +134,7 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_sc array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[copysign] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[pow] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[minimum] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[atan2] array_api_tests/test_operators_and_elementwise_functions.py::test_where_with_scalars From c03daa36c09d51162d240b77e223a49cc8a6076e Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sun, 1 Jun 2025 13:36:11 +0200 Subject: [PATCH 65/80] CI: install jax/sparse/torch in more jobs Also, `ndonnx` has wheels for all python versions now; And we do not bother with jax or dask numpy < 1. --- .github/workflows/tests.yml | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 81a05b3f..c995b370 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -32,20 +32,20 @@ jobs: python -m pip install --upgrade pip python -m pip install pytest + # Don't `pip install .[dev]` as it would pull in the whole torch cuda stack + python -m pip install array-api-strict + python -m pip install torch --index-url https://download.pytorch.org/whl/cpu + if [ "${{ matrix.numpy-version }}" == "dev" ]; then python -m pip install numpy --pre --extra-index-url https://pypi.anaconda.org/scientific-python-nightly-wheels/simple + python -m pip install dask[array] jax[cpu] sparse ndonnx elif [ "${{ matrix.numpy-version }}" == "1.22" ]; then python -m pip install 'numpy==1.22.*' elif [ "${{ matrix.numpy-version }}" == "1.26" ]; then python -m pip install 'numpy==1.26.*' else - # Don't `pip install .[dev]` as it would pull in the whole torch cuda stack - python -m pip install array-api-strict dask[array] jax[cpu] numpy sparse - python -m pip install torch --index-url https://download.pytorch.org/whl/cpu - if [ "${{ matrix.python-version }}" != "3.13" ]; then - # onnx wheels are not available on Python 3.13 at the moment of writing - python -m pip install ndonnx - fi + python -m pip install numpy + python -m pip install dask[array] jax[cpu] sparse ndonnx fi - name: Dump pip environment From a8e19835092335ab8e1846f1e3dda335d8eb4c4a Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sun, 1 Jun 2025 13:38:31 +0200 Subject: [PATCH 66/80] TST: xfail test_device_to_device with numpy < 2 It assumes that asarray has the copy kwarg, and this is not true in NumPy < 2. --- tests/test_common.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_common.py b/tests/test_common.py index 54b5ed69..85ed032e 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -195,6 +195,9 @@ def test_device_to_device(library, request): xfail(request, reason="Stub raises ValueError") if library == "sparse": xfail(request, reason="No __array_namespace_info__()") + if library == "array_api_strict": + if np.__version__ < "2": + xfail(request, reason="no copy argument of np.asarray") xp = import_(library, wrapper=True) devices = xp.__array_namespace_info__().devices() From 8e3ab3e7c5c6794f66196ec435d2f6bdd1492404 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sun, 1 Jun 2025 13:39:28 +0200 Subject: [PATCH 67/80] MAINT: filter out some warning noise --- tests/test_array_namespace.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/tests/test_array_namespace.py b/tests/test_array_namespace.py index cdb80007..2fbb0339 100644 --- a/tests/test_array_namespace.py +++ b/tests/test_array_namespace.py @@ -23,7 +23,9 @@ def test_array_namespace(library, api_version, use_compat): if library == "ndonnx" and api_version in ("2021.12", "2022.12"): pytest.skip("Unsupported API version") - namespace = array_namespace(array, api_version=api_version, use_compat=use_compat) + with warnings.catch_warnings(): + warnings.simplefilter('ignore', UserWarning) + namespace = array_namespace(array, api_version=api_version, use_compat=use_compat) if use_compat is False or use_compat is None and library not in wrapped_libraries: if library == "jax.numpy" and use_compat is None: @@ -45,10 +47,13 @@ def test_array_namespace(library, api_version, use_compat): if library == "numpy": # check that the same namespace is returned for NumPy scalars - scalar_namespace = array_namespace( - xp.float64(0.0), api_version=api_version, use_compat=use_compat - ) - assert scalar_namespace == namespace + with warnings.catch_warnings(): + warnings.simplefilter('ignore', UserWarning) + + scalar_namespace = array_namespace( + xp.float64(0.0), api_version=api_version, use_compat=use_compat + ) + assert scalar_namespace == namespace # Check that array_namespace works even if jax.experimental.array_api # hasn't been imported yet (it monkeypatches __array_namespace__ @@ -97,7 +102,9 @@ def test_api_version_torch(): torch = import_("torch") x = torch.asarray([1, 2]) torch_ = import_("torch", wrapper=True) - assert array_namespace(x, api_version="2023.12") == torch_ + with warnings.catch_warnings(): + warnings.simplefilter('ignore', UserWarning) + assert array_namespace(x, api_version="2023.12") == torch_ assert array_namespace(x, api_version=None) == torch_ assert array_namespace(x) == torch_ # Should issue a warning From 9959873e351ecab696538650893afc2faef17a38 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 2 Jun 2025 10:10:33 +0000 Subject: [PATCH 68/80] Bump dawidd6/action-download-artifact from 9 to 10 in the actions group Bumps the actions group with 1 update: [dawidd6/action-download-artifact](https://github.com/dawidd6/action-download-artifact). Updates `dawidd6/action-download-artifact` from 9 to 10 - [Release notes](https://github.com/dawidd6/action-download-artifact/releases) - [Commits](https://github.com/dawidd6/action-download-artifact/compare/v9...v10) --- updated-dependencies: - dependency-name: dawidd6/action-download-artifact dependency-version: '10' dependency-type: direct:production update-type: version-update:semver-major dependency-group: actions ... Signed-off-by: dependabot[bot] --- .github/workflows/docs-deploy.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/docs-deploy.yml b/.github/workflows/docs-deploy.yml index fc612588..4e3efb39 100644 --- a/.github/workflows/docs-deploy.yml +++ b/.github/workflows/docs-deploy.yml @@ -13,7 +13,7 @@ jobs: steps: - uses: actions/checkout@v4 - name: Download Artifact - uses: dawidd6/action-download-artifact@v9 + uses: dawidd6/action-download-artifact@v10 with: workflow: docs-build.yml name: docs-build From 6ae28ee9538820ae09ba45d8ef3d15d4a6570900 Mon Sep 17 00:00:00 2001 From: Guido Imperiale Date: Mon, 2 Jun 2025 18:38:34 +0100 Subject: [PATCH 69/80] ENH: speed up `array_namespace` * ENH: speed up `array_namespace` * jax 0.6.1 Reviewed at https://github.com/data-apis/array-api-compat/pull/329 --- array_api_compat/common/_helpers.py | 196 +++++++++++++++------------- tests/test_array_namespace.py | 121 ++++++++--------- 2 files changed, 161 insertions(+), 156 deletions(-) diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 77175d0d..a152e4c0 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -8,6 +8,7 @@ from __future__ import annotations +import enum import inspect import math import sys @@ -485,6 +486,86 @@ def _check_api_version(api_version: str | None) -> None: ) +class _ClsToXPInfo(enum.Enum): + SCALAR = 0 + MAYBE_JAX_ZERO_GRADIENT = 1 + + +@lru_cache(100) +def _cls_to_namespace( + cls: type, + api_version: str | None, + use_compat: bool | None, +) -> tuple[Namespace | None, _ClsToXPInfo | None]: + if use_compat not in (None, True, False): + raise ValueError("use_compat must be None, True, or False") + _use_compat = use_compat in (None, True) + cls_ = cast(Hashable, cls) # Make mypy happy + + if ( + _issubclass_fast(cls_, "numpy", "ndarray") + or _issubclass_fast(cls_, "numpy", "generic") + ): + if use_compat is True: + _check_api_version(api_version) + from .. import numpy as xp + elif use_compat is False: + import numpy as xp # type: ignore[no-redef] + else: + # NumPy 2.0+ have __array_namespace__; however they are not + # yet fully array API compatible. + from .. import numpy as xp # type: ignore[no-redef] + return xp, _ClsToXPInfo.MAYBE_JAX_ZERO_GRADIENT + + # Note: this must happen _after_ the test for np.generic, + # because np.float64 and np.complex128 are subclasses of float and complex. + if issubclass(cls, int | float | complex | type(None)): + return None, _ClsToXPInfo.SCALAR + + if _issubclass_fast(cls_, "cupy", "ndarray"): + if _use_compat: + _check_api_version(api_version) + from .. import cupy as xp # type: ignore[no-redef] + else: + import cupy as xp # type: ignore[no-redef] + return xp, None + + if _issubclass_fast(cls_, "torch", "Tensor"): + if _use_compat: + _check_api_version(api_version) + from .. import torch as xp # type: ignore[no-redef] + else: + import torch as xp # type: ignore[no-redef] + return xp, None + + if _issubclass_fast(cls_, "dask.array", "Array"): + if _use_compat: + _check_api_version(api_version) + from ..dask import array as xp # type: ignore[no-redef] + else: + import dask.array as xp # type: ignore[no-redef] + return xp, None + + # Backwards compatibility for jax<0.4.32 + if _issubclass_fast(cls_, "jax", "Array"): + return _jax_namespace(api_version, use_compat), None + + return None, None + + +def _jax_namespace(api_version: str | None, use_compat: bool | None) -> Namespace: + if use_compat: + raise ValueError("JAX does not have an array-api-compat wrapper") + import jax.numpy as jnp + if not hasattr(jnp, "__array_namespace_info__"): + # JAX v0.4.32 and newer implements the array API directly in jax.numpy. + # For older JAX versions, it is available via jax.experimental.array_api. + # jnp.Array objects gain the __array_namespace__ method. + import jax.experimental.array_api # noqa: F401 + # Test api_version + return jnp.empty(0).__array_namespace__(api_version=api_version) + + def array_namespace( *xs: Array | complex | None, api_version: str | None = None, @@ -553,105 +634,40 @@ def your_function(x, y): is_pydata_sparse_array """ - if use_compat not in [None, True, False]: - raise ValueError("use_compat must be None, True, or False") - - _use_compat = use_compat in [None, True] - namespaces: set[Namespace] = set() for x in xs: - if is_numpy_array(x): - import numpy as np - - from .. import numpy as numpy_namespace - - if use_compat is True: - _check_api_version(api_version) - namespaces.add(numpy_namespace) - elif use_compat is False: - namespaces.add(np) - else: - # numpy 2.0+ have __array_namespace__, however, they are not yet fully array API - # compatible. - namespaces.add(numpy_namespace) - elif is_cupy_array(x): - if _use_compat: - _check_api_version(api_version) - from .. import cupy as cupy_namespace - - namespaces.add(cupy_namespace) - else: - import cupy as cp # pyright: ignore[reportMissingTypeStubs] - - namespaces.add(cp) - elif is_torch_array(x): - if _use_compat: - _check_api_version(api_version) - from .. import torch as torch_namespace - - namespaces.add(torch_namespace) - else: - import torch - - namespaces.add(torch) - elif is_dask_array(x): - if _use_compat: - _check_api_version(api_version) - from ..dask import array as dask_namespace - - namespaces.add(dask_namespace) - else: - import dask.array as da - - namespaces.add(da) - elif is_jax_array(x): - if use_compat is True: - _check_api_version(api_version) - raise ValueError("JAX does not have an array-api-compat wrapper") - elif use_compat is False: - import jax.numpy as jnp - else: - # JAX v0.4.32 and newer implements the array API directly in jax.numpy. - # For older JAX versions, it is available via jax.experimental.array_api. - import jax.numpy - - if hasattr(jax.numpy, "__array_api_version__"): - jnp = jax.numpy - else: - import jax.experimental.array_api as jnp # pyright: ignore[reportMissingImports] - namespaces.add(jnp) - elif is_pydata_sparse_array(x): - if use_compat is True: - _check_api_version(api_version) - raise ValueError("`sparse` does not have an array-api-compat wrapper") - else: - import sparse # pyright: ignore[reportMissingTypeStubs] - # `sparse` is already an array namespace. We do not have a wrapper - # submodule for it. - namespaces.add(sparse) - elif hasattr(x, "__array_namespace__"): - if use_compat is True: + xp, info = _cls_to_namespace(cast(Hashable, type(x)), api_version, use_compat) + if info is _ClsToXPInfo.SCALAR: + continue + + if ( + info is _ClsToXPInfo.MAYBE_JAX_ZERO_GRADIENT + and _is_jax_zero_gradient_array(x) + ): + xp = _jax_namespace(api_version, use_compat) + + if xp is None: + get_ns = getattr(x, "__array_namespace__", None) + if get_ns is None: + raise TypeError(f"{type(x).__name__} is not a supported array type") + if use_compat: raise ValueError( "The given array does not have an array-api-compat wrapper" ) - x = cast("SupportsArrayNamespace[Any]", x) - namespaces.add(x.__array_namespace__(api_version=api_version)) - elif isinstance(x, (bool, int, float, complex, type(None))): - continue - else: - # TODO: Support Python scalars? - raise TypeError(f"{type(x).__name__} is not a supported array type") + xp = get_ns(api_version=api_version) - if not namespaces: - raise TypeError("Unrecognized array input") + namespaces.add(xp) - if len(namespaces) != 1: + try: + (xp,) = namespaces + return xp + except ValueError: + if not namespaces: + raise TypeError( + "array_namespace requires at least one non-scalar array input" + ) raise TypeError(f"Multiple namespaces for array inputs: {namespaces}") - (xp,) = namespaces - - return xp - # backwards compatibility alias get_namespace = array_namespace diff --git a/tests/test_array_namespace.py b/tests/test_array_namespace.py index 2fbb0339..311efc37 100644 --- a/tests/test_array_namespace.py +++ b/tests/test_array_namespace.py @@ -8,42 +8,41 @@ import array_api_compat from array_api_compat import array_namespace -from ._helpers import import_, all_libraries, wrapped_libraries +from ._helpers import all_libraries, wrapped_libraries, xfail + @pytest.mark.parametrize("use_compat", [True, False, None]) -@pytest.mark.parametrize("api_version", [None, "2021.12", "2022.12", "2023.12"]) +@pytest.mark.parametrize( + "api_version", [None, "2021.12", "2022.12", "2023.12", "2024.12"] +) @pytest.mark.parametrize("library", all_libraries) -def test_array_namespace(library, api_version, use_compat): - xp = import_(library) +def test_array_namespace(request, library, api_version, use_compat): + xp = pytest.importorskip(library) array = xp.asarray([1.0, 2.0, 3.0]) if use_compat and library not in wrapped_libraries: pytest.raises(ValueError, lambda: array_namespace(array, use_compat=use_compat)) return - if library == "ndonnx" and api_version in ("2021.12", "2022.12"): - pytest.skip("Unsupported API version") + if (library == "sparse" and api_version in ("2023.12", "2024.12")) or ( + library == "jax.numpy" and api_version in ("2021.12", "2022.12", "2023.12") + ): + xfail(request, "Unsupported API version") with warnings.catch_warnings(): warnings.simplefilter('ignore', UserWarning) namespace = array_namespace(array, api_version=api_version, use_compat=use_compat) if use_compat is False or use_compat is None and library not in wrapped_libraries: - if library == "jax.numpy" and use_compat is None: - import jax.numpy - if hasattr(jax.numpy, "__array_api_version__"): - # JAX v0.4.32 or later uses jax.numpy directly - assert namespace == jax.numpy - else: - # JAX v0.4.31 or earlier uses jax.experimental.array_api - import jax.experimental.array_api - assert namespace == jax.experimental.array_api + if library == "jax.numpy" and not hasattr(xp, "__array_api_version__"): + # Backwards compatibility for JAX <0.4.32 + import jax.experimental.array_api + assert namespace == jax.experimental.array_api else: assert namespace == xp + elif library == "dask.array": + assert namespace == array_api_compat.dask.array else: - if library == "dask.array": - assert namespace == array_api_compat.dask.array - else: - assert namespace == getattr(array_api_compat, library) + assert namespace == getattr(array_api_compat, library) if library == "numpy": # check that the same namespace is returned for NumPy scalars @@ -55,20 +54,20 @@ def test_array_namespace(library, api_version, use_compat): ) assert scalar_namespace == namespace - # Check that array_namespace works even if jax.experimental.array_api - # hasn't been imported yet (it monkeypatches __array_namespace__ - # onto JAX arrays, but we should support them regardless). The only way to - # do this is to use a subprocess, since we cannot un-import it and another - # test probably already imported it. - if library == "jax.numpy" and sys.version_info >= (3, 9): - code = f"""\ + +def test_jax_backwards_compat(): + """On JAX <0.4.32, test that array_namespace works even if + jax.experimental.array_api has not been imported yet. + """ + pytest.importorskip("jax") + code = """\ import sys import jax.numpy import array_api_compat -array = jax.numpy.asarray([1.0, 2.0, 3.0]) +array = jax.numpy.asarray([1.0, 2.0, 3.0]) assert 'jax.experimental.array_api' not in sys.modules -namespace = array_api_compat.array_namespace(array, api_version={api_version!r}) +namespace = array_api_compat.array_namespace(array) if hasattr(jax.numpy, '__array_api_version__'): assert namespace == jax.numpy @@ -76,14 +75,16 @@ def test_array_namespace(library, api_version, use_compat): import jax.experimental.array_api assert namespace == jax.experimental.array_api """ - subprocess.run([sys.executable, "-c", code], check=True) + subprocess.check_call([sys.executable, "-c", code]) + def test_jax_zero_gradient(): - jax = import_("jax") + jax = pytest.importorskip("jax") jx = jax.numpy.arange(4) jax_zero = jax.vmap(jax.grad(jax.numpy.float32, allow_int=True))(jx) assert array_namespace(jax_zero) is array_namespace(jx) + def test_array_namespace_errors(): pytest.raises(TypeError, lambda: array_namespace([1])) pytest.raises(TypeError, lambda: array_namespace()) @@ -92,43 +93,31 @@ def test_array_namespace_errors(): pytest.raises(TypeError, lambda: array_namespace((x, x))) pytest.raises(TypeError, lambda: array_namespace(x, (x, x))) -def test_array_namespace_errors_torch(): - torch = import_("torch") - y = torch.asarray([1, 2]) - x = np.asarray([1, 2]) - pytest.raises(TypeError, lambda: array_namespace(x, y)) -def test_api_version_torch(): - torch = import_("torch") - x = torch.asarray([1, 2]) - torch_ = import_("torch", wrapper=True) - with warnings.catch_warnings(): - warnings.simplefilter('ignore', UserWarning) - assert array_namespace(x, api_version="2023.12") == torch_ - assert array_namespace(x, api_version=None) == torch_ - assert array_namespace(x) == torch_ - # Should issue a warning - with warnings.catch_warnings(record=True) as w: - assert array_namespace(x, api_version="2021.12") == torch_ - assert len(w) == 1 - assert "2021.12" in str(w[0].message) - - # Should issue a warning - with warnings.catch_warnings(record=True) as w: - assert array_namespace(x, api_version="2022.12") == torch_ - assert len(w) == 1 - assert "2022.12" in str(w[0].message) - - pytest.raises(ValueError, lambda: array_namespace(x, api_version="2020.12")) +@pytest.mark.parametrize("library", all_libraries) +def test_array_namespace_many_args(library): + xp = pytest.importorskip(library) + a = xp.asarray(1) + b = xp.asarray(2) + assert array_namespace(a, b) is array_namespace(a) + + +def test_array_namespace_mismatch(): + xp = pytest.importorskip("array_api_strict") + with pytest.raises(TypeError, match="Multiple namespaces"): + array_namespace(np.asarray(1), xp.asarray(1)) + def test_get_namespace(): # Backwards compatible wrapper assert array_api_compat.get_namespace is array_namespace -def test_python_scalars(): - torch = import_("torch") - a = torch.asarray([1, 2]) - xp = import_("torch", wrapper=True) + +@pytest.mark.parametrize("library", all_libraries) +def test_python_scalars(library): + xp = pytest.importorskip(library) + a = xp.asarray([1, 2]) + xp = array_namespace(a) pytest.raises(TypeError, lambda: array_namespace(1)) pytest.raises(TypeError, lambda: array_namespace(1.0)) @@ -136,8 +125,8 @@ def test_python_scalars(): pytest.raises(TypeError, lambda: array_namespace(True)) pytest.raises(TypeError, lambda: array_namespace(None)) - assert array_namespace(a, 1) == xp - assert array_namespace(a, 1.0) == xp - assert array_namespace(a, 1j) == xp - assert array_namespace(a, True) == xp - assert array_namespace(a, None) == xp + assert array_namespace(a, 1) is xp + assert array_namespace(a, 1.0) is xp + assert array_namespace(a, 1j) is xp + assert array_namespace(a, True) is xp + assert array_namespace(a, None) is xp From c9cfc2c9193fcdf0e52a2bbdace54182780839c9 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sun, 1 Jun 2025 11:37:49 +0200 Subject: [PATCH 70/80] TST: add a test that wrapping preserves a view/copy semantics for unary functions If a bare library returns a copy, so does the wrapped library; if the bare library returns a view, so does the wrapped library. --- tests/test_copies_or_views.py | 66 +++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) create mode 100644 tests/test_copies_or_views.py diff --git a/tests/test_copies_or_views.py b/tests/test_copies_or_views.py new file mode 100644 index 00000000..5b9b9207 --- /dev/null +++ b/tests/test_copies_or_views.py @@ -0,0 +1,66 @@ +""" +A collection of tests to make sure that wrapped namespaces agree with the bare ones +on whether to return a view or a copy of inputs. +""" +import pytest +from ._helpers import import_ + + +LIB_NAMES = ['numpy', 'cupy', 'torch', 'dask.array', 'array_api_strict'] + +FUNC_INPUTS = [ + # func_name, arr_input, dtype, scalar_value + ('abs', [1, 2], 'int8', 3), + ('abs', [1, 2], 'float32', 3.), + ('ceil', [1, 2], 'int8', 3), + ('clip', [1, 2], 'int8', 3), + ('conj', [1, 2], 'int8', 3), + ('floor', [1, 2], 'int8', 3), + ('imag', [1j, 2j], 'complex64', 3), + ('positive', [1, 2], 'int8', 3), + ('real', [1., 2.], 'float32', 3.), + ('round', [1, 2], 'int8', 3), + ('sign', [0, 0], 'float32', 3), + ('trunc', [1, 2], 'int8', 3), + ('trunc', [1, 2], 'float32', 3), +] + + +def ensure_unary(func, arr): + """Make a trivial unary function from func.""" + if func.__name__ == 'clip': + return lambda x: func(x, arr[0], arr[1]) + return func + + +def is_view(func, a, value): + """Apply `func`, mutate the output; does the input change?""" + b = func(a) + b[0] = value + return a[0] == value + + +@pytest.mark.parametrize('xp_name', LIB_NAMES) +@pytest.mark.parametrize('inputs', FUNC_INPUTS, ids=[inp[0] for inp in FUNC_INPUTS]) +def test_view_or_copy(inputs, xp_name): + bare_xp = import_(xp_name, wrapper=False) + wrapped_xp = import_(xp_name, wrapper=True) + + func_name, arr_input, dtype_str, value = inputs + dtype = getattr(bare_xp, dtype_str) + + bare_func = getattr(bare_xp, func_name) + bare_func = ensure_unary(bare_func, arr_input) + + wrapped_func = getattr(wrapped_xp, func_name) + wrapped_func = ensure_unary(wrapped_func, arr_input) + + # bare namespace: mutate the output, does the input change? + a = bare_xp.asarray(arr_input, dtype=dtype) + is_view_bare = is_view(bare_func, a, value) + + # wrapped namespace: mutate the output, does the input change? + a1 = wrapped_xp.asarray(arr_input, dtype=dtype) + is_view_wrapped = is_view(wrapped_func, a1, value) + + assert is_view_bare == is_view_wrapped From 1facc3526414926b2d123e88c16f7d517d9d2558 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sun, 1 Jun 2025 14:30:54 +0200 Subject: [PATCH 71/80] BUG: make ceil,trunc,floor always respect view/copy semantics Remove these functions from common/_aliases.py, add specific implementations for numpy < 2 and cupy. --- array_api_compat/common/_aliases.py | 24 -------------------- array_api_compat/cupy/_aliases.py | 24 ++++++++++++++++---- array_api_compat/dask/array/_aliases.py | 3 --- array_api_compat/numpy/_aliases.py | 29 ++++++++++++++++++++++--- 4 files changed, 46 insertions(+), 34 deletions(-) diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 8ea9162a..39d10860 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -524,27 +524,6 @@ def nonzero(x: Array, /, xp: Namespace, **kwargs: object) -> tuple[Array, ...]: return xp.nonzero(x, **kwargs) -# ceil, floor, and trunc return integers for integer inputs - - -def ceil(x: Array, /, xp: Namespace, **kwargs: object) -> Array: - if xp.issubdtype(x.dtype, xp.integer): - return x - return xp.ceil(x, **kwargs) - - -def floor(x: Array, /, xp: Namespace, **kwargs: object) -> Array: - if xp.issubdtype(x.dtype, xp.integer): - return x - return xp.floor(x, **kwargs) - - -def trunc(x: Array, /, xp: Namespace, **kwargs: object) -> Array: - if xp.issubdtype(x.dtype, xp.integer): - return x - return xp.trunc(x, **kwargs) - - # linear algebra functions @@ -707,9 +686,6 @@ def iinfo(type_: DType | Array, /, xp: Namespace) -> Any: "argsort", "sort", "nonzero", - "ceil", - "floor", - "trunc", "matmul", "matrix_transpose", "tensordot", diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index 90b48f05..e000602e 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -54,9 +54,6 @@ argsort = get_xp(cp)(_aliases.argsort) sort = get_xp(cp)(_aliases.sort) nonzero = get_xp(cp)(_aliases.nonzero) -ceil = get_xp(cp)(_aliases.ceil) -floor = get_xp(cp)(_aliases.floor) -trunc = get_xp(cp)(_aliases.trunc) matmul = get_xp(cp)(_aliases.matmul) matrix_transpose = get_xp(cp)(_aliases.matrix_transpose) tensordot = get_xp(cp)(_aliases.tensordot) @@ -123,6 +120,25 @@ def count_nonzero( return cp.expand_dims(result, axis) return result +# ceil, floor, and trunc return integers for integer inputs + +def ceil(x: Array, /) -> Array: + if cp.issubdtype(x.dtype, cp.integer): + return x.copy() + return cp.ceil(x) + + +def floor(x: Array, /) -> Array: + if cp.issubdtype(x.dtype, cp.integer): + return x.copy() + return cp.floor(x) + + +def trunc(x: Array, /) -> Array: + if cp.issubdtype(x.dtype, cp.integer): + return x.copy() + return cp.trunc(x) + # take_along_axis: axis defaults to -1 but in cupy (and numpy) axis is a required arg def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1): @@ -151,6 +167,6 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1): 'atan2', 'atanh', 'bitwise_left_shift', 'bitwise_invert', 'bitwise_right_shift', 'bool', 'concat', 'count_nonzero', 'pow', 'sign', - 'take_along_axis'] + 'ceil', 'floor', 'trunc', 'take_along_axis'] _all_ignore = ['cp', 'get_xp'] diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index d43881ab..0bb5d227 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -134,9 +134,6 @@ def arange( matrix_transpose = get_xp(da)(_aliases.matrix_transpose) vecdot = get_xp(da)(_aliases.vecdot) nonzero = get_xp(da)(_aliases.nonzero) -ceil = get_xp(np)(_aliases.ceil) -floor = get_xp(np)(_aliases.floor) -trunc = get_xp(np)(_aliases.trunc) matmul = get_xp(np)(_aliases.matmul) tensordot = get_xp(np)(_aliases.tensordot) sign = get_xp(np)(_aliases.sign) diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index a1aee5c0..502dfb3a 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -63,9 +63,6 @@ argsort = get_xp(np)(_aliases.argsort) sort = get_xp(np)(_aliases.sort) nonzero = get_xp(np)(_aliases.nonzero) -ceil = get_xp(np)(_aliases.ceil) -floor = get_xp(np)(_aliases.floor) -trunc = get_xp(np)(_aliases.trunc) matmul = get_xp(np)(_aliases.matmul) matrix_transpose = get_xp(np)(_aliases.matrix_transpose) tensordot = get_xp(np)(_aliases.tensordot) @@ -145,6 +142,29 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1): return np.take_along_axis(x, indices, axis=axis) +# ceil, floor, and trunc return integers for integer inputs in NumPy < 2 + +def ceil(x: Array, /) -> Array: + if np.issubdtype(x.dtype, np.integer): + if np.__version__ < '2': + return x.copy() + return np.ceil(x) + + +def floor(x: Array, /) -> Array: + if np.issubdtype(x.dtype, np.integer): + if np.__version__ < '2': + return x.copy() + return np.floor(x) + + +def trunc(x: Array, /) -> Array: + if np.issubdtype(x.dtype, np.integer): + if np.__version__ < '2': + return x.copy() + return np.trunc(x) + + # These functions are completely new here. If the library already has them # (i.e., numpy 2.0), use the library version instead of our wrapper. if hasattr(np, "vecdot"): @@ -173,6 +193,9 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1): "atan", "atan2", "atanh", + "ceil", + "floor", + "trunc", "bitwise_left_shift", "bitwise_invert", "bitwise_right_shift", From 0ad664bdfde03ec3f21d82b1048616ae5d0fb6b7 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Wed, 4 Jun 2025 10:49:43 +0200 Subject: [PATCH 72/80] Apply suggestions from code review Co-authored-by: Guido Imperiale --- tests/test_copies_or_views.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_copies_or_views.py b/tests/test_copies_or_views.py index 5b9b9207..24d03547 100644 --- a/tests/test_copies_or_views.py +++ b/tests/test_copies_or_views.py @@ -3,7 +3,7 @@ on whether to return a view or a copy of inputs. """ import pytest -from ._helpers import import_ +from ._helpers import import_, wrapped_libraries LIB_NAMES = ['numpy', 'cupy', 'torch', 'dask.array', 'array_api_strict'] @@ -40,7 +40,7 @@ def is_view(func, a, value): return a[0] == value -@pytest.mark.parametrize('xp_name', LIB_NAMES) +@pytest.mark.parametrize('xp_name', wrapped_libraries) @pytest.mark.parametrize('inputs', FUNC_INPUTS, ids=[inp[0] for inp in FUNC_INPUTS]) def test_view_or_copy(inputs, xp_name): bare_xp = import_(xp_name, wrapper=False) From 118ae2d0428be763abf1e31b2827a4800398e901 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Wed, 4 Jun 2025 10:52:46 +0200 Subject: [PATCH 73/80] TST: test views vs copies on array-api-strict, too --- tests/test_copies_or_views.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_copies_or_views.py b/tests/test_copies_or_views.py index 24d03547..ec8995f7 100644 --- a/tests/test_copies_or_views.py +++ b/tests/test_copies_or_views.py @@ -6,8 +6,6 @@ from ._helpers import import_, wrapped_libraries -LIB_NAMES = ['numpy', 'cupy', 'torch', 'dask.array', 'array_api_strict'] - FUNC_INPUTS = [ # func_name, arr_input, dtype, scalar_value ('abs', [1, 2], 'int8', 3), @@ -40,7 +38,7 @@ def is_view(func, a, value): return a[0] == value -@pytest.mark.parametrize('xp_name', wrapped_libraries) +@pytest.mark.parametrize('xp_name', wrapped_libraries + ['array_api_strict']) @pytest.mark.parametrize('inputs', FUNC_INPUTS, ids=[inp[0] for inp in FUNC_INPUTS]) def test_view_or_copy(inputs, xp_name): bare_xp = import_(xp_name, wrapper=False) From b0eed557d6dba8c87d9693ff82360b33c1af3480 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Wed, 4 Jun 2025 13:05:08 +0200 Subject: [PATCH 74/80] Apply suggestions from code review Co-authored-by: Guido Imperiale --- array_api_compat/numpy/_aliases.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index 502dfb3a..f04837de 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -145,23 +145,20 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1): # ceil, floor, and trunc return integers for integer inputs in NumPy < 2 def ceil(x: Array, /) -> Array: - if np.issubdtype(x.dtype, np.integer): - if np.__version__ < '2': - return x.copy() + if np.__version__ < '2' and np.issubdtype(x.dtype, np.integer): + return x.copy() return np.ceil(x) def floor(x: Array, /) -> Array: - if np.issubdtype(x.dtype, np.integer): - if np.__version__ < '2': - return x.copy() + if np.__version__ < '2' and np.issubdtype(x.dtype, np.integer): + return x.copy() return np.floor(x) def trunc(x: Array, /) -> Array: - if np.issubdtype(x.dtype, np.integer): - if np.__version__ < '2': - return x.copy() + if np.__version__ < '2' and np.issubdtype(x.dtype, np.integer): + return x.copy() return np.trunc(x) From 2b559e62e05ebea3dd3ab631aee47b270109eaa1 Mon Sep 17 00:00:00 2001 From: Guido Imperiale Date: Wed, 4 Jun 2025 15:36:16 +0100 Subject: [PATCH 75/80] TYP: Type annotations, part 4 (#313) * Type annotations, part 4 * Fix CopyMode * revert * Revert `_all_ignore` * code review * code review * JustInt mypy ignores * lint * fix merge * lint * Reverts and tweaks * Fix test_all * Revert batmobile --- array_api_compat/_internal.py | 4 +- array_api_compat/common/_aliases.py | 13 +- array_api_compat/common/_helpers.py | 34 ++--- array_api_compat/common/_linalg.py | 6 +- array_api_compat/common/_typing.py | 15 +- array_api_compat/cupy/_aliases.py | 31 ++-- array_api_compat/cupy/fft.py | 9 +- array_api_compat/cupy/linalg.py | 2 +- array_api_compat/dask/array/__init__.py | 2 +- array_api_compat/dask/array/_aliases.py | 2 +- array_api_compat/dask/array/_info.py | 47 +++--- array_api_compat/dask/array/fft.py | 2 +- array_api_compat/dask/array/linalg.py | 19 +-- array_api_compat/numpy/__init__.py | 2 +- array_api_compat/numpy/_aliases.py | 33 ++-- array_api_compat/numpy/_info.py | 3 +- array_api_compat/numpy/linalg.py | 4 +- array_api_compat/torch/_aliases.py | 192 ++++++++++++------------ array_api_compat/torch/fft.py | 19 +-- array_api_compat/torch/linalg.py | 12 +- pyproject.toml | 56 ++++--- 21 files changed, 246 insertions(+), 261 deletions(-) diff --git a/array_api_compat/_internal.py b/array_api_compat/_internal.py index cd8d939f..b1925492 100644 --- a/array_api_compat/_internal.py +++ b/array_api_compat/_internal.py @@ -46,8 +46,8 @@ def wrapped_f(*args: object, **kwargs: object) -> object: specification for more details. """ - wrapped_f.__signature__ = new_sig # pyright: ignore[reportAttributeAccessIssue] - return wrapped_f # pyright: ignore[reportReturnType] + wrapped_f.__signature__ = new_sig # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] + return wrapped_f # type: ignore[return-value] # pyright: ignore[reportReturnType] return inner diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 8ea9162a..51732b91 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -5,11 +5,12 @@ from __future__ import annotations import inspect -from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Sequence, cast +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any, NamedTuple, cast from ._helpers import _check_device, array_namespace from ._helpers import device as _get_device -from ._helpers import is_cupy_namespace as _is_cupy_namespace +from ._helpers import is_cupy_namespace from ._typing import Array, Device, DType, Namespace if TYPE_CHECKING: @@ -381,8 +382,8 @@ def clip( # TODO: np.clip has other ufunc kwargs out: Array | None = None, ) -> Array: - def _isscalar(a: object) -> TypeIs[int | float | None]: - return isinstance(a, (int, float, type(None))) + def _isscalar(a: object) -> TypeIs[float | None]: + return isinstance(a, int | float) or a is None min_shape = () if _isscalar(min) else min.shape max_shape = () if _isscalar(max) else max.shape @@ -450,7 +451,7 @@ def reshape( shape: tuple[int, ...], xp: Namespace, *, - copy: Optional[bool] = None, + copy: bool | None = None, **kwargs: object, ) -> Array: if copy is True: @@ -657,7 +658,7 @@ def sign(x: Array, /, xp: Namespace, **kwargs: object) -> Array: out = xp.sign(x, **kwargs) # CuPy sign() does not propagate nans. See # https://github.com/data-apis/array-api-compat/issues/136 - if _is_cupy_namespace(xp) and isdtype(x.dtype, "real floating", xp=xp): + if is_cupy_namespace(xp) and isdtype(x.dtype, "real floating", xp=xp): out[xp.isnan(x)] = xp.nan return out[()] diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index a152e4c0..cae0ee0b 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -23,7 +23,6 @@ SupportsIndex, TypeAlias, TypeGuard, - TypeVar, cast, overload, ) @@ -31,32 +30,29 @@ from ._typing import Array, Device, HasShape, Namespace, SupportsArrayNamespace if TYPE_CHECKING: - + import cupy as cp import dask.array as da import jax import ndonnx as ndx import numpy as np import numpy.typing as npt - import sparse # pyright: ignore[reportMissingTypeStubs] + import sparse import torch # TODO: import from typing (requires Python >=3.13) - from typing_extensions import TypeIs, TypeVar - - _SizeT = TypeVar("_SizeT", bound = int | None) + from typing_extensions import TypeIs _ZeroGradientArray: TypeAlias = npt.NDArray[np.void] - _CupyArray: TypeAlias = Any # cupy has no py.typed _ArrayApiObj: TypeAlias = ( npt.NDArray[Any] + | cp.ndarray | da.Array | jax.Array | ndx.Array | sparse.SparseArray | torch.Tensor | SupportsArrayNamespace[Any] - | _CupyArray ) _API_VERSIONS_OLD: Final = frozenset({"2021.12", "2022.12", "2023.12"}) @@ -96,7 +92,7 @@ def _is_jax_zero_gradient_array(x: object) -> TypeGuard[_ZeroGradientArray]: return dtype == jax.float0 -def is_numpy_array(x: object) -> TypeGuard[npt.NDArray[Any]]: +def is_numpy_array(x: object) -> TypeIs[npt.NDArray[Any]]: """ Return True if `x` is a NumPy array. @@ -267,7 +263,7 @@ def is_pydata_sparse_array(x: object) -> TypeIs[sparse.SparseArray]: return _issubclass_fast(cls, "sparse", "SparseArray") -def is_array_api_obj(x: object) -> TypeIs[_ArrayApiObj]: # pyright: ignore[reportUnknownParameterType] +def is_array_api_obj(x: object) -> TypeGuard[_ArrayApiObj]: """ Return True if `x` is an array API compatible array object. @@ -748,7 +744,7 @@ def device(x: _ArrayApiObj, /) -> Device: return "cpu" elif is_dask_array(x): # Peek at the metadata of the Dask array to determine type - if is_numpy_array(x._meta): # pyright: ignore + if is_numpy_array(x._meta): # Must be on CPU since backed by numpy return "cpu" return _DASK_DEVICE @@ -777,7 +773,7 @@ def device(x: _ArrayApiObj, /) -> Device: return "cpu" # Return the device of the constituent array return device(inner) # pyright: ignore - return x.device # pyright: ignore + return x.device # type: ignore # pyright: ignore # Prevent shadowing, used below @@ -786,11 +782,11 @@ def device(x: _ArrayApiObj, /) -> Device: # Based on cupy.array_api.Array.to_device def _cupy_to_device( - x: _CupyArray, + x: cp.ndarray, device: Device, /, stream: int | Any | None = None, -) -> _CupyArray: +) -> cp.ndarray: import cupy as cp if device == "cpu": @@ -819,7 +815,7 @@ def _torch_to_device( x: torch.Tensor, device: torch.device | str | int, /, - stream: None = None, + stream: int | Any | None = None, ) -> torch.Tensor: if stream is not None: raise NotImplementedError @@ -885,7 +881,7 @@ def to_device(x: Array, device: Device, /, *, stream: int | Any | None = None) - # cupy does not yet have to_device return _cupy_to_device(x, device, stream=stream) elif is_torch_array(x): - return _torch_to_device(x, device, stream=stream) # pyright: ignore[reportArgumentType] + return _torch_to_device(x, device, stream=stream) elif is_dask_array(x): if stream is not None: raise ValueError("The stream argument to to_device() is not supported") @@ -912,8 +908,6 @@ def to_device(x: Array, device: Device, /, *, stream: int | Any | None = None) - @overload def size(x: HasShape[Collection[SupportsIndex]]) -> int: ... @overload -def size(x: HasShape[Collection[None]]) -> None: ... -@overload def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None: ... def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None: """ @@ -948,7 +942,7 @@ def _is_writeable_cls(cls: type) -> bool | None: return None -def is_writeable_array(x: object) -> bool: +def is_writeable_array(x: object) -> TypeGuard[_ArrayApiObj]: """ Return False if ``x.__setitem__`` is expected to raise; True otherwise. Return False if `x` is not an array API compatible object. @@ -986,7 +980,7 @@ def _is_lazy_cls(cls: type) -> bool | None: return None -def is_lazy_array(x: object) -> bool: +def is_lazy_array(x: object) -> TypeGuard[_ArrayApiObj]: """Return True if x is potentially a future or it may be otherwise impossible or expensive to eagerly read its contents, regardless of their size, e.g. by calling ``bool(x)`` or ``float(x)``. diff --git a/array_api_compat/common/_linalg.py b/array_api_compat/common/_linalg.py index 7ad87a1b..3fd9d860 100644 --- a/array_api_compat/common/_linalg.py +++ b/array_api_compat/common/_linalg.py @@ -8,7 +8,7 @@ if np.__version__[0] == "2": from numpy.lib.array_utils import normalize_axis_tuple else: - from numpy.core.numeric import normalize_axis_tuple + from numpy.core.numeric import normalize_axis_tuple # type: ignore[no-redef] from .._internal import get_xp from ._aliases import isdtype, matmul, matrix_transpose, tensordot, vecdot @@ -187,14 +187,14 @@ def vector_norm( # We can't reuse xp.linalg.norm(keepdims) because of the reshape hacks # above to avoid matrix norm logic. shape = list(x.shape) - _axis = cast( + axes = cast( "tuple[int, ...]", normalize_axis_tuple( # pyright: ignore[reportCallIssue] range(x.ndim) if axis is None else axis, x.ndim, ), ) - for i in _axis: + for i in axes: shape[i] = 1 res = xp.reshape(res, tuple(shape)) diff --git a/array_api_compat/common/_typing.py b/array_api_compat/common/_typing.py index cd26feeb..11b00bd1 100644 --- a/array_api_compat/common/_typing.py +++ b/array_api_compat/common/_typing.py @@ -34,32 +34,29 @@ # - docs: https://github.com/jorenham/optype/blob/master/README.md#just # - code: https://github.com/jorenham/optype/blob/master/optype/_core/_just.py @final -class JustInt(Protocol): - @property +class JustInt(Protocol): # type: ignore[misc] + @property # type: ignore[override] def __class__(self, /) -> type[int]: ... @__class__.setter def __class__(self, value: type[int], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride] @final -class JustFloat(Protocol): - @property +class JustFloat(Protocol): # type: ignore[misc] + @property # type: ignore[override] def __class__(self, /) -> type[float]: ... @__class__.setter def __class__(self, value: type[float], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride] @final -class JustComplex(Protocol): - @property +class JustComplex(Protocol): # type: ignore[misc] + @property # type: ignore[override] def __class__(self, /) -> type[complex]: ... @__class__.setter def __class__(self, value: type[complex], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride] -# - - class NestedSequence(Protocol[_T_co]): def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ... def __len__(self, /) -> int: ... diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index 90b48f05..c0473ca4 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Optional +from builtins import bool as py_bool import cupy as cp @@ -67,18 +67,13 @@ # asarray also adds the copy keyword, which is not present in numpy 1.0. def asarray( - obj: ( - Array - | bool | int | float | complex - | NestedSequence[bool | int | float | complex] - | SupportsBufferProtocol - ), + obj: Array | complex | NestedSequence[complex] | SupportsBufferProtocol, /, *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - copy: Optional[bool] = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + copy: py_bool | None = None, + **kwargs: object, ) -> Array: """ Array API compatibility wrapper for asarray(). @@ -101,8 +96,8 @@ def astype( dtype: DType, /, *, - copy: bool = True, - device: Optional[Device] = None, + copy: py_bool = True, + device: Device | None = None, ) -> Array: if device is None: return x.astype(dtype=dtype, copy=copy) @@ -113,8 +108,8 @@ def astype( # cupy.count_nonzero does not have keepdims def count_nonzero( x: Array, - axis=None, - keepdims=False + axis: int | tuple[int, ...] | None = None, + keepdims: py_bool = False, ) -> Array: result = cp.count_nonzero(x, axis) if keepdims: @@ -125,7 +120,7 @@ def count_nonzero( # take_along_axis: axis defaults to -1 but in cupy (and numpy) axis is a required arg -def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1): +def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array: return cp.take_along_axis(x, indices, axis=axis) @@ -153,4 +148,6 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1): 'bool', 'concat', 'count_nonzero', 'pow', 'sign', 'take_along_axis'] -_all_ignore = ['cp', 'get_xp'] + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/cupy/fft.py b/array_api_compat/cupy/fft.py index 307e0f72..2bd11940 100644 --- a/array_api_compat/cupy/fft.py +++ b/array_api_compat/cupy/fft.py @@ -1,10 +1,11 @@ -from cupy.fft import * # noqa: F403 +from cupy.fft import * # noqa: F403 + # cupy.fft doesn't have __all__. If it is added, replace this with # # from cupy.fft import __all__ as linalg_all -_n = {} -exec('from cupy.fft import *', _n) -del _n['__builtins__'] +_n: dict[str, object] = {} +exec("from cupy.fft import *", _n) +del _n["__builtins__"] fft_all = list(_n) del _n diff --git a/array_api_compat/cupy/linalg.py b/array_api_compat/cupy/linalg.py index 7fcdd498..7bc3536e 100644 --- a/array_api_compat/cupy/linalg.py +++ b/array_api_compat/cupy/linalg.py @@ -2,7 +2,7 @@ # cupy.linalg doesn't have __all__. If it is added, replace this with # # from cupy.linalg import __all__ as linalg_all -_n = {} +_n: dict[str, object] = {} exec('from cupy.linalg import *', _n) del _n['__builtins__'] linalg_all = list(_n) diff --git a/array_api_compat/dask/array/__init__.py b/array_api_compat/dask/array/__init__.py index 1e47b960..6d2ea7cd 100644 --- a/array_api_compat/dask/array/__init__.py +++ b/array_api_compat/dask/array/__init__.py @@ -3,7 +3,7 @@ from dask.array import * # noqa: F403 # These imports may overwrite names from the import * above. -from ._aliases import * # noqa: F403 +from ._aliases import * # type: ignore[assignment] # noqa: F403 __array_api_version__: Final = "2024.12" diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index d43881ab..bc0302fe 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -146,7 +146,7 @@ def arange( # asarray also adds the copy keyword, which is not present in numpy 1.0. def asarray( - obj: complex | NestedSequence[complex] | Array | SupportsBufferProtocol, + obj: Array | complex | NestedSequence[complex] | SupportsBufferProtocol, /, *, dtype: DType | None = None, diff --git a/array_api_compat/dask/array/_info.py b/array_api_compat/dask/array/_info.py index 9e4d736f..2f39fc4b 100644 --- a/array_api_compat/dask/array/_info.py +++ b/array_api_compat/dask/array/_info.py @@ -12,9 +12,9 @@ from __future__ import annotations -from typing import Literal as L -from typing import TypeAlias, overload +from typing import Literal, TypeAlias, overload +import dask.array as da from numpy import bool_ as bool from numpy import ( complex64, @@ -33,7 +33,7 @@ uint64, ) -from ...common._helpers import _DASK_DEVICE, _dask_device +from ...common._helpers import _DASK_DEVICE, _check_device, _dask_device from ...common._typing import ( Capabilities, DefaultDTypes, @@ -49,8 +49,7 @@ DTypesSigned, DTypesUnsigned, ) - -_Device: TypeAlias = L["cpu"] | _dask_device +Device: TypeAlias = Literal["cpu"] | _dask_device class __array_namespace_info__: @@ -142,7 +141,7 @@ def capabilities(self) -> Capabilities: "max dimensions": 64, } - def default_device(self) -> L["cpu"]: + def default_device(self) -> Device: """ The default device used for new Dask arrays. @@ -169,7 +168,7 @@ def default_device(self) -> L["cpu"]: """ return "cpu" - def default_dtypes(self, /, *, device: _Device | None = None) -> DefaultDTypes: + def default_dtypes(self, /, *, device: Device | None = None) -> DefaultDTypes: """ The default data types used for new Dask arrays. @@ -208,11 +207,7 @@ def default_dtypes(self, /, *, device: _Device | None = None) -> DefaultDTypes: 'indexing': dask.int64} """ - if device not in ["cpu", _DASK_DEVICE, None]: - raise ValueError( - f'Device not understood. Only "cpu" or _DASK_DEVICE is allowed, ' - f"but received: {device!r}" - ) + _check_device(da, device) return { "real floating": dtype(float64), "complex floating": dtype(complex128), @@ -222,38 +217,38 @@ def default_dtypes(self, /, *, device: _Device | None = None) -> DefaultDTypes: @overload def dtypes( - self, /, *, device: _Device | None = None, kind: None = None + self, /, *, device: Device | None = None, kind: None = None ) -> DTypesAll: ... @overload def dtypes( - self, /, *, device: _Device | None = None, kind: L["bool"] + self, /, *, device: Device | None = None, kind: Literal["bool"] ) -> DTypesBool: ... @overload def dtypes( - self, /, *, device: _Device | None = None, kind: L["signed integer"] + self, /, *, device: Device | None = None, kind: Literal["signed integer"] ) -> DTypesSigned: ... @overload def dtypes( - self, /, *, device: _Device | None = None, kind: L["unsigned integer"] + self, /, *, device: Device | None = None, kind: Literal["unsigned integer"] ) -> DTypesUnsigned: ... @overload def dtypes( - self, /, *, device: _Device | None = None, kind: L["integral"] + self, /, *, device: Device | None = None, kind: Literal["integral"] ) -> DTypesIntegral: ... @overload def dtypes( - self, /, *, device: _Device | None = None, kind: L["real floating"] + self, /, *, device: Device | None = None, kind: Literal["real floating"] ) -> DTypesReal: ... @overload def dtypes( - self, /, *, device: _Device | None = None, kind: L["complex floating"] + self, /, *, device: Device | None = None, kind: Literal["complex floating"] ) -> DTypesComplex: ... @overload def dtypes( - self, /, *, device: _Device | None = None, kind: L["numeric"] + self, /, *, device: Device | None = None, kind: Literal["numeric"] ) -> DTypesNumeric: ... def dtypes( - self, /, *, device: _Device | None = None, kind: DTypeKind | None = None + self, /, *, device: Device | None = None, kind: DTypeKind | None = None ) -> DTypesAny: """ The array API data types supported by Dask. @@ -308,11 +303,7 @@ def dtypes( 'int64': dask.int64} """ - if device not in ["cpu", _DASK_DEVICE, None]: - raise ValueError( - 'Device not understood. Only "cpu" or _DASK_DEVICE is allowed, but received:' - f" {device}" - ) + _check_device(da, device) if kind is None: return { "bool": dtype(bool), @@ -381,14 +372,14 @@ def dtypes( "complex64": dtype(complex64), "complex128": dtype(complex128), } - if isinstance(kind, tuple): # type: ignore[reportUnnecessaryIsinstanceCall] + if isinstance(kind, tuple): res: dict[str, DType] = {} for k in kind: res.update(self.dtypes(kind=k)) return res raise ValueError(f"unsupported kind: {kind!r}") - def devices(self) -> list[_Device]: + def devices(self) -> list[Device]: """ The devices supported by Dask. diff --git a/array_api_compat/dask/array/fft.py b/array_api_compat/dask/array/fft.py index 3f40dffe..68c4280e 100644 --- a/array_api_compat/dask/array/fft.py +++ b/array_api_compat/dask/array/fft.py @@ -2,7 +2,7 @@ # dask.array.fft doesn't have __all__. If it is added, replace this with # # from dask.array.fft import __all__ as linalg_all -_n = {} +_n: dict[str, object] = {} exec('from dask.array.fft import *', _n) for k in ("__builtins__", "Sequence", "annotations", "warnings"): _n.pop(k, None) diff --git a/array_api_compat/dask/array/linalg.py b/array_api_compat/dask/array/linalg.py index 0825386e..06f596bc 100644 --- a/array_api_compat/dask/array/linalg.py +++ b/array_api_compat/dask/array/linalg.py @@ -4,21 +4,22 @@ import dask.array as da -# The `matmul` and `tensordot` functions are in both the main and linalg namespaces -from dask.array import matmul, outer, tensordot - # Exports from dask.array.linalg import * # noqa: F403 +from dask.array import outer +# The `matmul` and `tensordot` functions are in both the main and linalg namespaces +from dask.array import matmul, tensordot + from ..._internal import get_xp from ...common import _linalg -from ...common._typing import Array as _Array +from ...common._typing import Array from ._aliases import matrix_transpose, vecdot # dask.array.linalg doesn't have __all__. If it is added, replace this with # # from dask.array.linalg import __all__ as linalg_all -_n = {} +_n: dict[str, object] = {} exec('from dask.array.linalg import *', _n) for k in ('__builtins__', 'annotations', 'operator', 'warnings', 'Array'): _n.pop(k, None) @@ -33,8 +34,8 @@ # supports the mode keyword on QR # https://github.com/dask/dask/issues/10388 #qr = get_xp(da)(_linalg.qr) -def qr( - x: _Array, +def qr( # type: ignore[no-redef] + x: Array, mode: Literal["reduced", "complete"] = "reduced", **kwargs: object, ) -> QRResult: @@ -50,12 +51,12 @@ def qr( # Wrap the svd functions to not pass full_matrices to dask # when full_matrices=False (as that is the default behavior for dask), # and dask doesn't have the full_matrices keyword -def svd(x: _Array, full_matrices: bool = True, **kwargs) -> SVDResult: +def svd(x: Array, full_matrices: bool = True, **kwargs: object) -> SVDResult: # type: ignore[no-redef] if full_matrices: raise ValueError("full_matrics=True is not supported by dask.") return da.linalg.svd(x, coerce_signs=False, **kwargs) -def svdvals(x: _Array) -> _Array: +def svdvals(x: Array) -> Array: # TODO: can't avoid computing U or V for dask _, s, _ = svd(x) return s diff --git a/array_api_compat/numpy/__init__.py b/array_api_compat/numpy/__init__.py index 3e138f53..bf43fe61 100644 --- a/array_api_compat/numpy/__init__.py +++ b/array_api_compat/numpy/__init__.py @@ -10,7 +10,7 @@ from numpy import round as round # These imports may overwrite names from the import * above. -from ._aliases import * # noqa: F403 +from ._aliases import * # type: ignore[assignment,no-redef] # noqa: F403 # Don't know why, but we have to do an absolute import to import linalg. If we # instead do diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index a1aee5c0..5a05a820 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -2,7 +2,7 @@ from __future__ import annotations from builtins import bool as py_bool -from typing import TYPE_CHECKING, Any, Literal, TypeAlias, cast +from typing import Any, cast import numpy as np @@ -12,13 +12,6 @@ from ._info import __array_namespace_info__ from ._typing import Array, Device, DType -if TYPE_CHECKING: - from typing_extensions import Buffer, TypeIs - -# The values of the `_CopyMode` enum can be either `False`, `True`, or `2`: -# https://github.com/numpy/numpy/blob/5a8a6a79d9c2fff8f07dcab5d41e14f8508d673f/numpy/_globals.pyi#L7-L10 -_Copy: TypeAlias = py_bool | Literal[2] | np._CopyMode - bool = np.bool_ # Basic renames @@ -74,14 +67,6 @@ iinfo = get_xp(np)(_aliases.iinfo) -def _supports_buffer_protocol(obj: object) -> TypeIs[Buffer]: # pyright: ignore[reportUnusedFunction] - try: - memoryview(obj) # pyright: ignore[reportArgumentType] - except TypeError: - return False - return True - - # asarray also adds the copy keyword, which is not present in numpy 1.0. # asarray() is different enough between numpy, cupy, and dask, the logic # complicated enough that it's easier to define it separately for each module @@ -92,7 +77,7 @@ def asarray( *, dtype: DType | None = None, device: Device | None = None, - copy: _Copy | None = None, + copy: py_bool | None = None, **kwargs: Any, ) -> Array: """ @@ -103,14 +88,14 @@ def asarray( """ _helpers._check_device(np, device) + # None is unsupported in NumPy 1.0, but we can use an internal enum + # False in NumPy 1.0 means None in NumPy 2.0 and in the Array API if copy is None: - copy = np._CopyMode.IF_NEEDED + copy = np._CopyMode.IF_NEEDED # type: ignore[assignment,attr-defined] elif copy is False: - copy = np._CopyMode.NEVER - elif copy is True: - copy = np._CopyMode.ALWAYS + copy = np._CopyMode.NEVER # type: ignore[assignment,attr-defined] - return np.array(obj, copy=copy, dtype=dtype, **kwargs) # pyright: ignore + return np.array(obj, copy=copy, dtype=dtype, **kwargs) def astype( @@ -141,7 +126,7 @@ def count_nonzero( # take_along_axis: axis defaults to -1 but in numpy axis is a required arg -def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1): +def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array: return np.take_along_axis(x, indices, axis=axis) @@ -150,7 +135,7 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1): if hasattr(np, "vecdot"): vecdot = np.vecdot else: - vecdot = get_xp(np)(_aliases.vecdot) + vecdot = get_xp(np)(_aliases.vecdot) # type: ignore[assignment] if hasattr(np, "isdtype"): isdtype = np.isdtype diff --git a/array_api_compat/numpy/_info.py b/array_api_compat/numpy/_info.py index f307f62c..c625c13e 100644 --- a/array_api_compat/numpy/_info.py +++ b/array_api_compat/numpy/_info.py @@ -27,6 +27,7 @@ uint64, ) +from ..common._typing import DefaultDTypes from ._typing import Device, DType @@ -139,7 +140,7 @@ def default_dtypes( self, *, device: Device | None = None, - ) -> dict[str, dtype[intp | float64 | complex128]]: + ) -> DefaultDTypes: """ The default data types used for new NumPy arrays. diff --git a/array_api_compat/numpy/linalg.py b/array_api_compat/numpy/linalg.py index 2d3e731d..9a618be9 100644 --- a/array_api_compat/numpy/linalg.py +++ b/array_api_compat/numpy/linalg.py @@ -65,7 +65,7 @@ # https://github.com/cupy/cupy/blob/main/cupy/cublas.py#L43). def solve(x1: Array, x2: Array, /) -> Array: try: - from numpy.linalg._linalg import ( + from numpy.linalg._linalg import ( # type: ignore[attr-defined] _assert_stacked_2d, _assert_stacked_square, _commonType, @@ -74,7 +74,7 @@ def solve(x1: Array, x2: Array, /) -> Array: isComplexType, ) except ImportError: - from numpy.linalg.linalg import ( + from numpy.linalg.linalg import ( # type: ignore[attr-defined] _assert_stacked_2d, _assert_stacked_square, _commonType, diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index de5d1a5d..7a449001 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -1,8 +1,9 @@ from __future__ import annotations +from collections.abc import Sequence from functools import reduce as _reduce, wraps as _wraps from builtins import all as _builtin_all, any as _builtin_any -from typing import Any, List, Optional, Sequence, Tuple, Union, Literal +from typing import Any, Literal import torch @@ -96,9 +97,7 @@ def _fix_promotion(x1, x2, only_scalar=True): _py_scalars = (bool, int, float, complex) -def result_type( - *arrays_and_dtypes: Array | DType | bool | int | float | complex -) -> DType: +def result_type(*arrays_and_dtypes: Array | DType | complex) -> DType: num = len(arrays_and_dtypes) if num == 0: @@ -129,10 +128,7 @@ def result_type( return _reduce(_result_type, others + scalars) -def _result_type( - x: Array | DType | bool | int | float | complex, - y: Array | DType | bool | int | float | complex, -) -> DType: +def _result_type(x: Array | DType | complex, y: Array | DType | complex) -> DType: if not (isinstance(x, _py_scalars) or isinstance(y, _py_scalars)): xdt = x if isinstance(x, torch.dtype) else x.dtype ydt = y if isinstance(y, torch.dtype) else y.dtype @@ -150,7 +146,7 @@ def _result_type( return torch.result_type(x, y) -def can_cast(from_: Union[DType, Array], to: DType, /) -> bool: +def can_cast(from_: DType | Array, to: DType, /) -> bool: if not isinstance(from_, torch.dtype): from_ = from_.dtype return torch.can_cast(from_, to) @@ -194,12 +190,7 @@ def can_cast(from_: Union[DType, Array], to: DType, /) -> bool: def asarray( - obj: ( - Array - | bool | int | float | complex - | NestedSequence[bool | int | float | complex] - | SupportsBufferProtocol - ), + obj: Array | complex | NestedSequence[complex] | SupportsBufferProtocol, /, *, dtype: DType | None = None, @@ -218,13 +209,13 @@ def asarray( # of 'axis'. # torch.min and torch.max return a tuple and don't support multiple axes https://github.com/pytorch/pytorch/issues/58745 -def max(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Array: +def max(x: Array, /, *, axis: int | tuple[int, ...] | None = None, keepdims: bool = False) -> Array: # https://github.com/pytorch/pytorch/issues/29137 if axis == (): return torch.clone(x) return torch.amax(x, axis, keepdims=keepdims) -def min(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Array: +def min(x: Array, /, *, axis: int | tuple[int, ...] |None = None, keepdims: bool = False) -> Array: # https://github.com/pytorch/pytorch/issues/29137 if axis == (): return torch.clone(x) @@ -240,7 +231,15 @@ def min(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keep # torch.sort also returns a tuple # https://github.com/pytorch/pytorch/issues/70921 -def sort(x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True, **kwargs) -> Array: +def sort( + x: Array, + /, + *, + axis: int = -1, + descending: bool = False, + stable: bool = True, + **kwargs: object, +) -> Array: return torch.sort(x, dim=axis, descending=descending, stable=stable, **kwargs).values def _normalize_axes(axis, ndim): @@ -307,10 +306,10 @@ def _sum_prod_no_axis(x: Array, dtype: DType | None) -> Array: def prod(x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - dtype: Optional[DType] = None, + axis: int | tuple[int, ...] | None = None, + dtype: DType | None = None, keepdims: bool = False, - **kwargs) -> Array: + **kwargs: object) -> Array: if axis == (): return _sum_prod_no_axis(x, dtype) @@ -331,10 +330,10 @@ def prod(x: Array, def sum(x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - dtype: Optional[DType] = None, + axis: int | tuple[int, ...] | None = None, + dtype: DType | None = None, keepdims: bool = False, - **kwargs) -> Array: + **kwargs: object) -> Array: if axis == (): return _sum_prod_no_axis(x, dtype) @@ -350,9 +349,9 @@ def sum(x: Array, def any(x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, + axis: int | tuple[int, ...] | None = None, keepdims: bool = False, - **kwargs) -> Array: + **kwargs: object) -> Array: if axis == (): return x.to(torch.bool) @@ -374,9 +373,9 @@ def any(x: Array, def all(x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, + axis: int | tuple[int, ...] | None = None, keepdims: bool = False, - **kwargs) -> Array: + **kwargs: object) -> Array: if axis == (): return x.to(torch.bool) @@ -398,9 +397,9 @@ def all(x: Array, def mean(x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, + axis: int | tuple[int, ...] | None = None, keepdims: bool = False, - **kwargs) -> Array: + **kwargs: object) -> Array: # https://github.com/pytorch/pytorch/issues/29137 if axis == (): return torch.clone(x) @@ -415,10 +414,10 @@ def mean(x: Array, def std(x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - correction: Union[int, float] = 0.0, + axis: int | tuple[int, ...] | None = None, + correction: float = 0.0, keepdims: bool = False, - **kwargs) -> Array: + **kwargs: object) -> Array: # Note, float correction is not supported # https://github.com/pytorch/pytorch/issues/61492. We don't try to # implement it here for now. @@ -446,10 +445,10 @@ def std(x: Array, def var(x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - correction: Union[int, float] = 0.0, + axis: int | tuple[int, ...] | None = None, + correction: float = 0.0, keepdims: bool = False, - **kwargs) -> Array: + **kwargs: object) -> Array: # Note, float correction is not supported # https://github.com/pytorch/pytorch/issues/61492. We don't try to # implement it here for now. @@ -472,11 +471,11 @@ def var(x: Array, # torch.concat doesn't support dim=None # https://github.com/pytorch/pytorch/issues/70925 -def concat(arrays: Union[Tuple[Array, ...], List[Array]], +def concat(arrays: tuple[Array, ...] | list[Array], /, *, - axis: Optional[int] = 0, - **kwargs) -> Array: + axis: int | None = 0, + **kwargs: object) -> Array: if axis is None: arrays = tuple(ar.flatten() for ar in arrays) axis = 0 @@ -485,7 +484,7 @@ def concat(arrays: Union[Tuple[Array, ...], List[Array]], # torch.squeeze only accepts int dim and doesn't require it # https://github.com/pytorch/pytorch/issues/70924. Support for tuple dim was # added at https://github.com/pytorch/pytorch/pull/89017. -def squeeze(x: Array, /, axis: Union[int, Tuple[int, ...]]) -> Array: +def squeeze(x: Array, /, axis: int | tuple[int, ...]) -> Array: if isinstance(axis, int): axis = (axis,) for a in axis: @@ -499,27 +498,27 @@ def squeeze(x: Array, /, axis: Union[int, Tuple[int, ...]]) -> Array: return x # torch.broadcast_to uses size instead of shape -def broadcast_to(x: Array, /, shape: Tuple[int, ...], **kwargs) -> Array: +def broadcast_to(x: Array, /, shape: tuple[int, ...], **kwargs: object) -> Array: return torch.broadcast_to(x, shape, **kwargs) # torch.permute uses dims instead of axes -def permute_dims(x: Array, /, axes: Tuple[int, ...]) -> Array: +def permute_dims(x: Array, /, axes: tuple[int, ...]) -> Array: return torch.permute(x, axes) # The axis parameter doesn't work for flip() and roll() # https://github.com/pytorch/pytorch/issues/71210. Also torch.flip() doesn't # accept axis=None -def flip(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs) -> Array: +def flip(x: Array, /, *, axis: int | tuple[int, ...] | None = None, **kwargs: object) -> Array: if axis is None: axis = tuple(range(x.ndim)) # torch.flip doesn't accept dim as an int but the method does # https://github.com/pytorch/pytorch/issues/18095 return x.flip(axis, **kwargs) -def roll(x: Array, /, shift: Union[int, Tuple[int, ...]], *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs) -> Array: +def roll(x: Array, /, shift: int | tuple[int, ...], *, axis: int | tuple[int, ...] | None = None, **kwargs: object) -> Array: return torch.roll(x, shift, axis, **kwargs) -def nonzero(x: Array, /, **kwargs) -> Tuple[Array, ...]: +def nonzero(x: Array, /, **kwargs: object) -> tuple[Array, ...]: if x.ndim == 0: raise ValueError("nonzero() does not support zero-dimensional arrays") return torch.nonzero(x, as_tuple=True, **kwargs) @@ -532,8 +531,8 @@ def diff( *, axis: int = -1, n: int = 1, - prepend: Optional[Array] = None, - append: Optional[Array] = None, + prepend: Array | None = None, + append: Array | None = None, ) -> Array: return torch.diff(x, dim=axis, n=n, prepend=prepend, append=append) @@ -543,7 +542,7 @@ def count_nonzero( x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, + axis: int | tuple[int, ...] | None = None, keepdims: bool = False, ) -> Array: result = torch.count_nonzero(x, dim=axis) @@ -564,12 +563,7 @@ def repeat(x: Array, repeats: int | Array, /, *, axis: int | None = None) -> Arr return torch.repeat_interleave(x, repeats, axis) -def where( - condition: Array, - x1: Array | bool | int | float | complex, - x2: Array | bool | int | float | complex, - /, -) -> Array: +def where(condition: Array, x1: Array | complex, x2: Array | complex, /) -> Array: x1, x2 = _fix_promotion(x1, x2) return torch.where(condition, x1, x2) @@ -577,10 +571,10 @@ def where( # torch.reshape doesn't have the copy keyword def reshape(x: Array, /, - shape: Tuple[int, ...], + shape: tuple[int, ...], *, - copy: Optional[bool] = None, - **kwargs) -> Array: + copy: bool | None = None, + **kwargs: object) -> Array: if copy is not None: raise NotImplementedError("torch.reshape doesn't yet support the copy keyword") return torch.reshape(x, shape, **kwargs) @@ -589,14 +583,14 @@ def reshape(x: Array, # (https://github.com/pytorch/pytorch/issues/70915), and doesn't support some # keyword argument combinations # (https://github.com/pytorch/pytorch/issues/70914) -def arange(start: Union[int, float], +def arange(start: float, /, - stop: Optional[Union[int, float]] = None, - step: Union[int, float] = 1, + stop: float | None = None, + step: float = 1, *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs) -> Array: + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object) -> Array: if stop is None: start, stop = 0, start if step > 0 and stop <= start or step < 0 and stop >= start: @@ -611,13 +605,13 @@ def arange(start: Union[int, float], # torch.eye does not accept None as a default for the second argument and # doesn't support off-diagonals (https://github.com/pytorch/pytorch/issues/70910) def eye(n_rows: int, - n_cols: Optional[int] = None, + n_cols: int | None = None, /, *, k: int = 0, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs) -> Array: + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object) -> Array: if n_cols is None: n_cols = n_rows z = torch.zeros(n_rows, n_cols, dtype=dtype, device=device, **kwargs) @@ -626,52 +620,52 @@ def eye(n_rows: int, return z # torch.linspace doesn't have the endpoint parameter -def linspace(start: Union[int, float], - stop: Union[int, float], +def linspace(start: float, + stop: float, /, num: int, *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, + dtype: DType | None = None, + device: Device | None = None, endpoint: bool = True, - **kwargs) -> Array: + **kwargs: object) -> Array: if not endpoint: return torch.linspace(start, stop, num+1, dtype=dtype, device=device, **kwargs)[:-1] return torch.linspace(start, stop, num, dtype=dtype, device=device, **kwargs) # torch.full does not accept an int size # https://github.com/pytorch/pytorch/issues/70906 -def full(shape: Union[int, Tuple[int, ...]], - fill_value: bool | int | float | complex, +def full(shape: int | tuple[int, ...], + fill_value: complex, *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs) -> Array: + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object) -> Array: if isinstance(shape, int): shape = (shape,) return torch.full(shape, fill_value, dtype=dtype, device=device, **kwargs) # ones, zeros, and empty do not accept shape as a keyword argument -def ones(shape: Union[int, Tuple[int, ...]], +def ones(shape: int | tuple[int, ...], *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs) -> Array: + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object) -> Array: return torch.ones(shape, dtype=dtype, device=device, **kwargs) -def zeros(shape: Union[int, Tuple[int, ...]], +def zeros(shape: int | tuple[int, ...], *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs) -> Array: + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object) -> Array: return torch.zeros(shape, dtype=dtype, device=device, **kwargs) -def empty(shape: Union[int, Tuple[int, ...]], +def empty(shape: int | tuple[int, ...], *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs) -> Array: + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object) -> Array: return torch.empty(shape, dtype=dtype, device=device, **kwargs) # tril and triu do not call the keyword argument k @@ -693,14 +687,14 @@ def astype( /, *, copy: bool = True, - device: Optional[Device] = None, + device: Device | None = None, ) -> Array: if device is not None: return x.to(device, dtype=dtype, copy=copy) return x.to(dtype=dtype, copy=copy) -def broadcast_arrays(*arrays: Array) -> List[Array]: +def broadcast_arrays(*arrays: Array) -> list[Array]: shape = torch.broadcast_shapes(*[a.shape for a in arrays]) return [torch.broadcast_to(a, shape) for a in arrays] @@ -738,7 +732,7 @@ def unique_inverse(x: Array) -> UniqueInverseResult: def unique_values(x: Array) -> Array: return torch.unique(x) -def matmul(x1: Array, x2: Array, /, **kwargs) -> Array: +def matmul(x1: Array, x2: Array, /, **kwargs: object) -> Array: # torch.matmul doesn't type promote (but differently from _fix_promotion) x1, x2 = _fix_promotion(x1, x2, only_scalar=False) return torch.matmul(x1, x2, **kwargs) @@ -756,8 +750,8 @@ def tensordot( x2: Array, /, *, - axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2, - **kwargs, + axes: int | tuple[Sequence[int], Sequence[int]] = 2, + **kwargs: object, ) -> Array: # Note: torch.tensordot fails with integer dtypes when there is only 1 # element in the axis (https://github.com/pytorch/pytorch/issues/84530). @@ -766,8 +760,10 @@ def tensordot( def isdtype( - dtype: DType, kind: Union[DType, str, Tuple[Union[DType, str], ...]], - *, _tuple=True, # Disallow nested tuples + dtype: DType, + kind: DType | str | tuple[DType | str, ...], + *, + _tuple: bool = True, # Disallow nested tuples ) -> bool: """ Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``. @@ -801,7 +797,7 @@ def isdtype( else: return dtype == kind -def take(x: Array, indices: Array, /, *, axis: Optional[int] = None, **kwargs) -> Array: +def take(x: Array, indices: Array, /, *, axis: int | None = None, **kwargs: object) -> Array: if axis is None: if x.ndim != 1: raise ValueError("axis must be specified when ndim > 1") @@ -828,7 +824,7 @@ def sign(x: Array, /) -> Array: return out -def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> List[Array]: +def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> list[Array]: # enforce the default of 'xy' # TODO: is the return type a list or a tuple return list(torch.meshgrid(*arrays, indexing='xy')) diff --git a/array_api_compat/torch/fft.py b/array_api_compat/torch/fft.py index 50e6a0d0..ddf87c65 100644 --- a/array_api_compat/torch/fft.py +++ b/array_api_compat/torch/fft.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import Union, Sequence, Literal +from collections.abc import Sequence +from typing import Literal import torch import torch.fft @@ -17,7 +18,7 @@ def fftn( s: Sequence[int] = None, axes: Sequence[int] = None, norm: Literal["backward", "ortho", "forward"] = "backward", - **kwargs, + **kwargs: object, ) -> Array: return torch.fft.fftn(x, s=s, dim=axes, norm=norm, **kwargs) @@ -28,7 +29,7 @@ def ifftn( s: Sequence[int] = None, axes: Sequence[int] = None, norm: Literal["backward", "ortho", "forward"] = "backward", - **kwargs, + **kwargs: object, ) -> Array: return torch.fft.ifftn(x, s=s, dim=axes, norm=norm, **kwargs) @@ -39,7 +40,7 @@ def rfftn( s: Sequence[int] = None, axes: Sequence[int] = None, norm: Literal["backward", "ortho", "forward"] = "backward", - **kwargs, + **kwargs: object, ) -> Array: return torch.fft.rfftn(x, s=s, dim=axes, norm=norm, **kwargs) @@ -50,7 +51,7 @@ def irfftn( s: Sequence[int] = None, axes: Sequence[int] = None, norm: Literal["backward", "ortho", "forward"] = "backward", - **kwargs, + **kwargs: object, ) -> Array: return torch.fft.irfftn(x, s=s, dim=axes, norm=norm, **kwargs) @@ -58,8 +59,8 @@ def fftshift( x: Array, /, *, - axes: Union[int, Sequence[int]] = None, - **kwargs, + axes: int | Sequence[int] = None, + **kwargs: object, ) -> Array: return torch.fft.fftshift(x, dim=axes, **kwargs) @@ -67,8 +68,8 @@ def ifftshift( x: Array, /, *, - axes: Union[int, Sequence[int]] = None, - **kwargs, + axes: int | Sequence[int] = None, + **kwargs: object, ) -> Array: return torch.fft.ifftshift(x, dim=axes, **kwargs) diff --git a/array_api_compat/torch/linalg.py b/array_api_compat/torch/linalg.py index 70d72405..558cfe7b 100644 --- a/array_api_compat/torch/linalg.py +++ b/array_api_compat/torch/linalg.py @@ -1,8 +1,6 @@ from __future__ import annotations import torch -from typing import Optional, Union, Tuple - from torch.linalg import * # noqa: F403 # torch.linalg doesn't define __all__ @@ -32,7 +30,7 @@ def cross(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: x1, x2 = torch.broadcast_tensors(x1, x2) return torch_linalg.cross(x1, x2, dim=axis) -def vecdot(x1: Array, x2: Array, /, *, axis: int = -1, **kwargs) -> Array: +def vecdot(x1: Array, x2: Array, /, *, axis: int = -1, **kwargs: object) -> Array: from ._aliases import isdtype x1, x2 = _fix_promotion(x1, x2, only_scalar=False) @@ -54,7 +52,7 @@ def vecdot(x1: Array, x2: Array, /, *, axis: int = -1, **kwargs) -> Array: return res[..., 0, 0] return torch.linalg.vecdot(x1, x2, dim=axis, **kwargs) -def solve(x1: Array, x2: Array, /, **kwargs) -> Array: +def solve(x1: Array, x2: Array, /, **kwargs: object) -> Array: x1, x2 = _fix_promotion(x1, x2, only_scalar=False) # Torch tries to emulate NumPy 1 solve behavior by using batched 1-D solve # whenever @@ -75,7 +73,7 @@ def solve(x1: Array, x2: Array, /, **kwargs) -> Array: return torch.linalg.solve(x1, x2, **kwargs) # torch.trace doesn't support the offset argument and doesn't support stacking -def trace(x: Array, /, *, offset: int = 0, dtype: Optional[DType] = None) -> Array: +def trace(x: Array, /, *, offset: int = 0, dtype: DType | None = None) -> Array: # Use our wrapped sum to make sure it does upcasting correctly return sum(torch.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype) @@ -83,11 +81,11 @@ def vector_norm( x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, + axis: int | tuple[int, ...] | None = None, keepdims: bool = False, # JustFloat stands for inf | -inf, which are not valid for Literal ord: JustInt | JustFloat = 2, - **kwargs, + **kwargs: object, ) -> Array: # torch.vector_norm incorrectly treats axis=() the same as axis=None if axis == (): diff --git a/pyproject.toml b/pyproject.toml index aacebd11..ec054417 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,11 +43,11 @@ dev = [ "array-api-strict", "dask[array]>=2024.9.0", "jax[cpu]", + "ndonnx", "numpy>=1.22", "pytest", "torch", "sparse>=0.15.1", - "ndonnx" ] [project.urls] @@ -61,7 +61,7 @@ version = {attr = "array_api_compat.__version__"} include = ["array_api_compat*"] namespaces = false -[toolint] +[tool.ruff.lint] preview = true select = [ # Defaults @@ -79,20 +79,42 @@ ignore = [ "E722" ] -[tool.ruff.lint] -preview = true -select = [ -# Defaults -"E4", "E7", "E9", "F", -# Undefined export -"F822", -# Useless import alias -"PLC0414" -] -ignore = [ - # Module import not at top of file - "E402", - # Do not use bare `except` - "E722" +[tool.mypy] +files = ["array_api_compat"] +disallow_incomplete_defs = true +disallow_untyped_decorators = true +disallow_untyped_defs = false # TODO +ignore_missing_imports = false +no_implicit_optional = true +show_error_codes = true +warn_redundant_casts = true +warn_unused_ignores = true +warn_unreachable = true + +[[tool.mypy.overrides]] +module = ["cupy.*", "cupy_backends.*", "dask.*", "jax.*", "ndonnx.*", "sparse.*", "torch.*"] +ignore_missing_imports = true + + +[tool.pyright] +include = ["src", "tests"] +pythonPlatform = "All" + +reportAny = false +reportExplicitAny = false +# missing type stubs +reportAttributeAccessIssue = false +reportUnknownMemberType = false +reportUnknownVariableType = false +# Redundant with mypy checks +reportMissingImports = false +reportMissingTypeStubs = false +# false positives for input validation +reportUnreachable = false +# ruff handles this +reportUnusedParameter = false + +executionEnvironments = [ + { root = "array_api_compat" }, ] From cddc9ef8a19b453b09884987ca6a0626408a1478 Mon Sep 17 00:00:00 2001 From: Guido Imperiale Date: Fri, 6 Jun 2025 12:04:01 +0100 Subject: [PATCH 76/80] ENH: Review exported symbols; redesign `test_all` (#315) Review and discussion at https://github.com/data-apis/array-api-compat/pull/315 --- array_api_compat/_internal.py | 20 +- array_api_compat/common/_aliases.py | 2 - array_api_compat/common/_helpers.py | 2 - array_api_compat/common/_linalg.py | 2 - array_api_compat/cupy/__init__.py | 13 +- array_api_compat/cupy/_aliases.py | 3 +- array_api_compat/cupy/_typing.py | 1 - array_api_compat/cupy/fft.py | 7 +- array_api_compat/cupy/linalg.py | 6 +- array_api_compat/dask/array/__init__.py | 16 +- array_api_compat/dask/array/_aliases.py | 4 - array_api_compat/dask/array/fft.py | 19 +- array_api_compat/dask/array/linalg.py | 35 +-- array_api_compat/numpy/__init__.py | 22 +- array_api_compat/numpy/_aliases.py | 6 +- array_api_compat/numpy/_typing.py | 1 - array_api_compat/numpy/fft.py | 15 +- array_api_compat/numpy/linalg.py | 29 +- array_api_compat/torch/__init__.py | 29 +- array_api_compat/torch/_aliases.py | 5 +- array_api_compat/torch/fft.py | 16 +- array_api_compat/torch/linalg.py | 19 +- tests/test_all.py | 360 ++++++++++++++++++++---- 23 files changed, 435 insertions(+), 197 deletions(-) diff --git a/array_api_compat/_internal.py b/array_api_compat/_internal.py index b1925492..baa39ded 100644 --- a/array_api_compat/_internal.py +++ b/array_api_compat/_internal.py @@ -2,6 +2,7 @@ Internal helpers """ +import importlib from collections.abc import Callable from functools import wraps from inspect import signature @@ -52,8 +53,25 @@ def wrapped_f(*args: object, **kwargs: object) -> object: return inner -__all__ = ["get_xp"] +def clone_module(mod_name: str, globals_: dict[str, object]) -> list[str]: + """Import everything from module, updating globals(). + Returns __all__. + """ + mod = importlib.import_module(mod_name) + # Neither of these two methods is sufficient by itself, + # depending on various idiosyncrasies of the libraries we're wrapping. + objs = {} + exec(f"from {mod.__name__} import *", objs) + + for n in dir(mod): + if not n.startswith("_") and hasattr(mod, n): + objs[n] = getattr(mod, n) + + globals_.update(objs) + return list(objs) + +__all__ = ["get_xp", "clone_module"] def __dir__() -> list[str]: return __all__ diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 51732b91..27b2604b 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -721,8 +721,6 @@ def iinfo(type_: DType | Array, /, xp: Namespace) -> Any: "finfo", "iinfo", ] -_all_ignore = ["inspect", "array_namespace", "NamedTuple"] - def __dir__() -> list[str]: return __all__ diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index cae0ee0b..37f31ec2 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -1062,7 +1062,5 @@ def is_lazy_array(x: object) -> TypeGuard[_ArrayApiObj]: "to_device", ] -_all_ignore = ['lru_cache', 'sys', 'math', 'inspect', 'warnings'] - def __dir__() -> list[str]: return __all__ diff --git a/array_api_compat/common/_linalg.py b/array_api_compat/common/_linalg.py index 3fd9d860..69672af7 100644 --- a/array_api_compat/common/_linalg.py +++ b/array_api_compat/common/_linalg.py @@ -225,8 +225,6 @@ def trace( 'matrix_transpose', 'svdvals', 'vecdot', 'vector_norm', 'diagonal', 'trace'] -_all_ignore = ['math', 'normalize_axis_tuple', 'get_xp', 'np', 'isdtype'] - def __dir__() -> list[str]: return __all__ diff --git a/array_api_compat/cupy/__init__.py b/array_api_compat/cupy/__init__.py index 9a30f95d..af003c5a 100644 --- a/array_api_compat/cupy/__init__.py +++ b/array_api_compat/cupy/__init__.py @@ -1,3 +1,4 @@ +from typing import Final from cupy import * # noqa: F403 # from cupy import * doesn't overwrite these builtin names @@ -5,9 +6,19 @@ # These imports may overwrite names from the import * above. from ._aliases import * # noqa: F403 +from ._info import __array_namespace_info__ # noqa: F401 # See the comment in the numpy __init__.py __import__(__package__ + '.linalg') __import__(__package__ + '.fft') -__array_api_version__ = '2024.12' +__array_api_version__: Final = '2024.12' + +__all__ = sorted( + {name for name in globals() if not name.startswith("__")} + - {"Final", "_aliases", "_info", "_typing"} + | {"__array_api_version__", "__array_namespace_info__", "linalg", "fft"} +) + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index c0473ca4..2752bd98 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -7,7 +7,6 @@ from ..common import _aliases, _helpers from ..common._typing import NestedSequence, SupportsBufferProtocol from .._internal import get_xp -from ._info import __array_namespace_info__ from ._typing import Array, Device, DType bool = cp.bool_ @@ -141,7 +140,7 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array: else: unstack = get_xp(cp)(_aliases.unstack) -__all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', 'astype', +__all__ = _aliases.__all__ + ['asarray', 'astype', 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atan2', 'atanh', 'bitwise_left_shift', 'bitwise_invert', 'bitwise_right_shift', diff --git a/array_api_compat/cupy/_typing.py b/array_api_compat/cupy/_typing.py index d8e49ca7..e5c202dc 100644 --- a/array_api_compat/cupy/_typing.py +++ b/array_api_compat/cupy/_typing.py @@ -1,7 +1,6 @@ from __future__ import annotations __all__ = ["Array", "DType", "Device"] -_all_ignore = ["cp"] from typing import TYPE_CHECKING diff --git a/array_api_compat/cupy/fft.py b/array_api_compat/cupy/fft.py index 2bd11940..53a9a454 100644 --- a/array_api_compat/cupy/fft.py +++ b/array_api_compat/cupy/fft.py @@ -31,7 +31,6 @@ __all__ = fft_all + _fft.__all__ -del get_xp -del cp -del fft_all -del _fft +def __dir__() -> list[str]: + return __all__ + diff --git a/array_api_compat/cupy/linalg.py b/array_api_compat/cupy/linalg.py index 7bc3536e..da301574 100644 --- a/array_api_compat/cupy/linalg.py +++ b/array_api_compat/cupy/linalg.py @@ -43,7 +43,5 @@ __all__ = linalg_all + _linalg.__all__ -del get_xp -del cp -del linalg_all -del _linalg +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/dask/array/__init__.py b/array_api_compat/dask/array/__init__.py index 6d2ea7cd..f78aa8b3 100644 --- a/array_api_compat/dask/array/__init__.py +++ b/array_api_compat/dask/array/__init__.py @@ -1,12 +1,26 @@ from typing import Final -from dask.array import * # noqa: F403 +from ..._internal import clone_module + +__all__ = clone_module("dask.array", globals()) # These imports may overwrite names from the import * above. +from . import _aliases from ._aliases import * # type: ignore[assignment] # noqa: F403 +from ._info import __array_namespace_info__ # noqa: F401 __array_api_version__: Final = "2024.12" +del Final # See the comment in the numpy __init__.py __import__(__package__ + '.linalg') __import__(__package__ + '.fft') + +__all__ = sorted( + set(__all__) + | set(_aliases.__all__) + | {"__array_api_version__", "__array_namespace_info__", "linalg", "fft"} +) + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index bc0302fe..4d1e7341 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -41,7 +41,6 @@ NestedSequence, SupportsBufferProtocol, ) -from ._info import __array_namespace_info__ isdtype = get_xp(np)(_aliases.isdtype) unstack = get_xp(da)(_aliases.unstack) @@ -355,7 +354,6 @@ def count_nonzero( __all__ = [ - "__array_namespace_info__", "count_nonzero", "bool", "int8", "int16", "int32", "int64", @@ -369,8 +367,6 @@ def count_nonzero( "bitwise_left_shift", "bitwise_right_shift", "bitwise_invert", ] # fmt: skip __all__ += _aliases.__all__ -_all_ignore = ["array_namespace", "get_xp", "da", "np"] - def __dir__() -> list[str]: return __all__ diff --git a/array_api_compat/dask/array/fft.py b/array_api_compat/dask/array/fft.py index 68c4280e..44b68e73 100644 --- a/array_api_compat/dask/array/fft.py +++ b/array_api_compat/dask/array/fft.py @@ -1,13 +1,6 @@ -from dask.array.fft import * # noqa: F403 -# dask.array.fft doesn't have __all__. If it is added, replace this with -# -# from dask.array.fft import __all__ as linalg_all -_n: dict[str, object] = {} -exec('from dask.array.fft import *', _n) -for k in ("__builtins__", "Sequence", "annotations", "warnings"): - _n.pop(k, None) -fft_all = list(_n) -del _n, k +from ..._internal import clone_module + +__all__ = clone_module("dask.array.fft", globals()) from ...common import _fft from ..._internal import get_xp @@ -17,5 +10,7 @@ fftfreq = get_xp(da)(_fft.fftfreq) rfftfreq = get_xp(da)(_fft.rfftfreq) -__all__ = fft_all + ["fftfreq", "rfftfreq"] -_all_ignore = ["da", "fft_all", "get_xp", "warnings"] +__all__ += ["fftfreq", "rfftfreq"] + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/dask/array/linalg.py b/array_api_compat/dask/array/linalg.py index 06f596bc..6b3c1011 100644 --- a/array_api_compat/dask/array/linalg.py +++ b/array_api_compat/dask/array/linalg.py @@ -4,27 +4,17 @@ import dask.array as da -# Exports -from dask.array.linalg import * # noqa: F403 -from dask.array import outer # The `matmul` and `tensordot` functions are in both the main and linalg namespaces -from dask.array import matmul, tensordot - +from dask.array import matmul, outer, tensordot -from ..._internal import get_xp +# Exports +from ..._internal import clone_module, get_xp from ...common import _linalg from ...common._typing import Array -from ._aliases import matrix_transpose, vecdot -# dask.array.linalg doesn't have __all__. If it is added, replace this with -# -# from dask.array.linalg import __all__ as linalg_all -_n: dict[str, object] = {} -exec('from dask.array.linalg import *', _n) -for k in ('__builtins__', 'annotations', 'operator', 'warnings', 'Array'): - _n.pop(k, None) -linalg_all = list(_n) -del _n, k +__all__ = clone_module("dask.array.linalg", globals()) + +from ._aliases import matrix_transpose, vecdot EighResult = _linalg.EighResult QRResult = _linalg.QRResult @@ -64,10 +54,11 @@ def svdvals(x: Array) -> Array: vector_norm = get_xp(da)(_linalg.vector_norm) diagonal = get_xp(da)(_linalg.diagonal) -__all__ = linalg_all + ["trace", "outer", "matmul", "tensordot", - "matrix_transpose", "vecdot", "EighResult", - "QRResult", "SlogdetResult", "SVDResult", "qr", - "cholesky", "matrix_rank", "matrix_norm", "svdvals", - "vector_norm", "diagonal"] +__all__ += ["trace", "outer", "matmul", "tensordot", + "matrix_transpose", "vecdot", "EighResult", + "QRResult", "SlogdetResult", "SVDResult", "qr", + "cholesky", "matrix_rank", "matrix_norm", "svdvals", + "vector_norm", "diagonal"] -_all_ignore = ['get_xp', 'da', 'linalg_all', 'warnings'] +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/numpy/__init__.py b/array_api_compat/numpy/__init__.py index bf43fe61..23379e44 100644 --- a/array_api_compat/numpy/__init__.py +++ b/array_api_compat/numpy/__init__.py @@ -1,16 +1,17 @@ # ruff: noqa: PLC0414 from typing import Final -from numpy import * # noqa: F403 # pyright: ignore[reportWildcardImportFromLibrary] +from .._internal import clone_module -# from numpy import * doesn't overwrite these builtin names -from numpy import abs as abs -from numpy import max as max -from numpy import min as min -from numpy import round as round +# This needs to be loaded explicitly before cloning +import numpy.typing # noqa: F401 + +__all__ = clone_module("numpy", globals()) # These imports may overwrite names from the import * above. +from . import _aliases from ._aliases import * # type: ignore[assignment,no-redef] # noqa: F403 +from ._info import __array_namespace_info__ # noqa: F401 # Don't know why, but we have to do an absolute import to import linalg. If we # instead do @@ -26,3 +27,12 @@ from .linalg import matrix_transpose, vecdot # type: ignore[no-redef] # noqa: F401 __array_api_version__: Final = "2024.12" + +__all__ = sorted( + set(__all__) + | set(_aliases.__all__) + | {"__array_api_version__", "__array_namespace_info__", "linalg", "fft"} +) + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index 5a05a820..5bb8869a 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -9,7 +9,6 @@ from .._internal import get_xp from ..common import _aliases, _helpers from ..common._typing import NestedSequence, SupportsBufferProtocol -from ._info import __array_namespace_info__ from ._typing import Array, Device, DType bool = np.bool_ @@ -147,8 +146,7 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array: else: unstack = get_xp(np)(_aliases.unstack) -__all__ = [ - "__array_namespace_info__", +__all__ = _aliases.__all__ + [ "asarray", "astype", "acos", @@ -167,8 +165,6 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array: "pow", "take_along_axis" ] -__all__ += _aliases.__all__ -_all_ignore = ["np", "get_xp"] def __dir__() -> list[str]: diff --git a/array_api_compat/numpy/_typing.py b/array_api_compat/numpy/_typing.py index e771c788..b5fa188c 100644 --- a/array_api_compat/numpy/_typing.py +++ b/array_api_compat/numpy/_typing.py @@ -23,7 +23,6 @@ Array: TypeAlias = np.ndarray __all__ = ["Array", "DType", "Device"] -_all_ignore = ["np"] def __dir__() -> list[str]: diff --git a/array_api_compat/numpy/fft.py b/array_api_compat/numpy/fft.py index 06875f00..a492feb8 100644 --- a/array_api_compat/numpy/fft.py +++ b/array_api_compat/numpy/fft.py @@ -1,6 +1,8 @@ import numpy as np -from numpy.fft import __all__ as fft_all -from numpy.fft import fft2, ifft2, irfft2, rfft2 + +from .._internal import clone_module + +__all__ = clone_module("numpy.fft", globals()) from .._internal import get_xp from ..common import _fft @@ -21,15 +23,8 @@ ifftshift = get_xp(np)(_fft.ifftshift) -__all__ = ["rfft2", "irfft2", "fft2", "ifft2"] -__all__ += _fft.__all__ - +__all__ = sorted(set(__all__) | set(_fft.__all__)) def __dir__() -> list[str]: return __all__ - -del get_xp -del np -del fft_all -del _fft diff --git a/array_api_compat/numpy/linalg.py b/array_api_compat/numpy/linalg.py index 9a618be9..7168441c 100644 --- a/array_api_compat/numpy/linalg.py +++ b/array_api_compat/numpy/linalg.py @@ -7,26 +7,11 @@ import numpy as np -# intersection of `np.linalg.__all__` on numpy 1.22 and 2.2, minus `_linalg.__all__` -from numpy.linalg import ( - LinAlgError, - cond, - det, - eig, - eigvals, - eigvalsh, - inv, - lstsq, - matrix_power, - multi_dot, - norm, - tensorinv, - tensorsolve, -) - -from .._internal import get_xp +from .._internal import clone_module, get_xp from ..common import _linalg +__all__ = clone_module("numpy.linalg", globals()) + # These functions are in both the main and linalg namespaces from ._aliases import matmul, matrix_transpose, tensordot, vecdot # noqa: F401 from ._typing import Array @@ -120,7 +105,7 @@ def solve(x1: Array, x2: Array, /) -> Array: vector_norm = get_xp(np)(_linalg.vector_norm) -__all__ = [ +_all = [ "LinAlgError", "cond", "det", @@ -132,12 +117,12 @@ def solve(x1: Array, x2: Array, /) -> Array: "matrix_power", "multi_dot", "norm", + "solve", "tensorinv", "tensorsolve", + "vector_norm", ] -__all__ += _linalg.__all__ -__all__ += ["solve", "vector_norm"] - +__all__ = sorted(set(__all__) | set(_linalg.__all__) | set(_all)) def __dir__() -> list[str]: return __all__ diff --git a/array_api_compat/torch/__init__.py b/array_api_compat/torch/__init__.py index 69fd19ce..6cbb6ec2 100644 --- a/array_api_compat/torch/__init__.py +++ b/array_api_compat/torch/__init__.py @@ -1,22 +1,25 @@ -from torch import * # noqa: F403 +from typing import Final -# Several names are not included in the above import * -import torch -for n in dir(torch): - if (n.startswith('_') - or n.endswith('_') - or 'cuda' in n - or 'cpu' in n - or 'backward' in n): - continue - exec(f"{n} = torch.{n}") -del n +from .._internal import clone_module + +__all__ = clone_module("torch", globals()) # These imports may overwrite names from the import * above. +from . import _aliases from ._aliases import * # noqa: F403 +from ._info import __array_namespace_info__ # noqa: F401 # See the comment in the numpy __init__.py __import__(__package__ + '.linalg') __import__(__package__ + '.fft') -__array_api_version__ = '2024.12' +__array_api_version__: Final = '2024.12' + +__all__ = sorted( + set(__all__) + | set(_aliases.__all__) + | {"__array_api_version__", "__array_namespace_info__", "linalg", "fft"} +) + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 7a449001..91161d24 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -10,7 +10,6 @@ from .._internal import get_xp from ..common import _aliases from ..common._typing import NestedSequence, SupportsBufferProtocol -from ._info import __array_namespace_info__ from ._typing import Array, Device, DType _int_dtypes = { @@ -830,7 +829,7 @@ def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> list[Array return list(torch.meshgrid(*arrays, indexing='xy')) -__all__ = ['__array_namespace_info__', 'asarray', 'result_type', 'can_cast', +__all__ = ['asarray', 'result_type', 'can_cast', 'permute_dims', 'bitwise_invert', 'newaxis', 'conj', 'add', 'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'copysign', 'count_nonzero', @@ -847,5 +846,3 @@ def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> list[Array 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values', 'matmul', 'matrix_transpose', 'vecdot', 'tensordot', 'isdtype', 'take', 'take_along_axis', 'sign', 'finfo', 'iinfo', 'repeat', 'meshgrid'] - -_all_ignore = ['torch', 'get_xp'] diff --git a/array_api_compat/torch/fft.py b/array_api_compat/torch/fft.py index ddf87c65..76342980 100644 --- a/array_api_compat/torch/fft.py +++ b/array_api_compat/torch/fft.py @@ -5,9 +5,11 @@ import torch import torch.fft -from torch.fft import * # noqa: F403 from ._typing import Array +from .._internal import clone_module + +__all__ = clone_module("torch.fft", globals()) # Several torch fft functions do not map axes to dim @@ -74,13 +76,7 @@ def ifftshift( return torch.fft.ifftshift(x, dim=axes, **kwargs) -__all__ = torch.fft.__all__ + [ - "fftn", - "ifftn", - "rfftn", - "irfftn", - "fftshift", - "ifftshift", -] +__all__ += ["fftn", "ifftn", "rfftn", "irfftn", "fftshift", "ifftshift"] -_all_ignore = ['torch'] +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/torch/linalg.py b/array_api_compat/torch/linalg.py index 558cfe7b..08271d22 100644 --- a/array_api_compat/torch/linalg.py +++ b/array_api_compat/torch/linalg.py @@ -1,12 +1,11 @@ from __future__ import annotations import torch -from torch.linalg import * # noqa: F403 +import torch.linalg -# torch.linalg doesn't define __all__ -# from torch.linalg import __all__ as linalg_all -from torch import linalg as torch_linalg -linalg_all = [i for i in dir(torch_linalg) if not i.startswith('_')] +from .._internal import clone_module + +__all__ = clone_module("torch.linalg", globals()) # outer is implemented in torch but aren't in the linalg namespace from torch import outer @@ -28,7 +27,7 @@ def cross(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: if not (x1.shape[axis] == x2.shape[axis] == 3): raise ValueError(f"cross product axis must have size 3, got {x1.shape[axis]} and {x2.shape[axis]}") x1, x2 = torch.broadcast_tensors(x1, x2) - return torch_linalg.cross(x1, x2, dim=axis) + return torch.linalg.cross(x1, x2, dim=axis) def vecdot(x1: Array, x2: Array, /, *, axis: int = -1, **kwargs: object) -> Array: from ._aliases import isdtype @@ -108,12 +107,8 @@ def vector_norm( return out return torch.linalg.vector_norm(x, ord=ord, axis=axis, keepdim=keepdims, **kwargs) -__all__ = linalg_all + ['outer', 'matmul', 'matrix_transpose', 'tensordot', - 'cross', 'vecdot', 'solve', 'trace', 'vector_norm'] - -_all_ignore = ['torch_linalg', 'sum'] - -del linalg_all +__all__ += ['outer', 'matmul', 'matrix_transpose', 'tensordot', + 'cross', 'vecdot', 'solve', 'trace', 'vector_norm'] def __dir__() -> list[str]: return __all__ diff --git a/tests/test_all.py b/tests/test_all.py index 271cd189..c36aef67 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -1,63 +1,311 @@ -""" -Test that files that define __all__ aren't missing any exports. +"""Test exported names""" -You can add names that shouldn't be exported to _all_ignore, like +import builtins -_all_ignore = ['sys'] +import numpy as np +import pytest -This is preferable to del-ing the names as this will break any name that is -used inside of a function. Note that names starting with an underscore are automatically ignored. -""" +from array_api_compat._internal import clone_module +from ._helpers import wrapped_libraries -import sys +NAMES = { + "": [ + # Inspection + "__array_api_version__", + "__array_namespace_info__", + # Submodules + "fft", + "linalg", + # Constants + "e", + "inf", + "nan", + "newaxis", + "pi", + # Creation Functions + "arange", + "asarray", + "empty", + "empty_like", + "eye", + "from_dlpack", + "full", + "full_like", + "linspace", + "meshgrid", + "ones", + "ones_like", + "tril", + "triu", + "zeros", + "zeros_like", + # Data Type Functions + "astype", + "can_cast", + "finfo", + "iinfo", + "isdtype", + "result_type", + # Data Types + "bool", + "int8", + "int16", + "int32", + "int64", + "uint8", + "uint16", + "uint32", + "uint64", + "float32", + "float64", + "complex64", + "complex128", + # Elementwise Functions + "abs", + "acos", + "acosh", + "add", + "asin", + "asinh", + "atan", + "atan2", + "atanh", + "bitwise_and", + "bitwise_left_shift", + "bitwise_invert", + "bitwise_or", + "bitwise_right_shift", + "bitwise_xor", + "ceil", + "clip", + "conj", + "copysign", + "cos", + "cosh", + "divide", + "equal", + "exp", + "expm1", + "floor", + "floor_divide", + "greater", + "greater_equal", + "hypot", + "imag", + "isfinite", + "isinf", + "isnan", + "less", + "less_equal", + "log", + "log1p", + "log2", + "log10", + "logaddexp", + "logical_and", + "logical_not", + "logical_or", + "logical_xor", + "maximum", + "minimum", + "multiply", + "negative", + "nextafter", + "not_equal", + "positive", + "pow", + "real", + "reciprocal", + "remainder", + "round", + "sign", + "signbit", + "sin", + "sinh", + "square", + "sqrt", + "subtract", + "tan", + "tanh", + "trunc", + # Indexing Functions + "take", + "take_along_axis", + # Linear Algebra Functions + "matmul", + "matrix_transpose", + "tensordot", + "vecdot", + # Manipulation Functions + "broadcast_arrays", + "broadcast_to", + "concat", + "expand_dims", + "flip", + "moveaxis", + "permute_dims", + "repeat", + "reshape", + "roll", + "squeeze", + "stack", + "tile", + "unstack", + # Searching Functions + "argmax", + "argmin", + "count_nonzero", + "nonzero", + "searchsorted", + "where", + # Set Functions + "unique_all", + "unique_counts", + "unique_inverse", + "unique_values", + # Sorting Functions + "argsort", + "sort", + # Statistical Functions + "cumulative_prod", + "cumulative_sum", + "max", + "mean", + "min", + "prod", + "std", + "sum", + "var", + # Utility Functions + "all", + "any", + "diff", + ], + "fft": [ + "fft", + "ifft", + "fftn", + "ifftn", + "rfft", + "irfft", + "rfftn", + "irfftn", + "hfft", + "ihfft", + "fftfreq", + "rfftfreq", + "fftshift", + "ifftshift", + ], + "linalg": [ + "cholesky", + "cross", + "det", + "diagonal", + "eigh", + "eigvalsh", + "inv", + "matmul", + "matrix_norm", + "matrix_power", + "matrix_rank", + "matrix_transpose", + "outer", + "pinv", + "qr", + "slogdet", + "solve", + "svd", + "svdvals", + "tensordot", + "trace", + "vecdot", + "vector_norm", + ], +} -from ._helpers import import_, wrapped_libraries +XFAILS = { + ("numpy", ""): ["from_dlpack"] if np.__version__ < "1.23" else [], + ("dask.array", ""): ["from_dlpack", "take_along_axis"], + ("dask.array", "linalg"): [ + "cross", + "det", + "eigh", + "eigvalsh", + "matrix_power", + "pinv", + "slogdet", + ], +} -import pytest -import typing - -TYPING_NAMES = frozenset(( - "Array", - "Device", - "DType", - "Namespace", - "NestedSequence", - "SupportsBufferProtocol", -)) - -@pytest.mark.parametrize("library", ["common"] + wrapped_libraries) -def test_all(library): - if library == "common": - import array_api_compat.common # noqa: F401 - else: - import_(library, wrapper=True) - - # NB: iterate over a copy to avoid a "dictionary size changed" error - for mod_name in sys.modules.copy(): - if not mod_name.startswith('array_api_compat.' + library): - continue - - module = sys.modules[mod_name] - - # TODO: We should define __all__ in the __init__.py files and test it - # there too. - if not hasattr(module, '__all__'): - continue - - dir_names = [n for n in dir(module) if not n.startswith('_')] - if '__array_namespace_info__' in dir(module): - dir_names.append('__array_namespace_info__') - ignore_all_names = set(getattr(module, '_all_ignore', ())) - ignore_all_names |= set(dir(typing)) - ignore_all_names |= {"annotations"} - if not module.__name__.endswith("._typing"): - ignore_all_names |= TYPING_NAMES - dir_names = set(dir_names) - set(ignore_all_names) - all_names = module.__all__ - - if set(dir_names) != set(all_names): - extra_dir = set(dir_names) - set(all_names) - extra_all = set(all_names) - set(dir_names) - assert not extra_dir, f"Some dir() names not included in __all__ for {mod_name}: {extra_dir}" - assert not extra_all, f"Some __all__ names not in dir() for {mod_name}: {extra_all}" + +def all_names(mod): + """Return all names available in a module.""" + objs = {} + clone_module(mod.__name__, objs) + return set(objs) + + +def get_mod(library, module, *, compat): + if compat: + library = f"array_api_compat.{library}" + xp = pytest.importorskip(library) + return getattr(xp, module) if module else xp + + +@pytest.mark.parametrize("module", list(NAMES)) +@pytest.mark.parametrize("library", wrapped_libraries) +def test_array_api_names(library, module): + """Test that __all__ isn't missing any exports + dictated by the Standard. + """ + mod = get_mod(library, module, compat=True) + missing = set(NAMES[module]) - all_names(mod) + xfail = set(XFAILS.get((library, module), [])) + xpass = xfail - missing + fails = missing - xfail + assert not xpass, f"Names in XFAILS are defined: {xpass}" + assert not fails, f"Missing exports: {fails}" + + +@pytest.mark.parametrize("module", list(NAMES)) +@pytest.mark.parametrize("library", wrapped_libraries) +def test_compat_doesnt_hide_names(library, module): + """The base namespace can have more names than the ones explicitly exported + by array-api-compat. Test that we're not suppressing them. + """ + bare_mod = get_mod(library, module, compat=False) + compat_mod = get_mod(library, module, compat=True) + + missing = all_names(bare_mod) - all_names(compat_mod) + missing = {name for name in missing if not name.startswith("_")} + assert not missing, f"Non-Array API names have been hidden: {missing}" + + +@pytest.mark.parametrize("module", list(NAMES)) +@pytest.mark.parametrize("library", wrapped_libraries) +def test_compat_doesnt_add_names(library, module): + """Test that array-api-compat isn't adding names to the namespace + besides those defined by the Array API Standard. + """ + bare_mod = get_mod(library, module, compat=False) + compat_mod = get_mod(library, module, compat=True) + + aapi_names = set(NAMES[module]) + spurious = all_names(compat_mod) - all_names(bare_mod) - aapi_names + # Quietly ignore *Result dataclasses + spurious = {name for name in spurious if not name.endswith("Result")} + assert not spurious, ( + f"array-api-compat is adding non-Array API names: {spurious}" + ) + + +@pytest.mark.parametrize( + "name", [name for name in NAMES[""] if hasattr(builtins, name)] +) +@pytest.mark.parametrize("library", wrapped_libraries) +def test_builtins_collision(library, name): + """Test that xp.bool is not accidentally builtins.bool, etc.""" + xp = pytest.importorskip(f"array_api_compat.{library}") + assert getattr(xp, name) is not getattr(builtins, name) From 1d1178d33f7af737abf697a76fb161901faa075d Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 16 Jun 2025 10:16:28 +0000 Subject: [PATCH 77/80] Bump dawidd6/action-download-artifact from 10 to 11 in the actions group Bumps the actions group with 1 update: [dawidd6/action-download-artifact](https://github.com/dawidd6/action-download-artifact). Updates `dawidd6/action-download-artifact` from 10 to 11 - [Release notes](https://github.com/dawidd6/action-download-artifact/releases) - [Commits](https://github.com/dawidd6/action-download-artifact/compare/v10...v11) --- updated-dependencies: - dependency-name: dawidd6/action-download-artifact dependency-version: '11' dependency-type: direct:production update-type: version-update:semver-major dependency-group: actions ... Signed-off-by: dependabot[bot] --- .github/workflows/docs-deploy.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/docs-deploy.yml b/.github/workflows/docs-deploy.yml index 4e3efb39..ed90b29d 100644 --- a/.github/workflows/docs-deploy.yml +++ b/.github/workflows/docs-deploy.yml @@ -13,7 +13,7 @@ jobs: steps: - uses: actions/checkout@v4 - name: Download Artifact - uses: dawidd6/action-download-artifact@v10 + uses: dawidd6/action-download-artifact@v11 with: workflow: docs-build.yml name: docs-build From 4bafa4cc8a455a301f3688fd3fa7404a4fe00974 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 11 Aug 2025 14:10:23 +0000 Subject: [PATCH 78/80] Bump actions/download-artifact from 4 to 5 in the actions group Bumps the actions group with 1 update: [actions/download-artifact](https://github.com/actions/download-artifact). Updates `actions/download-artifact` from 4 to 5 - [Release notes](https://github.com/actions/download-artifact/releases) - [Commits](https://github.com/actions/download-artifact/compare/v4...v5) --- updated-dependencies: - dependency-name: actions/download-artifact dependency-version: '5' dependency-type: direct:production update-type: version-update:semver-major dependency-group: actions ... Signed-off-by: dependabot[bot] --- .github/workflows/publish-package.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/publish-package.yml b/.github/workflows/publish-package.yml index 6d88066d..1e28689c 100644 --- a/.github/workflows/publish-package.yml +++ b/.github/workflows/publish-package.yml @@ -81,7 +81,7 @@ jobs: steps: - name: Download distribution artifact - uses: actions/download-artifact@v4 + uses: actions/download-artifact@v5 with: name: dist-artifact path: dist From edd9072c296827d0e4eccf02ae87920eb2481b9c Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 18 Aug 2025 13:07:35 +0000 Subject: [PATCH 79/80] Bump actions/checkout from 4 to 5 in the actions group Bumps the actions group with 1 update: [actions/checkout](https://github.com/actions/checkout). Updates `actions/checkout` from 4 to 5 - [Release notes](https://github.com/actions/checkout/releases) - [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md) - [Commits](https://github.com/actions/checkout/compare/v4...v5) --- updated-dependencies: - dependency-name: actions/checkout dependency-version: '5' dependency-type: direct:production update-type: version-update:semver-major dependency-group: actions ... Signed-off-by: dependabot[bot] --- .github/workflows/array-api-tests.yml | 4 ++-- .github/workflows/docs-build.yml | 2 +- .github/workflows/docs-deploy.yml | 2 +- .github/workflows/publish-package.yml | 2 +- .github/workflows/ruff.yml | 2 +- .github/workflows/tests.yml | 2 +- 6 files changed, 7 insertions(+), 7 deletions(-) diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index e832f870..5c3cc7d9 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -49,12 +49,12 @@ jobs: steps: - name: Checkout array-api-compat - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: path: array-api-compat - name: Checkout array-api-tests - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: repository: data-apis/array-api-tests submodules: 'true' diff --git a/.github/workflows/docs-build.yml b/.github/workflows/docs-build.yml index 34b9cbc6..778d20e2 100644 --- a/.github/workflows/docs-build.yml +++ b/.github/workflows/docs-build.yml @@ -6,7 +6,7 @@ jobs: docs-build: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - uses: actions/setup-python@v5 - name: Install Dependencies run: | diff --git a/.github/workflows/docs-deploy.yml b/.github/workflows/docs-deploy.yml index ed90b29d..42a3598f 100644 --- a/.github/workflows/docs-deploy.yml +++ b/.github/workflows/docs-deploy.yml @@ -11,7 +11,7 @@ jobs: environment: name: docs-deploy steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - name: Download Artifact uses: dawidd6/action-download-artifact@v11 with: diff --git a/.github/workflows/publish-package.yml b/.github/workflows/publish-package.yml index 1e28689c..03cae174 100644 --- a/.github/workflows/publish-package.yml +++ b/.github/workflows/publish-package.yml @@ -30,7 +30,7 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 with: fetch-depth: 0 diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml index a9f0fd4b..68f68a14 100644 --- a/.github/workflows/ruff.yml +++ b/.github/workflows/ruff.yml @@ -5,7 +5,7 @@ jobs: runs-on: ubuntu-latest continue-on-error: true steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - name: Install Python uses: actions/setup-python@v5 with: diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index c995b370..d2e768eb 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -23,7 +23,7 @@ jobs: python-version: '3.13' steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} From c2b7a51c85d037fba4ea7dea7d0efe74a13bb550 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 8 Sep 2025 09:12:57 +0000 Subject: [PATCH 80/80] Bump the actions group with 2 updates Bumps the actions group with 2 updates: [actions/setup-python](https://github.com/actions/setup-python) and [pypa/gh-action-pypi-publish](https://github.com/pypa/gh-action-pypi-publish). Updates `actions/setup-python` from 5 to 6 - [Release notes](https://github.com/actions/setup-python/releases) - [Commits](https://github.com/actions/setup-python/compare/v5...v6) Updates `pypa/gh-action-pypi-publish` from 1.12.4 to 1.13.0 - [Release notes](https://github.com/pypa/gh-action-pypi-publish/releases) - [Commits](https://github.com/pypa/gh-action-pypi-publish/compare/v1.12.4...v1.13.0) --- updated-dependencies: - dependency-name: actions/setup-python dependency-version: '6' dependency-type: direct:production update-type: version-update:semver-major dependency-group: actions - dependency-name: pypa/gh-action-pypi-publish dependency-version: 1.13.0 dependency-type: direct:production update-type: version-update:semver-minor dependency-group: actions ... Signed-off-by: dependabot[bot] --- .github/workflows/array-api-tests.yml | 2 +- .github/workflows/docs-build.yml | 2 +- .github/workflows/publish-package.yml | 6 +++--- .github/workflows/ruff.yml | 2 +- .github/workflows/tests.yml | 2 +- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index 5c3cc7d9..e3c0c9e0 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -61,7 +61,7 @@ jobs: path: array-api-tests - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: ${{ matrix.python-version }} diff --git a/.github/workflows/docs-build.yml b/.github/workflows/docs-build.yml index 778d20e2..305a9003 100644 --- a/.github/workflows/docs-build.yml +++ b/.github/workflows/docs-build.yml @@ -7,7 +7,7 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v5 - - uses: actions/setup-python@v5 + - uses: actions/setup-python@v6 - name: Install Dependencies run: | python -m pip install .[docs] diff --git a/.github/workflows/publish-package.yml b/.github/workflows/publish-package.yml index 03cae174..bbfb2e80 100644 --- a/.github/workflows/publish-package.yml +++ b/.github/workflows/publish-package.yml @@ -35,7 +35,7 @@ jobs: fetch-depth: 0 - name: Set up Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: '3.x' @@ -95,14 +95,14 @@ jobs: # if: >- # (github.event_name == 'push' && startsWith(github.ref, 'refs/tags')) # || (github.event_name == 'workflow_dispatch' && github.event.inputs.publish == 'true') - # uses: pypa/gh-action-pypi-publish@v1.12.4 + # uses: pypa/gh-action-pypi-publish@v1.13.0 # with: # repository-url: https://test.pypi.org/legacy/ # print-hash: true - name: Publish distribution 📦 to PyPI if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags') - uses: pypa/gh-action-pypi-publish@v1.12.4 + uses: pypa/gh-action-pypi-publish@v1.13.0 with: print-hash: true diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml index 68f68a14..4a2ffcff 100644 --- a/.github/workflows/ruff.yml +++ b/.github/workflows/ruff.yml @@ -7,7 +7,7 @@ jobs: steps: - uses: actions/checkout@v5 - name: Install Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: "3.11" - name: Install dependencies diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index d2e768eb..cfbb875f 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -24,7 +24,7 @@ jobs: steps: - uses: actions/checkout@v5 - - uses: actions/setup-python@v5 + - uses: actions/setup-python@v6 with: python-version: ${{ matrix.python-version }} - name: Install Dependencies