From 5597ec755d44cb005f01601b3c2193f9f56b604f Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sun, 11 May 2025 12:56:49 +0200 Subject: [PATCH 01/21] CI: use ARRAY_API_TESTS_XFAIL_MARK on CI --- .github/workflows/array-api-tests-dask.yml | 2 ++ .github/workflows/array-api-tests-numpy-1-22.yml | 2 ++ .github/workflows/array-api-tests-numpy-1-26.yml | 2 ++ .github/workflows/array-api-tests-numpy-dev.yml | 2 ++ .github/workflows/array-api-tests-numpy-latest.yml | 2 ++ .github/workflows/array-api-tests-torch.yml | 1 + 6 files changed, 11 insertions(+) diff --git a/.github/workflows/array-api-tests-dask.yml b/.github/workflows/array-api-tests-dask.yml index afc67975..964fb52d 100644 --- a/.github/workflows/array-api-tests-dask.yml +++ b/.github/workflows/array-api-tests-dask.yml @@ -16,3 +16,5 @@ jobs: # the full test suite with at least 200 examples. pytest-extra-args: --max-examples=5 python-versions: '[''3.10'', ''3.13'']' + extra-env-vars: | + ARRAY_API_TESTS_XFAIL_MARK=skip diff --git a/.github/workflows/array-api-tests-numpy-1-22.yml b/.github/workflows/array-api-tests-numpy-1-22.yml index d8f60432..1cf6e26d 100644 --- a/.github/workflows/array-api-tests-numpy-1-22.yml +++ b/.github/workflows/array-api-tests-numpy-1-22.yml @@ -10,3 +10,5 @@ jobs: package-version: '== 1.22.*' xfails-file-extra: '-1-22' python-versions: '[''3.10'']' + extra-env-vars: | + ARRAY_API_TESTS_XFAIL_MARK=skip diff --git a/.github/workflows/array-api-tests-numpy-1-26.yml b/.github/workflows/array-api-tests-numpy-1-26.yml index 33780760..a2788d2f 100644 --- a/.github/workflows/array-api-tests-numpy-1-26.yml +++ b/.github/workflows/array-api-tests-numpy-1-26.yml @@ -10,3 +10,5 @@ jobs: package-version: '== 1.26.*' xfails-file-extra: '-1-26' python-versions: '[''3.10'', ''3.12'']' + extra-env-vars: | + ARRAY_API_TESTS_XFAIL_MARK=skip diff --git a/.github/workflows/array-api-tests-numpy-dev.yml b/.github/workflows/array-api-tests-numpy-dev.yml index d6de1a53..dce0813f 100644 --- a/.github/workflows/array-api-tests-numpy-dev.yml +++ b/.github/workflows/array-api-tests-numpy-dev.yml @@ -10,3 +10,5 @@ jobs: extra-requires: '--pre --extra-index https://pypi.anaconda.org/scientific-python-nightly-wheels/simple' xfails-file-extra: '-dev' python-versions: '[''3.11'', ''3.13'']' + extra-env-vars: | + ARRAY_API_TESTS_XFAIL_MARK=skip diff --git a/.github/workflows/array-api-tests-numpy-latest.yml b/.github/workflows/array-api-tests-numpy-latest.yml index 4d3667f6..54b21a25 100644 --- a/.github/workflows/array-api-tests-numpy-latest.yml +++ b/.github/workflows/array-api-tests-numpy-latest.yml @@ -8,3 +8,5 @@ jobs: with: package-name: numpy python-versions: '[''3.10'', ''3.13'']' + extra-env-vars: | + ARRAY_API_TESTS_XFAIL_MARK=skip diff --git a/.github/workflows/array-api-tests-torch.yml b/.github/workflows/array-api-tests-torch.yml index ac20df25..4dcb3347 100644 --- a/.github/workflows/array-api-tests-torch.yml +++ b/.github/workflows/array-api-tests-torch.yml @@ -10,4 +10,5 @@ jobs: extra-requires: '--index-url https://download.pytorch.org/whl/cpu' extra-env-vars: | ARRAY_API_TESTS_SKIP_DTYPES=uint16,uint32,uint64 + ARRAY_API_TESTS_XFAIL_MARK=skip python-versions: '[''3.10'', ''3.13'']' From 91dd626ce8b2612979e513af235be3809791f94b Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Fri, 16 May 2025 10:58:43 +0200 Subject: [PATCH 02/21] MAINT: bump version to 1.13.0.dev0 --- array_api_compat/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_compat/__init__.py b/array_api_compat/__init__.py index 653cb40a..a00e8cbc 100644 --- a/array_api_compat/__init__.py +++ b/array_api_compat/__init__.py @@ -17,6 +17,6 @@ this implementation for the default when working with NumPy arrays. """ -__version__ = '1.12.0' +__version__ = '1.13.0.dev0' from .common import * # noqa: F401, F403 From 3350f670e1b67a37888228c102e0e560f43077bd Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sun, 11 May 2025 14:12:51 +0200 Subject: [PATCH 03/21] CI: run 500 examples on NumPy and PyTorch; 50 on Dask --- .github/workflows/array-api-tests-dask.yml | 2 +- .github/workflows/array-api-tests.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/array-api-tests-dask.yml b/.github/workflows/array-api-tests-dask.yml index 964fb52d..a60b28a4 100644 --- a/.github/workflows/array-api-tests-dask.yml +++ b/.github/workflows/array-api-tests-dask.yml @@ -14,7 +14,7 @@ jobs: # workflow is barely more than a smoke test, and one should expect extreme # flakiness. Before changes to dask-xfails.txt or dask-skips.txt, please run # the full test suite with at least 200 examples. - pytest-extra-args: --max-examples=5 + pytest-extra-args: --max-examples=50 python-versions: '[''3.10'', ''3.13'']' extra-env-vars: | ARRAY_API_TESTS_XFAIL_MARK=skip diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index 31bedde6..f652438b 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -37,7 +37,7 @@ on: description: Multiline string of environment variables to set for the test run. env: - PYTEST_ARGS: "--max-examples 200 -v -rxXfE --ci ${{ inputs.pytest-extra-args }} --hypothesis-disable-deadline --durations 10" + PYTEST_ARGS: "--max-examples 500 -v -rxXfE --ci ${{ inputs.pytest-extra-args }} --hypothesis-disable-deadline --durations 20" jobs: tests: From 5b1ece468fb9b9b789304b57035de6801a39c7b1 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 12 May 2025 14:18:30 +0200 Subject: [PATCH 04/21] CI: use 4 workers; bump the # of examples to 1000 (np/torch), 200 (dask) --- .github/workflows/array-api-tests-dask.yml | 2 +- .github/workflows/array-api-tests-numpy-1-22.yml | 1 + .github/workflows/array-api-tests-numpy-1-26.yml | 1 + .github/workflows/array-api-tests-numpy-dev.yml | 1 + .github/workflows/array-api-tests-numpy-latest.yml | 1 + .github/workflows/array-api-tests-torch.yml | 1 + .github/workflows/array-api-tests.yml | 3 ++- 7 files changed, 8 insertions(+), 2 deletions(-) diff --git a/.github/workflows/array-api-tests-dask.yml b/.github/workflows/array-api-tests-dask.yml index a60b28a4..ef430d9c 100644 --- a/.github/workflows/array-api-tests-dask.yml +++ b/.github/workflows/array-api-tests-dask.yml @@ -14,7 +14,7 @@ jobs: # workflow is barely more than a smoke test, and one should expect extreme # flakiness. Before changes to dask-xfails.txt or dask-skips.txt, please run # the full test suite with at least 200 examples. - pytest-extra-args: --max-examples=50 + pytest-extra-args: --max-examples=200 -n 4 python-versions: '[''3.10'', ''3.13'']' extra-env-vars: | ARRAY_API_TESTS_XFAIL_MARK=skip diff --git a/.github/workflows/array-api-tests-numpy-1-22.yml b/.github/workflows/array-api-tests-numpy-1-22.yml index 1cf6e26d..83d4cf1d 100644 --- a/.github/workflows/array-api-tests-numpy-1-22.yml +++ b/.github/workflows/array-api-tests-numpy-1-22.yml @@ -10,5 +10,6 @@ jobs: package-version: '== 1.22.*' xfails-file-extra: '-1-22' python-versions: '[''3.10'']' + pytest-extra-args: -n 4 extra-env-vars: | ARRAY_API_TESTS_XFAIL_MARK=skip diff --git a/.github/workflows/array-api-tests-numpy-1-26.yml b/.github/workflows/array-api-tests-numpy-1-26.yml index a2788d2f..13124644 100644 --- a/.github/workflows/array-api-tests-numpy-1-26.yml +++ b/.github/workflows/array-api-tests-numpy-1-26.yml @@ -10,5 +10,6 @@ jobs: package-version: '== 1.26.*' xfails-file-extra: '-1-26' python-versions: '[''3.10'', ''3.12'']' + pytest-extra-args: -n 4 extra-env-vars: | ARRAY_API_TESTS_XFAIL_MARK=skip diff --git a/.github/workflows/array-api-tests-numpy-dev.yml b/.github/workflows/array-api-tests-numpy-dev.yml index dce0813f..dec4c7ae 100644 --- a/.github/workflows/array-api-tests-numpy-dev.yml +++ b/.github/workflows/array-api-tests-numpy-dev.yml @@ -10,5 +10,6 @@ jobs: extra-requires: '--pre --extra-index https://pypi.anaconda.org/scientific-python-nightly-wheels/simple' xfails-file-extra: '-dev' python-versions: '[''3.11'', ''3.13'']' + pytest-extra-args: -n 4 extra-env-vars: | ARRAY_API_TESTS_XFAIL_MARK=skip diff --git a/.github/workflows/array-api-tests-numpy-latest.yml b/.github/workflows/array-api-tests-numpy-latest.yml index 54b21a25..65bbc9a2 100644 --- a/.github/workflows/array-api-tests-numpy-latest.yml +++ b/.github/workflows/array-api-tests-numpy-latest.yml @@ -8,5 +8,6 @@ jobs: with: package-name: numpy python-versions: '[''3.10'', ''3.13'']' + pytest-extra-args: -n 4 extra-env-vars: | ARRAY_API_TESTS_XFAIL_MARK=skip diff --git a/.github/workflows/array-api-tests-torch.yml b/.github/workflows/array-api-tests-torch.yml index 4dcb3347..4b4b945e 100644 --- a/.github/workflows/array-api-tests-torch.yml +++ b/.github/workflows/array-api-tests-torch.yml @@ -12,3 +12,4 @@ jobs: ARRAY_API_TESTS_SKIP_DTYPES=uint16,uint32,uint64 ARRAY_API_TESTS_XFAIL_MARK=skip python-versions: '[''3.10'', ''3.13'']' + pytest-extra-args: -n 4 diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index f652438b..53c1474d 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -37,7 +37,7 @@ on: description: Multiline string of environment variables to set for the test run. env: - PYTEST_ARGS: "--max-examples 500 -v -rxXfE --ci ${{ inputs.pytest-extra-args }} --hypothesis-disable-deadline --durations 20" + PYTEST_ARGS: "--max-examples 1000 -v -rxXfE --ci ${{ inputs.pytest-extra-args }} --hypothesis-disable-deadline --durations 20" jobs: tests: @@ -76,6 +76,7 @@ jobs: python -m pip install --upgrade pip python -m pip install '${{ inputs.package-name }} ${{ inputs.package-version }}' ${{ inputs.extra-requires }} python -m pip install -r ${GITHUB_WORKSPACE}/array-api-tests/requirements.txt + python -m pip install pytest-xdist - name: Dump pip environment run: pip freeze From 37d5d668674f10ae41709a7ebc3d2ab2ae6b25c4 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 12 May 2025 16:42:01 +0200 Subject: [PATCH 05/21] MAINT: update numpy-1.22 xfails --- numpy-1-22-xfails.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/numpy-1-22-xfails.txt b/numpy-1-22-xfails.txt index cacb95b7..d4022b31 100644 --- a/numpy-1-22-xfails.txt +++ b/numpy-1-22-xfails.txt @@ -133,7 +133,7 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_sc array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[maximum] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[copysign] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[pow] -array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[pow] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[minimum] array_api_tests/test_operators_and_elementwise_functions.py::test_where_with_scalars From 43435808041951df2c7b7cae28204b3ce61f6e46 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 12 May 2025 16:42:29 +0200 Subject: [PATCH 06/21] MAINT: remove --ci pytest switch The warning says it's deprecated. --- .github/workflows/array-api-tests.yml | 2 +- numpy-1-22-xfails.txt | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index 53c1474d..e832f870 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -37,7 +37,7 @@ on: description: Multiline string of environment variables to set for the test run. env: - PYTEST_ARGS: "--max-examples 1000 -v -rxXfE --ci ${{ inputs.pytest-extra-args }} --hypothesis-disable-deadline --durations 20" + PYTEST_ARGS: "--max-examples 1000 -v -rxXfE ${{ inputs.pytest-extra-args }} --hypothesis-disable-deadline --durations 20" jobs: tests: diff --git a/numpy-1-22-xfails.txt b/numpy-1-22-xfails.txt index d4022b31..e1c4f832 100644 --- a/numpy-1-22-xfails.txt +++ b/numpy-1-22-xfails.txt @@ -134,6 +134,7 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_sc array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[copysign] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[pow] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[minimum] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[atan2] array_api_tests/test_operators_and_elementwise_functions.py::test_where_with_scalars From c03daa36c09d51162d240b77e223a49cc8a6076e Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sun, 1 Jun 2025 13:36:11 +0200 Subject: [PATCH 07/21] CI: install jax/sparse/torch in more jobs Also, `ndonnx` has wheels for all python versions now; And we do not bother with jax or dask numpy < 1. --- .github/workflows/tests.yml | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 81a05b3f..c995b370 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -32,20 +32,20 @@ jobs: python -m pip install --upgrade pip python -m pip install pytest + # Don't `pip install .[dev]` as it would pull in the whole torch cuda stack + python -m pip install array-api-strict + python -m pip install torch --index-url https://download.pytorch.org/whl/cpu + if [ "${{ matrix.numpy-version }}" == "dev" ]; then python -m pip install numpy --pre --extra-index-url https://pypi.anaconda.org/scientific-python-nightly-wheels/simple + python -m pip install dask[array] jax[cpu] sparse ndonnx elif [ "${{ matrix.numpy-version }}" == "1.22" ]; then python -m pip install 'numpy==1.22.*' elif [ "${{ matrix.numpy-version }}" == "1.26" ]; then python -m pip install 'numpy==1.26.*' else - # Don't `pip install .[dev]` as it would pull in the whole torch cuda stack - python -m pip install array-api-strict dask[array] jax[cpu] numpy sparse - python -m pip install torch --index-url https://download.pytorch.org/whl/cpu - if [ "${{ matrix.python-version }}" != "3.13" ]; then - # onnx wheels are not available on Python 3.13 at the moment of writing - python -m pip install ndonnx - fi + python -m pip install numpy + python -m pip install dask[array] jax[cpu] sparse ndonnx fi - name: Dump pip environment From a8e19835092335ab8e1846f1e3dda335d8eb4c4a Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sun, 1 Jun 2025 13:38:31 +0200 Subject: [PATCH 08/21] TST: xfail test_device_to_device with numpy < 2 It assumes that asarray has the copy kwarg, and this is not true in NumPy < 2. --- tests/test_common.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_common.py b/tests/test_common.py index 54b5ed69..85ed032e 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -195,6 +195,9 @@ def test_device_to_device(library, request): xfail(request, reason="Stub raises ValueError") if library == "sparse": xfail(request, reason="No __array_namespace_info__()") + if library == "array_api_strict": + if np.__version__ < "2": + xfail(request, reason="no copy argument of np.asarray") xp = import_(library, wrapper=True) devices = xp.__array_namespace_info__().devices() From 8e3ab3e7c5c6794f66196ec435d2f6bdd1492404 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sun, 1 Jun 2025 13:39:28 +0200 Subject: [PATCH 09/21] MAINT: filter out some warning noise --- tests/test_array_namespace.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/tests/test_array_namespace.py b/tests/test_array_namespace.py index cdb80007..2fbb0339 100644 --- a/tests/test_array_namespace.py +++ b/tests/test_array_namespace.py @@ -23,7 +23,9 @@ def test_array_namespace(library, api_version, use_compat): if library == "ndonnx" and api_version in ("2021.12", "2022.12"): pytest.skip("Unsupported API version") - namespace = array_namespace(array, api_version=api_version, use_compat=use_compat) + with warnings.catch_warnings(): + warnings.simplefilter('ignore', UserWarning) + namespace = array_namespace(array, api_version=api_version, use_compat=use_compat) if use_compat is False or use_compat is None and library not in wrapped_libraries: if library == "jax.numpy" and use_compat is None: @@ -45,10 +47,13 @@ def test_array_namespace(library, api_version, use_compat): if library == "numpy": # check that the same namespace is returned for NumPy scalars - scalar_namespace = array_namespace( - xp.float64(0.0), api_version=api_version, use_compat=use_compat - ) - assert scalar_namespace == namespace + with warnings.catch_warnings(): + warnings.simplefilter('ignore', UserWarning) + + scalar_namespace = array_namespace( + xp.float64(0.0), api_version=api_version, use_compat=use_compat + ) + assert scalar_namespace == namespace # Check that array_namespace works even if jax.experimental.array_api # hasn't been imported yet (it monkeypatches __array_namespace__ @@ -97,7 +102,9 @@ def test_api_version_torch(): torch = import_("torch") x = torch.asarray([1, 2]) torch_ = import_("torch", wrapper=True) - assert array_namespace(x, api_version="2023.12") == torch_ + with warnings.catch_warnings(): + warnings.simplefilter('ignore', UserWarning) + assert array_namespace(x, api_version="2023.12") == torch_ assert array_namespace(x, api_version=None) == torch_ assert array_namespace(x) == torch_ # Should issue a warning From 9959873e351ecab696538650893afc2faef17a38 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 2 Jun 2025 10:10:33 +0000 Subject: [PATCH 10/21] Bump dawidd6/action-download-artifact from 9 to 10 in the actions group Bumps the actions group with 1 update: [dawidd6/action-download-artifact](https://github.com/dawidd6/action-download-artifact). Updates `dawidd6/action-download-artifact` from 9 to 10 - [Release notes](https://github.com/dawidd6/action-download-artifact/releases) - [Commits](https://github.com/dawidd6/action-download-artifact/compare/v9...v10) --- updated-dependencies: - dependency-name: dawidd6/action-download-artifact dependency-version: '10' dependency-type: direct:production update-type: version-update:semver-major dependency-group: actions ... Signed-off-by: dependabot[bot] --- .github/workflows/docs-deploy.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/docs-deploy.yml b/.github/workflows/docs-deploy.yml index fc612588..4e3efb39 100644 --- a/.github/workflows/docs-deploy.yml +++ b/.github/workflows/docs-deploy.yml @@ -13,7 +13,7 @@ jobs: steps: - uses: actions/checkout@v4 - name: Download Artifact - uses: dawidd6/action-download-artifact@v9 + uses: dawidd6/action-download-artifact@v10 with: workflow: docs-build.yml name: docs-build From 6ae28ee9538820ae09ba45d8ef3d15d4a6570900 Mon Sep 17 00:00:00 2001 From: Guido Imperiale Date: Mon, 2 Jun 2025 18:38:34 +0100 Subject: [PATCH 11/21] ENH: speed up `array_namespace` * ENH: speed up `array_namespace` * jax 0.6.1 Reviewed at https://github.com/data-apis/array-api-compat/pull/329 --- array_api_compat/common/_helpers.py | 196 +++++++++++++++------------- tests/test_array_namespace.py | 121 ++++++++--------- 2 files changed, 161 insertions(+), 156 deletions(-) diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 77175d0d..a152e4c0 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -8,6 +8,7 @@ from __future__ import annotations +import enum import inspect import math import sys @@ -485,6 +486,86 @@ def _check_api_version(api_version: str | None) -> None: ) +class _ClsToXPInfo(enum.Enum): + SCALAR = 0 + MAYBE_JAX_ZERO_GRADIENT = 1 + + +@lru_cache(100) +def _cls_to_namespace( + cls: type, + api_version: str | None, + use_compat: bool | None, +) -> tuple[Namespace | None, _ClsToXPInfo | None]: + if use_compat not in (None, True, False): + raise ValueError("use_compat must be None, True, or False") + _use_compat = use_compat in (None, True) + cls_ = cast(Hashable, cls) # Make mypy happy + + if ( + _issubclass_fast(cls_, "numpy", "ndarray") + or _issubclass_fast(cls_, "numpy", "generic") + ): + if use_compat is True: + _check_api_version(api_version) + from .. import numpy as xp + elif use_compat is False: + import numpy as xp # type: ignore[no-redef] + else: + # NumPy 2.0+ have __array_namespace__; however they are not + # yet fully array API compatible. + from .. import numpy as xp # type: ignore[no-redef] + return xp, _ClsToXPInfo.MAYBE_JAX_ZERO_GRADIENT + + # Note: this must happen _after_ the test for np.generic, + # because np.float64 and np.complex128 are subclasses of float and complex. + if issubclass(cls, int | float | complex | type(None)): + return None, _ClsToXPInfo.SCALAR + + if _issubclass_fast(cls_, "cupy", "ndarray"): + if _use_compat: + _check_api_version(api_version) + from .. import cupy as xp # type: ignore[no-redef] + else: + import cupy as xp # type: ignore[no-redef] + return xp, None + + if _issubclass_fast(cls_, "torch", "Tensor"): + if _use_compat: + _check_api_version(api_version) + from .. import torch as xp # type: ignore[no-redef] + else: + import torch as xp # type: ignore[no-redef] + return xp, None + + if _issubclass_fast(cls_, "dask.array", "Array"): + if _use_compat: + _check_api_version(api_version) + from ..dask import array as xp # type: ignore[no-redef] + else: + import dask.array as xp # type: ignore[no-redef] + return xp, None + + # Backwards compatibility for jax<0.4.32 + if _issubclass_fast(cls_, "jax", "Array"): + return _jax_namespace(api_version, use_compat), None + + return None, None + + +def _jax_namespace(api_version: str | None, use_compat: bool | None) -> Namespace: + if use_compat: + raise ValueError("JAX does not have an array-api-compat wrapper") + import jax.numpy as jnp + if not hasattr(jnp, "__array_namespace_info__"): + # JAX v0.4.32 and newer implements the array API directly in jax.numpy. + # For older JAX versions, it is available via jax.experimental.array_api. + # jnp.Array objects gain the __array_namespace__ method. + import jax.experimental.array_api # noqa: F401 + # Test api_version + return jnp.empty(0).__array_namespace__(api_version=api_version) + + def array_namespace( *xs: Array | complex | None, api_version: str | None = None, @@ -553,105 +634,40 @@ def your_function(x, y): is_pydata_sparse_array """ - if use_compat not in [None, True, False]: - raise ValueError("use_compat must be None, True, or False") - - _use_compat = use_compat in [None, True] - namespaces: set[Namespace] = set() for x in xs: - if is_numpy_array(x): - import numpy as np - - from .. import numpy as numpy_namespace - - if use_compat is True: - _check_api_version(api_version) - namespaces.add(numpy_namespace) - elif use_compat is False: - namespaces.add(np) - else: - # numpy 2.0+ have __array_namespace__, however, they are not yet fully array API - # compatible. - namespaces.add(numpy_namespace) - elif is_cupy_array(x): - if _use_compat: - _check_api_version(api_version) - from .. import cupy as cupy_namespace - - namespaces.add(cupy_namespace) - else: - import cupy as cp # pyright: ignore[reportMissingTypeStubs] - - namespaces.add(cp) - elif is_torch_array(x): - if _use_compat: - _check_api_version(api_version) - from .. import torch as torch_namespace - - namespaces.add(torch_namespace) - else: - import torch - - namespaces.add(torch) - elif is_dask_array(x): - if _use_compat: - _check_api_version(api_version) - from ..dask import array as dask_namespace - - namespaces.add(dask_namespace) - else: - import dask.array as da - - namespaces.add(da) - elif is_jax_array(x): - if use_compat is True: - _check_api_version(api_version) - raise ValueError("JAX does not have an array-api-compat wrapper") - elif use_compat is False: - import jax.numpy as jnp - else: - # JAX v0.4.32 and newer implements the array API directly in jax.numpy. - # For older JAX versions, it is available via jax.experimental.array_api. - import jax.numpy - - if hasattr(jax.numpy, "__array_api_version__"): - jnp = jax.numpy - else: - import jax.experimental.array_api as jnp # pyright: ignore[reportMissingImports] - namespaces.add(jnp) - elif is_pydata_sparse_array(x): - if use_compat is True: - _check_api_version(api_version) - raise ValueError("`sparse` does not have an array-api-compat wrapper") - else: - import sparse # pyright: ignore[reportMissingTypeStubs] - # `sparse` is already an array namespace. We do not have a wrapper - # submodule for it. - namespaces.add(sparse) - elif hasattr(x, "__array_namespace__"): - if use_compat is True: + xp, info = _cls_to_namespace(cast(Hashable, type(x)), api_version, use_compat) + if info is _ClsToXPInfo.SCALAR: + continue + + if ( + info is _ClsToXPInfo.MAYBE_JAX_ZERO_GRADIENT + and _is_jax_zero_gradient_array(x) + ): + xp = _jax_namespace(api_version, use_compat) + + if xp is None: + get_ns = getattr(x, "__array_namespace__", None) + if get_ns is None: + raise TypeError(f"{type(x).__name__} is not a supported array type") + if use_compat: raise ValueError( "The given array does not have an array-api-compat wrapper" ) - x = cast("SupportsArrayNamespace[Any]", x) - namespaces.add(x.__array_namespace__(api_version=api_version)) - elif isinstance(x, (bool, int, float, complex, type(None))): - continue - else: - # TODO: Support Python scalars? - raise TypeError(f"{type(x).__name__} is not a supported array type") + xp = get_ns(api_version=api_version) - if not namespaces: - raise TypeError("Unrecognized array input") + namespaces.add(xp) - if len(namespaces) != 1: + try: + (xp,) = namespaces + return xp + except ValueError: + if not namespaces: + raise TypeError( + "array_namespace requires at least one non-scalar array input" + ) raise TypeError(f"Multiple namespaces for array inputs: {namespaces}") - (xp,) = namespaces - - return xp - # backwards compatibility alias get_namespace = array_namespace diff --git a/tests/test_array_namespace.py b/tests/test_array_namespace.py index 2fbb0339..311efc37 100644 --- a/tests/test_array_namespace.py +++ b/tests/test_array_namespace.py @@ -8,42 +8,41 @@ import array_api_compat from array_api_compat import array_namespace -from ._helpers import import_, all_libraries, wrapped_libraries +from ._helpers import all_libraries, wrapped_libraries, xfail + @pytest.mark.parametrize("use_compat", [True, False, None]) -@pytest.mark.parametrize("api_version", [None, "2021.12", "2022.12", "2023.12"]) +@pytest.mark.parametrize( + "api_version", [None, "2021.12", "2022.12", "2023.12", "2024.12"] +) @pytest.mark.parametrize("library", all_libraries) -def test_array_namespace(library, api_version, use_compat): - xp = import_(library) +def test_array_namespace(request, library, api_version, use_compat): + xp = pytest.importorskip(library) array = xp.asarray([1.0, 2.0, 3.0]) if use_compat and library not in wrapped_libraries: pytest.raises(ValueError, lambda: array_namespace(array, use_compat=use_compat)) return - if library == "ndonnx" and api_version in ("2021.12", "2022.12"): - pytest.skip("Unsupported API version") + if (library == "sparse" and api_version in ("2023.12", "2024.12")) or ( + library == "jax.numpy" and api_version in ("2021.12", "2022.12", "2023.12") + ): + xfail(request, "Unsupported API version") with warnings.catch_warnings(): warnings.simplefilter('ignore', UserWarning) namespace = array_namespace(array, api_version=api_version, use_compat=use_compat) if use_compat is False or use_compat is None and library not in wrapped_libraries: - if library == "jax.numpy" and use_compat is None: - import jax.numpy - if hasattr(jax.numpy, "__array_api_version__"): - # JAX v0.4.32 or later uses jax.numpy directly - assert namespace == jax.numpy - else: - # JAX v0.4.31 or earlier uses jax.experimental.array_api - import jax.experimental.array_api - assert namespace == jax.experimental.array_api + if library == "jax.numpy" and not hasattr(xp, "__array_api_version__"): + # Backwards compatibility for JAX <0.4.32 + import jax.experimental.array_api + assert namespace == jax.experimental.array_api else: assert namespace == xp + elif library == "dask.array": + assert namespace == array_api_compat.dask.array else: - if library == "dask.array": - assert namespace == array_api_compat.dask.array - else: - assert namespace == getattr(array_api_compat, library) + assert namespace == getattr(array_api_compat, library) if library == "numpy": # check that the same namespace is returned for NumPy scalars @@ -55,20 +54,20 @@ def test_array_namespace(library, api_version, use_compat): ) assert scalar_namespace == namespace - # Check that array_namespace works even if jax.experimental.array_api - # hasn't been imported yet (it monkeypatches __array_namespace__ - # onto JAX arrays, but we should support them regardless). The only way to - # do this is to use a subprocess, since we cannot un-import it and another - # test probably already imported it. - if library == "jax.numpy" and sys.version_info >= (3, 9): - code = f"""\ + +def test_jax_backwards_compat(): + """On JAX <0.4.32, test that array_namespace works even if + jax.experimental.array_api has not been imported yet. + """ + pytest.importorskip("jax") + code = """\ import sys import jax.numpy import array_api_compat -array = jax.numpy.asarray([1.0, 2.0, 3.0]) +array = jax.numpy.asarray([1.0, 2.0, 3.0]) assert 'jax.experimental.array_api' not in sys.modules -namespace = array_api_compat.array_namespace(array, api_version={api_version!r}) +namespace = array_api_compat.array_namespace(array) if hasattr(jax.numpy, '__array_api_version__'): assert namespace == jax.numpy @@ -76,14 +75,16 @@ def test_array_namespace(library, api_version, use_compat): import jax.experimental.array_api assert namespace == jax.experimental.array_api """ - subprocess.run([sys.executable, "-c", code], check=True) + subprocess.check_call([sys.executable, "-c", code]) + def test_jax_zero_gradient(): - jax = import_("jax") + jax = pytest.importorskip("jax") jx = jax.numpy.arange(4) jax_zero = jax.vmap(jax.grad(jax.numpy.float32, allow_int=True))(jx) assert array_namespace(jax_zero) is array_namespace(jx) + def test_array_namespace_errors(): pytest.raises(TypeError, lambda: array_namespace([1])) pytest.raises(TypeError, lambda: array_namespace()) @@ -92,43 +93,31 @@ def test_array_namespace_errors(): pytest.raises(TypeError, lambda: array_namespace((x, x))) pytest.raises(TypeError, lambda: array_namespace(x, (x, x))) -def test_array_namespace_errors_torch(): - torch = import_("torch") - y = torch.asarray([1, 2]) - x = np.asarray([1, 2]) - pytest.raises(TypeError, lambda: array_namespace(x, y)) -def test_api_version_torch(): - torch = import_("torch") - x = torch.asarray([1, 2]) - torch_ = import_("torch", wrapper=True) - with warnings.catch_warnings(): - warnings.simplefilter('ignore', UserWarning) - assert array_namespace(x, api_version="2023.12") == torch_ - assert array_namespace(x, api_version=None) == torch_ - assert array_namespace(x) == torch_ - # Should issue a warning - with warnings.catch_warnings(record=True) as w: - assert array_namespace(x, api_version="2021.12") == torch_ - assert len(w) == 1 - assert "2021.12" in str(w[0].message) - - # Should issue a warning - with warnings.catch_warnings(record=True) as w: - assert array_namespace(x, api_version="2022.12") == torch_ - assert len(w) == 1 - assert "2022.12" in str(w[0].message) - - pytest.raises(ValueError, lambda: array_namespace(x, api_version="2020.12")) +@pytest.mark.parametrize("library", all_libraries) +def test_array_namespace_many_args(library): + xp = pytest.importorskip(library) + a = xp.asarray(1) + b = xp.asarray(2) + assert array_namespace(a, b) is array_namespace(a) + + +def test_array_namespace_mismatch(): + xp = pytest.importorskip("array_api_strict") + with pytest.raises(TypeError, match="Multiple namespaces"): + array_namespace(np.asarray(1), xp.asarray(1)) + def test_get_namespace(): # Backwards compatible wrapper assert array_api_compat.get_namespace is array_namespace -def test_python_scalars(): - torch = import_("torch") - a = torch.asarray([1, 2]) - xp = import_("torch", wrapper=True) + +@pytest.mark.parametrize("library", all_libraries) +def test_python_scalars(library): + xp = pytest.importorskip(library) + a = xp.asarray([1, 2]) + xp = array_namespace(a) pytest.raises(TypeError, lambda: array_namespace(1)) pytest.raises(TypeError, lambda: array_namespace(1.0)) @@ -136,8 +125,8 @@ def test_python_scalars(): pytest.raises(TypeError, lambda: array_namespace(True)) pytest.raises(TypeError, lambda: array_namespace(None)) - assert array_namespace(a, 1) == xp - assert array_namespace(a, 1.0) == xp - assert array_namespace(a, 1j) == xp - assert array_namespace(a, True) == xp - assert array_namespace(a, None) == xp + assert array_namespace(a, 1) is xp + assert array_namespace(a, 1.0) is xp + assert array_namespace(a, 1j) is xp + assert array_namespace(a, True) is xp + assert array_namespace(a, None) is xp From c9cfc2c9193fcdf0e52a2bbdace54182780839c9 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sun, 1 Jun 2025 11:37:49 +0200 Subject: [PATCH 12/21] TST: add a test that wrapping preserves a view/copy semantics for unary functions If a bare library returns a copy, so does the wrapped library; if the bare library returns a view, so does the wrapped library. --- tests/test_copies_or_views.py | 66 +++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) create mode 100644 tests/test_copies_or_views.py diff --git a/tests/test_copies_or_views.py b/tests/test_copies_or_views.py new file mode 100644 index 00000000..5b9b9207 --- /dev/null +++ b/tests/test_copies_or_views.py @@ -0,0 +1,66 @@ +""" +A collection of tests to make sure that wrapped namespaces agree with the bare ones +on whether to return a view or a copy of inputs. +""" +import pytest +from ._helpers import import_ + + +LIB_NAMES = ['numpy', 'cupy', 'torch', 'dask.array', 'array_api_strict'] + +FUNC_INPUTS = [ + # func_name, arr_input, dtype, scalar_value + ('abs', [1, 2], 'int8', 3), + ('abs', [1, 2], 'float32', 3.), + ('ceil', [1, 2], 'int8', 3), + ('clip', [1, 2], 'int8', 3), + ('conj', [1, 2], 'int8', 3), + ('floor', [1, 2], 'int8', 3), + ('imag', [1j, 2j], 'complex64', 3), + ('positive', [1, 2], 'int8', 3), + ('real', [1., 2.], 'float32', 3.), + ('round', [1, 2], 'int8', 3), + ('sign', [0, 0], 'float32', 3), + ('trunc', [1, 2], 'int8', 3), + ('trunc', [1, 2], 'float32', 3), +] + + +def ensure_unary(func, arr): + """Make a trivial unary function from func.""" + if func.__name__ == 'clip': + return lambda x: func(x, arr[0], arr[1]) + return func + + +def is_view(func, a, value): + """Apply `func`, mutate the output; does the input change?""" + b = func(a) + b[0] = value + return a[0] == value + + +@pytest.mark.parametrize('xp_name', LIB_NAMES) +@pytest.mark.parametrize('inputs', FUNC_INPUTS, ids=[inp[0] for inp in FUNC_INPUTS]) +def test_view_or_copy(inputs, xp_name): + bare_xp = import_(xp_name, wrapper=False) + wrapped_xp = import_(xp_name, wrapper=True) + + func_name, arr_input, dtype_str, value = inputs + dtype = getattr(bare_xp, dtype_str) + + bare_func = getattr(bare_xp, func_name) + bare_func = ensure_unary(bare_func, arr_input) + + wrapped_func = getattr(wrapped_xp, func_name) + wrapped_func = ensure_unary(wrapped_func, arr_input) + + # bare namespace: mutate the output, does the input change? + a = bare_xp.asarray(arr_input, dtype=dtype) + is_view_bare = is_view(bare_func, a, value) + + # wrapped namespace: mutate the output, does the input change? + a1 = wrapped_xp.asarray(arr_input, dtype=dtype) + is_view_wrapped = is_view(wrapped_func, a1, value) + + assert is_view_bare == is_view_wrapped From 1facc3526414926b2d123e88c16f7d517d9d2558 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sun, 1 Jun 2025 14:30:54 +0200 Subject: [PATCH 13/21] BUG: make ceil,trunc,floor always respect view/copy semantics Remove these functions from common/_aliases.py, add specific implementations for numpy < 2 and cupy. --- array_api_compat/common/_aliases.py | 24 -------------------- array_api_compat/cupy/_aliases.py | 24 ++++++++++++++++---- array_api_compat/dask/array/_aliases.py | 3 --- array_api_compat/numpy/_aliases.py | 29 ++++++++++++++++++++++--- 4 files changed, 46 insertions(+), 34 deletions(-) diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 8ea9162a..39d10860 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -524,27 +524,6 @@ def nonzero(x: Array, /, xp: Namespace, **kwargs: object) -> tuple[Array, ...]: return xp.nonzero(x, **kwargs) -# ceil, floor, and trunc return integers for integer inputs - - -def ceil(x: Array, /, xp: Namespace, **kwargs: object) -> Array: - if xp.issubdtype(x.dtype, xp.integer): - return x - return xp.ceil(x, **kwargs) - - -def floor(x: Array, /, xp: Namespace, **kwargs: object) -> Array: - if xp.issubdtype(x.dtype, xp.integer): - return x - return xp.floor(x, **kwargs) - - -def trunc(x: Array, /, xp: Namespace, **kwargs: object) -> Array: - if xp.issubdtype(x.dtype, xp.integer): - return x - return xp.trunc(x, **kwargs) - - # linear algebra functions @@ -707,9 +686,6 @@ def iinfo(type_: DType | Array, /, xp: Namespace) -> Any: "argsort", "sort", "nonzero", - "ceil", - "floor", - "trunc", "matmul", "matrix_transpose", "tensordot", diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index 90b48f05..e000602e 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -54,9 +54,6 @@ argsort = get_xp(cp)(_aliases.argsort) sort = get_xp(cp)(_aliases.sort) nonzero = get_xp(cp)(_aliases.nonzero) -ceil = get_xp(cp)(_aliases.ceil) -floor = get_xp(cp)(_aliases.floor) -trunc = get_xp(cp)(_aliases.trunc) matmul = get_xp(cp)(_aliases.matmul) matrix_transpose = get_xp(cp)(_aliases.matrix_transpose) tensordot = get_xp(cp)(_aliases.tensordot) @@ -123,6 +120,25 @@ def count_nonzero( return cp.expand_dims(result, axis) return result +# ceil, floor, and trunc return integers for integer inputs + +def ceil(x: Array, /) -> Array: + if cp.issubdtype(x.dtype, cp.integer): + return x.copy() + return cp.ceil(x) + + +def floor(x: Array, /) -> Array: + if cp.issubdtype(x.dtype, cp.integer): + return x.copy() + return cp.floor(x) + + +def trunc(x: Array, /) -> Array: + if cp.issubdtype(x.dtype, cp.integer): + return x.copy() + return cp.trunc(x) + # take_along_axis: axis defaults to -1 but in cupy (and numpy) axis is a required arg def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1): @@ -151,6 +167,6 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1): 'atan2', 'atanh', 'bitwise_left_shift', 'bitwise_invert', 'bitwise_right_shift', 'bool', 'concat', 'count_nonzero', 'pow', 'sign', - 'take_along_axis'] + 'ceil', 'floor', 'trunc', 'take_along_axis'] _all_ignore = ['cp', 'get_xp'] diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index d43881ab..0bb5d227 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -134,9 +134,6 @@ def arange( matrix_transpose = get_xp(da)(_aliases.matrix_transpose) vecdot = get_xp(da)(_aliases.vecdot) nonzero = get_xp(da)(_aliases.nonzero) -ceil = get_xp(np)(_aliases.ceil) -floor = get_xp(np)(_aliases.floor) -trunc = get_xp(np)(_aliases.trunc) matmul = get_xp(np)(_aliases.matmul) tensordot = get_xp(np)(_aliases.tensordot) sign = get_xp(np)(_aliases.sign) diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index a1aee5c0..502dfb3a 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -63,9 +63,6 @@ argsort = get_xp(np)(_aliases.argsort) sort = get_xp(np)(_aliases.sort) nonzero = get_xp(np)(_aliases.nonzero) -ceil = get_xp(np)(_aliases.ceil) -floor = get_xp(np)(_aliases.floor) -trunc = get_xp(np)(_aliases.trunc) matmul = get_xp(np)(_aliases.matmul) matrix_transpose = get_xp(np)(_aliases.matrix_transpose) tensordot = get_xp(np)(_aliases.tensordot) @@ -145,6 +142,29 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1): return np.take_along_axis(x, indices, axis=axis) +# ceil, floor, and trunc return integers for integer inputs in NumPy < 2 + +def ceil(x: Array, /) -> Array: + if np.issubdtype(x.dtype, np.integer): + if np.__version__ < '2': + return x.copy() + return np.ceil(x) + + +def floor(x: Array, /) -> Array: + if np.issubdtype(x.dtype, np.integer): + if np.__version__ < '2': + return x.copy() + return np.floor(x) + + +def trunc(x: Array, /) -> Array: + if np.issubdtype(x.dtype, np.integer): + if np.__version__ < '2': + return x.copy() + return np.trunc(x) + + # These functions are completely new here. If the library already has them # (i.e., numpy 2.0), use the library version instead of our wrapper. if hasattr(np, "vecdot"): @@ -173,6 +193,9 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1): "atan", "atan2", "atanh", + "ceil", + "floor", + "trunc", "bitwise_left_shift", "bitwise_invert", "bitwise_right_shift", From 0ad664bdfde03ec3f21d82b1048616ae5d0fb6b7 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Wed, 4 Jun 2025 10:49:43 +0200 Subject: [PATCH 14/21] Apply suggestions from code review Co-authored-by: Guido Imperiale --- tests/test_copies_or_views.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_copies_or_views.py b/tests/test_copies_or_views.py index 5b9b9207..24d03547 100644 --- a/tests/test_copies_or_views.py +++ b/tests/test_copies_or_views.py @@ -3,7 +3,7 @@ on whether to return a view or a copy of inputs. """ import pytest -from ._helpers import import_ +from ._helpers import import_, wrapped_libraries LIB_NAMES = ['numpy', 'cupy', 'torch', 'dask.array', 'array_api_strict'] @@ -40,7 +40,7 @@ def is_view(func, a, value): return a[0] == value -@pytest.mark.parametrize('xp_name', LIB_NAMES) +@pytest.mark.parametrize('xp_name', wrapped_libraries) @pytest.mark.parametrize('inputs', FUNC_INPUTS, ids=[inp[0] for inp in FUNC_INPUTS]) def test_view_or_copy(inputs, xp_name): bare_xp = import_(xp_name, wrapper=False) From 118ae2d0428be763abf1e31b2827a4800398e901 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Wed, 4 Jun 2025 10:52:46 +0200 Subject: [PATCH 15/21] TST: test views vs copies on array-api-strict, too --- tests/test_copies_or_views.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_copies_or_views.py b/tests/test_copies_or_views.py index 24d03547..ec8995f7 100644 --- a/tests/test_copies_or_views.py +++ b/tests/test_copies_or_views.py @@ -6,8 +6,6 @@ from ._helpers import import_, wrapped_libraries -LIB_NAMES = ['numpy', 'cupy', 'torch', 'dask.array', 'array_api_strict'] - FUNC_INPUTS = [ # func_name, arr_input, dtype, scalar_value ('abs', [1, 2], 'int8', 3), @@ -40,7 +38,7 @@ def is_view(func, a, value): return a[0] == value -@pytest.mark.parametrize('xp_name', wrapped_libraries) +@pytest.mark.parametrize('xp_name', wrapped_libraries + ['array_api_strict']) @pytest.mark.parametrize('inputs', FUNC_INPUTS, ids=[inp[0] for inp in FUNC_INPUTS]) def test_view_or_copy(inputs, xp_name): bare_xp = import_(xp_name, wrapper=False) From b0eed557d6dba8c87d9693ff82360b33c1af3480 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Wed, 4 Jun 2025 13:05:08 +0200 Subject: [PATCH 16/21] Apply suggestions from code review Co-authored-by: Guido Imperiale --- array_api_compat/numpy/_aliases.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index 502dfb3a..f04837de 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -145,23 +145,20 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1): # ceil, floor, and trunc return integers for integer inputs in NumPy < 2 def ceil(x: Array, /) -> Array: - if np.issubdtype(x.dtype, np.integer): - if np.__version__ < '2': - return x.copy() + if np.__version__ < '2' and np.issubdtype(x.dtype, np.integer): + return x.copy() return np.ceil(x) def floor(x: Array, /) -> Array: - if np.issubdtype(x.dtype, np.integer): - if np.__version__ < '2': - return x.copy() + if np.__version__ < '2' and np.issubdtype(x.dtype, np.integer): + return x.copy() return np.floor(x) def trunc(x: Array, /) -> Array: - if np.issubdtype(x.dtype, np.integer): - if np.__version__ < '2': - return x.copy() + if np.__version__ < '2' and np.issubdtype(x.dtype, np.integer): + return x.copy() return np.trunc(x) From 2b559e62e05ebea3dd3ab631aee47b270109eaa1 Mon Sep 17 00:00:00 2001 From: Guido Imperiale Date: Wed, 4 Jun 2025 15:36:16 +0100 Subject: [PATCH 17/21] TYP: Type annotations, part 4 (#313) * Type annotations, part 4 * Fix CopyMode * revert * Revert `_all_ignore` * code review * code review * JustInt mypy ignores * lint * fix merge * lint * Reverts and tweaks * Fix test_all * Revert batmobile --- array_api_compat/_internal.py | 4 +- array_api_compat/common/_aliases.py | 13 +- array_api_compat/common/_helpers.py | 34 ++--- array_api_compat/common/_linalg.py | 6 +- array_api_compat/common/_typing.py | 15 +- array_api_compat/cupy/_aliases.py | 31 ++-- array_api_compat/cupy/fft.py | 9 +- array_api_compat/cupy/linalg.py | 2 +- array_api_compat/dask/array/__init__.py | 2 +- array_api_compat/dask/array/_aliases.py | 2 +- array_api_compat/dask/array/_info.py | 47 +++--- array_api_compat/dask/array/fft.py | 2 +- array_api_compat/dask/array/linalg.py | 19 +-- array_api_compat/numpy/__init__.py | 2 +- array_api_compat/numpy/_aliases.py | 33 ++-- array_api_compat/numpy/_info.py | 3 +- array_api_compat/numpy/linalg.py | 4 +- array_api_compat/torch/_aliases.py | 192 ++++++++++++------------ array_api_compat/torch/fft.py | 19 +-- array_api_compat/torch/linalg.py | 12 +- pyproject.toml | 56 ++++--- 21 files changed, 246 insertions(+), 261 deletions(-) diff --git a/array_api_compat/_internal.py b/array_api_compat/_internal.py index cd8d939f..b1925492 100644 --- a/array_api_compat/_internal.py +++ b/array_api_compat/_internal.py @@ -46,8 +46,8 @@ def wrapped_f(*args: object, **kwargs: object) -> object: specification for more details. """ - wrapped_f.__signature__ = new_sig # pyright: ignore[reportAttributeAccessIssue] - return wrapped_f # pyright: ignore[reportReturnType] + wrapped_f.__signature__ = new_sig # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] + return wrapped_f # type: ignore[return-value] # pyright: ignore[reportReturnType] return inner diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 8ea9162a..51732b91 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -5,11 +5,12 @@ from __future__ import annotations import inspect -from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Sequence, cast +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any, NamedTuple, cast from ._helpers import _check_device, array_namespace from ._helpers import device as _get_device -from ._helpers import is_cupy_namespace as _is_cupy_namespace +from ._helpers import is_cupy_namespace from ._typing import Array, Device, DType, Namespace if TYPE_CHECKING: @@ -381,8 +382,8 @@ def clip( # TODO: np.clip has other ufunc kwargs out: Array | None = None, ) -> Array: - def _isscalar(a: object) -> TypeIs[int | float | None]: - return isinstance(a, (int, float, type(None))) + def _isscalar(a: object) -> TypeIs[float | None]: + return isinstance(a, int | float) or a is None min_shape = () if _isscalar(min) else min.shape max_shape = () if _isscalar(max) else max.shape @@ -450,7 +451,7 @@ def reshape( shape: tuple[int, ...], xp: Namespace, *, - copy: Optional[bool] = None, + copy: bool | None = None, **kwargs: object, ) -> Array: if copy is True: @@ -657,7 +658,7 @@ def sign(x: Array, /, xp: Namespace, **kwargs: object) -> Array: out = xp.sign(x, **kwargs) # CuPy sign() does not propagate nans. See # https://github.com/data-apis/array-api-compat/issues/136 - if _is_cupy_namespace(xp) and isdtype(x.dtype, "real floating", xp=xp): + if is_cupy_namespace(xp) and isdtype(x.dtype, "real floating", xp=xp): out[xp.isnan(x)] = xp.nan return out[()] diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index a152e4c0..cae0ee0b 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -23,7 +23,6 @@ SupportsIndex, TypeAlias, TypeGuard, - TypeVar, cast, overload, ) @@ -31,32 +30,29 @@ from ._typing import Array, Device, HasShape, Namespace, SupportsArrayNamespace if TYPE_CHECKING: - + import cupy as cp import dask.array as da import jax import ndonnx as ndx import numpy as np import numpy.typing as npt - import sparse # pyright: ignore[reportMissingTypeStubs] + import sparse import torch # TODO: import from typing (requires Python >=3.13) - from typing_extensions import TypeIs, TypeVar - - _SizeT = TypeVar("_SizeT", bound = int | None) + from typing_extensions import TypeIs _ZeroGradientArray: TypeAlias = npt.NDArray[np.void] - _CupyArray: TypeAlias = Any # cupy has no py.typed _ArrayApiObj: TypeAlias = ( npt.NDArray[Any] + | cp.ndarray | da.Array | jax.Array | ndx.Array | sparse.SparseArray | torch.Tensor | SupportsArrayNamespace[Any] - | _CupyArray ) _API_VERSIONS_OLD: Final = frozenset({"2021.12", "2022.12", "2023.12"}) @@ -96,7 +92,7 @@ def _is_jax_zero_gradient_array(x: object) -> TypeGuard[_ZeroGradientArray]: return dtype == jax.float0 -def is_numpy_array(x: object) -> TypeGuard[npt.NDArray[Any]]: +def is_numpy_array(x: object) -> TypeIs[npt.NDArray[Any]]: """ Return True if `x` is a NumPy array. @@ -267,7 +263,7 @@ def is_pydata_sparse_array(x: object) -> TypeIs[sparse.SparseArray]: return _issubclass_fast(cls, "sparse", "SparseArray") -def is_array_api_obj(x: object) -> TypeIs[_ArrayApiObj]: # pyright: ignore[reportUnknownParameterType] +def is_array_api_obj(x: object) -> TypeGuard[_ArrayApiObj]: """ Return True if `x` is an array API compatible array object. @@ -748,7 +744,7 @@ def device(x: _ArrayApiObj, /) -> Device: return "cpu" elif is_dask_array(x): # Peek at the metadata of the Dask array to determine type - if is_numpy_array(x._meta): # pyright: ignore + if is_numpy_array(x._meta): # Must be on CPU since backed by numpy return "cpu" return _DASK_DEVICE @@ -777,7 +773,7 @@ def device(x: _ArrayApiObj, /) -> Device: return "cpu" # Return the device of the constituent array return device(inner) # pyright: ignore - return x.device # pyright: ignore + return x.device # type: ignore # pyright: ignore # Prevent shadowing, used below @@ -786,11 +782,11 @@ def device(x: _ArrayApiObj, /) -> Device: # Based on cupy.array_api.Array.to_device def _cupy_to_device( - x: _CupyArray, + x: cp.ndarray, device: Device, /, stream: int | Any | None = None, -) -> _CupyArray: +) -> cp.ndarray: import cupy as cp if device == "cpu": @@ -819,7 +815,7 @@ def _torch_to_device( x: torch.Tensor, device: torch.device | str | int, /, - stream: None = None, + stream: int | Any | None = None, ) -> torch.Tensor: if stream is not None: raise NotImplementedError @@ -885,7 +881,7 @@ def to_device(x: Array, device: Device, /, *, stream: int | Any | None = None) - # cupy does not yet have to_device return _cupy_to_device(x, device, stream=stream) elif is_torch_array(x): - return _torch_to_device(x, device, stream=stream) # pyright: ignore[reportArgumentType] + return _torch_to_device(x, device, stream=stream) elif is_dask_array(x): if stream is not None: raise ValueError("The stream argument to to_device() is not supported") @@ -912,8 +908,6 @@ def to_device(x: Array, device: Device, /, *, stream: int | Any | None = None) - @overload def size(x: HasShape[Collection[SupportsIndex]]) -> int: ... @overload -def size(x: HasShape[Collection[None]]) -> None: ... -@overload def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None: ... def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None: """ @@ -948,7 +942,7 @@ def _is_writeable_cls(cls: type) -> bool | None: return None -def is_writeable_array(x: object) -> bool: +def is_writeable_array(x: object) -> TypeGuard[_ArrayApiObj]: """ Return False if ``x.__setitem__`` is expected to raise; True otherwise. Return False if `x` is not an array API compatible object. @@ -986,7 +980,7 @@ def _is_lazy_cls(cls: type) -> bool | None: return None -def is_lazy_array(x: object) -> bool: +def is_lazy_array(x: object) -> TypeGuard[_ArrayApiObj]: """Return True if x is potentially a future or it may be otherwise impossible or expensive to eagerly read its contents, regardless of their size, e.g. by calling ``bool(x)`` or ``float(x)``. diff --git a/array_api_compat/common/_linalg.py b/array_api_compat/common/_linalg.py index 7ad87a1b..3fd9d860 100644 --- a/array_api_compat/common/_linalg.py +++ b/array_api_compat/common/_linalg.py @@ -8,7 +8,7 @@ if np.__version__[0] == "2": from numpy.lib.array_utils import normalize_axis_tuple else: - from numpy.core.numeric import normalize_axis_tuple + from numpy.core.numeric import normalize_axis_tuple # type: ignore[no-redef] from .._internal import get_xp from ._aliases import isdtype, matmul, matrix_transpose, tensordot, vecdot @@ -187,14 +187,14 @@ def vector_norm( # We can't reuse xp.linalg.norm(keepdims) because of the reshape hacks # above to avoid matrix norm logic. shape = list(x.shape) - _axis = cast( + axes = cast( "tuple[int, ...]", normalize_axis_tuple( # pyright: ignore[reportCallIssue] range(x.ndim) if axis is None else axis, x.ndim, ), ) - for i in _axis: + for i in axes: shape[i] = 1 res = xp.reshape(res, tuple(shape)) diff --git a/array_api_compat/common/_typing.py b/array_api_compat/common/_typing.py index cd26feeb..11b00bd1 100644 --- a/array_api_compat/common/_typing.py +++ b/array_api_compat/common/_typing.py @@ -34,32 +34,29 @@ # - docs: https://github.com/jorenham/optype/blob/master/README.md#just # - code: https://github.com/jorenham/optype/blob/master/optype/_core/_just.py @final -class JustInt(Protocol): - @property +class JustInt(Protocol): # type: ignore[misc] + @property # type: ignore[override] def __class__(self, /) -> type[int]: ... @__class__.setter def __class__(self, value: type[int], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride] @final -class JustFloat(Protocol): - @property +class JustFloat(Protocol): # type: ignore[misc] + @property # type: ignore[override] def __class__(self, /) -> type[float]: ... @__class__.setter def __class__(self, value: type[float], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride] @final -class JustComplex(Protocol): - @property +class JustComplex(Protocol): # type: ignore[misc] + @property # type: ignore[override] def __class__(self, /) -> type[complex]: ... @__class__.setter def __class__(self, value: type[complex], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride] -# - - class NestedSequence(Protocol[_T_co]): def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ... def __len__(self, /) -> int: ... diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index 90b48f05..c0473ca4 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Optional +from builtins import bool as py_bool import cupy as cp @@ -67,18 +67,13 @@ # asarray also adds the copy keyword, which is not present in numpy 1.0. def asarray( - obj: ( - Array - | bool | int | float | complex - | NestedSequence[bool | int | float | complex] - | SupportsBufferProtocol - ), + obj: Array | complex | NestedSequence[complex] | SupportsBufferProtocol, /, *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - copy: Optional[bool] = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + copy: py_bool | None = None, + **kwargs: object, ) -> Array: """ Array API compatibility wrapper for asarray(). @@ -101,8 +96,8 @@ def astype( dtype: DType, /, *, - copy: bool = True, - device: Optional[Device] = None, + copy: py_bool = True, + device: Device | None = None, ) -> Array: if device is None: return x.astype(dtype=dtype, copy=copy) @@ -113,8 +108,8 @@ def astype( # cupy.count_nonzero does not have keepdims def count_nonzero( x: Array, - axis=None, - keepdims=False + axis: int | tuple[int, ...] | None = None, + keepdims: py_bool = False, ) -> Array: result = cp.count_nonzero(x, axis) if keepdims: @@ -125,7 +120,7 @@ def count_nonzero( # take_along_axis: axis defaults to -1 but in cupy (and numpy) axis is a required arg -def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1): +def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array: return cp.take_along_axis(x, indices, axis=axis) @@ -153,4 +148,6 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1): 'bool', 'concat', 'count_nonzero', 'pow', 'sign', 'take_along_axis'] -_all_ignore = ['cp', 'get_xp'] + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/cupy/fft.py b/array_api_compat/cupy/fft.py index 307e0f72..2bd11940 100644 --- a/array_api_compat/cupy/fft.py +++ b/array_api_compat/cupy/fft.py @@ -1,10 +1,11 @@ -from cupy.fft import * # noqa: F403 +from cupy.fft import * # noqa: F403 + # cupy.fft doesn't have __all__. If it is added, replace this with # # from cupy.fft import __all__ as linalg_all -_n = {} -exec('from cupy.fft import *', _n) -del _n['__builtins__'] +_n: dict[str, object] = {} +exec("from cupy.fft import *", _n) +del _n["__builtins__"] fft_all = list(_n) del _n diff --git a/array_api_compat/cupy/linalg.py b/array_api_compat/cupy/linalg.py index 7fcdd498..7bc3536e 100644 --- a/array_api_compat/cupy/linalg.py +++ b/array_api_compat/cupy/linalg.py @@ -2,7 +2,7 @@ # cupy.linalg doesn't have __all__. If it is added, replace this with # # from cupy.linalg import __all__ as linalg_all -_n = {} +_n: dict[str, object] = {} exec('from cupy.linalg import *', _n) del _n['__builtins__'] linalg_all = list(_n) diff --git a/array_api_compat/dask/array/__init__.py b/array_api_compat/dask/array/__init__.py index 1e47b960..6d2ea7cd 100644 --- a/array_api_compat/dask/array/__init__.py +++ b/array_api_compat/dask/array/__init__.py @@ -3,7 +3,7 @@ from dask.array import * # noqa: F403 # These imports may overwrite names from the import * above. -from ._aliases import * # noqa: F403 +from ._aliases import * # type: ignore[assignment] # noqa: F403 __array_api_version__: Final = "2024.12" diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index d43881ab..bc0302fe 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -146,7 +146,7 @@ def arange( # asarray also adds the copy keyword, which is not present in numpy 1.0. def asarray( - obj: complex | NestedSequence[complex] | Array | SupportsBufferProtocol, + obj: Array | complex | NestedSequence[complex] | SupportsBufferProtocol, /, *, dtype: DType | None = None, diff --git a/array_api_compat/dask/array/_info.py b/array_api_compat/dask/array/_info.py index 9e4d736f..2f39fc4b 100644 --- a/array_api_compat/dask/array/_info.py +++ b/array_api_compat/dask/array/_info.py @@ -12,9 +12,9 @@ from __future__ import annotations -from typing import Literal as L -from typing import TypeAlias, overload +from typing import Literal, TypeAlias, overload +import dask.array as da from numpy import bool_ as bool from numpy import ( complex64, @@ -33,7 +33,7 @@ uint64, ) -from ...common._helpers import _DASK_DEVICE, _dask_device +from ...common._helpers import _DASK_DEVICE, _check_device, _dask_device from ...common._typing import ( Capabilities, DefaultDTypes, @@ -49,8 +49,7 @@ DTypesSigned, DTypesUnsigned, ) - -_Device: TypeAlias = L["cpu"] | _dask_device +Device: TypeAlias = Literal["cpu"] | _dask_device class __array_namespace_info__: @@ -142,7 +141,7 @@ def capabilities(self) -> Capabilities: "max dimensions": 64, } - def default_device(self) -> L["cpu"]: + def default_device(self) -> Device: """ The default device used for new Dask arrays. @@ -169,7 +168,7 @@ def default_device(self) -> L["cpu"]: """ return "cpu" - def default_dtypes(self, /, *, device: _Device | None = None) -> DefaultDTypes: + def default_dtypes(self, /, *, device: Device | None = None) -> DefaultDTypes: """ The default data types used for new Dask arrays. @@ -208,11 +207,7 @@ def default_dtypes(self, /, *, device: _Device | None = None) -> DefaultDTypes: 'indexing': dask.int64} """ - if device not in ["cpu", _DASK_DEVICE, None]: - raise ValueError( - f'Device not understood. Only "cpu" or _DASK_DEVICE is allowed, ' - f"but received: {device!r}" - ) + _check_device(da, device) return { "real floating": dtype(float64), "complex floating": dtype(complex128), @@ -222,38 +217,38 @@ def default_dtypes(self, /, *, device: _Device | None = None) -> DefaultDTypes: @overload def dtypes( - self, /, *, device: _Device | None = None, kind: None = None + self, /, *, device: Device | None = None, kind: None = None ) -> DTypesAll: ... @overload def dtypes( - self, /, *, device: _Device | None = None, kind: L["bool"] + self, /, *, device: Device | None = None, kind: Literal["bool"] ) -> DTypesBool: ... @overload def dtypes( - self, /, *, device: _Device | None = None, kind: L["signed integer"] + self, /, *, device: Device | None = None, kind: Literal["signed integer"] ) -> DTypesSigned: ... @overload def dtypes( - self, /, *, device: _Device | None = None, kind: L["unsigned integer"] + self, /, *, device: Device | None = None, kind: Literal["unsigned integer"] ) -> DTypesUnsigned: ... @overload def dtypes( - self, /, *, device: _Device | None = None, kind: L["integral"] + self, /, *, device: Device | None = None, kind: Literal["integral"] ) -> DTypesIntegral: ... @overload def dtypes( - self, /, *, device: _Device | None = None, kind: L["real floating"] + self, /, *, device: Device | None = None, kind: Literal["real floating"] ) -> DTypesReal: ... @overload def dtypes( - self, /, *, device: _Device | None = None, kind: L["complex floating"] + self, /, *, device: Device | None = None, kind: Literal["complex floating"] ) -> DTypesComplex: ... @overload def dtypes( - self, /, *, device: _Device | None = None, kind: L["numeric"] + self, /, *, device: Device | None = None, kind: Literal["numeric"] ) -> DTypesNumeric: ... def dtypes( - self, /, *, device: _Device | None = None, kind: DTypeKind | None = None + self, /, *, device: Device | None = None, kind: DTypeKind | None = None ) -> DTypesAny: """ The array API data types supported by Dask. @@ -308,11 +303,7 @@ def dtypes( 'int64': dask.int64} """ - if device not in ["cpu", _DASK_DEVICE, None]: - raise ValueError( - 'Device not understood. Only "cpu" or _DASK_DEVICE is allowed, but received:' - f" {device}" - ) + _check_device(da, device) if kind is None: return { "bool": dtype(bool), @@ -381,14 +372,14 @@ def dtypes( "complex64": dtype(complex64), "complex128": dtype(complex128), } - if isinstance(kind, tuple): # type: ignore[reportUnnecessaryIsinstanceCall] + if isinstance(kind, tuple): res: dict[str, DType] = {} for k in kind: res.update(self.dtypes(kind=k)) return res raise ValueError(f"unsupported kind: {kind!r}") - def devices(self) -> list[_Device]: + def devices(self) -> list[Device]: """ The devices supported by Dask. diff --git a/array_api_compat/dask/array/fft.py b/array_api_compat/dask/array/fft.py index 3f40dffe..68c4280e 100644 --- a/array_api_compat/dask/array/fft.py +++ b/array_api_compat/dask/array/fft.py @@ -2,7 +2,7 @@ # dask.array.fft doesn't have __all__. If it is added, replace this with # # from dask.array.fft import __all__ as linalg_all -_n = {} +_n: dict[str, object] = {} exec('from dask.array.fft import *', _n) for k in ("__builtins__", "Sequence", "annotations", "warnings"): _n.pop(k, None) diff --git a/array_api_compat/dask/array/linalg.py b/array_api_compat/dask/array/linalg.py index 0825386e..06f596bc 100644 --- a/array_api_compat/dask/array/linalg.py +++ b/array_api_compat/dask/array/linalg.py @@ -4,21 +4,22 @@ import dask.array as da -# The `matmul` and `tensordot` functions are in both the main and linalg namespaces -from dask.array import matmul, outer, tensordot - # Exports from dask.array.linalg import * # noqa: F403 +from dask.array import outer +# The `matmul` and `tensordot` functions are in both the main and linalg namespaces +from dask.array import matmul, tensordot + from ..._internal import get_xp from ...common import _linalg -from ...common._typing import Array as _Array +from ...common._typing import Array from ._aliases import matrix_transpose, vecdot # dask.array.linalg doesn't have __all__. If it is added, replace this with # # from dask.array.linalg import __all__ as linalg_all -_n = {} +_n: dict[str, object] = {} exec('from dask.array.linalg import *', _n) for k in ('__builtins__', 'annotations', 'operator', 'warnings', 'Array'): _n.pop(k, None) @@ -33,8 +34,8 @@ # supports the mode keyword on QR # https://github.com/dask/dask/issues/10388 #qr = get_xp(da)(_linalg.qr) -def qr( - x: _Array, +def qr( # type: ignore[no-redef] + x: Array, mode: Literal["reduced", "complete"] = "reduced", **kwargs: object, ) -> QRResult: @@ -50,12 +51,12 @@ def qr( # Wrap the svd functions to not pass full_matrices to dask # when full_matrices=False (as that is the default behavior for dask), # and dask doesn't have the full_matrices keyword -def svd(x: _Array, full_matrices: bool = True, **kwargs) -> SVDResult: +def svd(x: Array, full_matrices: bool = True, **kwargs: object) -> SVDResult: # type: ignore[no-redef] if full_matrices: raise ValueError("full_matrics=True is not supported by dask.") return da.linalg.svd(x, coerce_signs=False, **kwargs) -def svdvals(x: _Array) -> _Array: +def svdvals(x: Array) -> Array: # TODO: can't avoid computing U or V for dask _, s, _ = svd(x) return s diff --git a/array_api_compat/numpy/__init__.py b/array_api_compat/numpy/__init__.py index 3e138f53..bf43fe61 100644 --- a/array_api_compat/numpy/__init__.py +++ b/array_api_compat/numpy/__init__.py @@ -10,7 +10,7 @@ from numpy import round as round # These imports may overwrite names from the import * above. -from ._aliases import * # noqa: F403 +from ._aliases import * # type: ignore[assignment,no-redef] # noqa: F403 # Don't know why, but we have to do an absolute import to import linalg. If we # instead do diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index a1aee5c0..5a05a820 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -2,7 +2,7 @@ from __future__ import annotations from builtins import bool as py_bool -from typing import TYPE_CHECKING, Any, Literal, TypeAlias, cast +from typing import Any, cast import numpy as np @@ -12,13 +12,6 @@ from ._info import __array_namespace_info__ from ._typing import Array, Device, DType -if TYPE_CHECKING: - from typing_extensions import Buffer, TypeIs - -# The values of the `_CopyMode` enum can be either `False`, `True`, or `2`: -# https://github.com/numpy/numpy/blob/5a8a6a79d9c2fff8f07dcab5d41e14f8508d673f/numpy/_globals.pyi#L7-L10 -_Copy: TypeAlias = py_bool | Literal[2] | np._CopyMode - bool = np.bool_ # Basic renames @@ -74,14 +67,6 @@ iinfo = get_xp(np)(_aliases.iinfo) -def _supports_buffer_protocol(obj: object) -> TypeIs[Buffer]: # pyright: ignore[reportUnusedFunction] - try: - memoryview(obj) # pyright: ignore[reportArgumentType] - except TypeError: - return False - return True - - # asarray also adds the copy keyword, which is not present in numpy 1.0. # asarray() is different enough between numpy, cupy, and dask, the logic # complicated enough that it's easier to define it separately for each module @@ -92,7 +77,7 @@ def asarray( *, dtype: DType | None = None, device: Device | None = None, - copy: _Copy | None = None, + copy: py_bool | None = None, **kwargs: Any, ) -> Array: """ @@ -103,14 +88,14 @@ def asarray( """ _helpers._check_device(np, device) + # None is unsupported in NumPy 1.0, but we can use an internal enum + # False in NumPy 1.0 means None in NumPy 2.0 and in the Array API if copy is None: - copy = np._CopyMode.IF_NEEDED + copy = np._CopyMode.IF_NEEDED # type: ignore[assignment,attr-defined] elif copy is False: - copy = np._CopyMode.NEVER - elif copy is True: - copy = np._CopyMode.ALWAYS + copy = np._CopyMode.NEVER # type: ignore[assignment,attr-defined] - return np.array(obj, copy=copy, dtype=dtype, **kwargs) # pyright: ignore + return np.array(obj, copy=copy, dtype=dtype, **kwargs) def astype( @@ -141,7 +126,7 @@ def count_nonzero( # take_along_axis: axis defaults to -1 but in numpy axis is a required arg -def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1): +def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array: return np.take_along_axis(x, indices, axis=axis) @@ -150,7 +135,7 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1): if hasattr(np, "vecdot"): vecdot = np.vecdot else: - vecdot = get_xp(np)(_aliases.vecdot) + vecdot = get_xp(np)(_aliases.vecdot) # type: ignore[assignment] if hasattr(np, "isdtype"): isdtype = np.isdtype diff --git a/array_api_compat/numpy/_info.py b/array_api_compat/numpy/_info.py index f307f62c..c625c13e 100644 --- a/array_api_compat/numpy/_info.py +++ b/array_api_compat/numpy/_info.py @@ -27,6 +27,7 @@ uint64, ) +from ..common._typing import DefaultDTypes from ._typing import Device, DType @@ -139,7 +140,7 @@ def default_dtypes( self, *, device: Device | None = None, - ) -> dict[str, dtype[intp | float64 | complex128]]: + ) -> DefaultDTypes: """ The default data types used for new NumPy arrays. diff --git a/array_api_compat/numpy/linalg.py b/array_api_compat/numpy/linalg.py index 2d3e731d..9a618be9 100644 --- a/array_api_compat/numpy/linalg.py +++ b/array_api_compat/numpy/linalg.py @@ -65,7 +65,7 @@ # https://github.com/cupy/cupy/blob/main/cupy/cublas.py#L43). def solve(x1: Array, x2: Array, /) -> Array: try: - from numpy.linalg._linalg import ( + from numpy.linalg._linalg import ( # type: ignore[attr-defined] _assert_stacked_2d, _assert_stacked_square, _commonType, @@ -74,7 +74,7 @@ def solve(x1: Array, x2: Array, /) -> Array: isComplexType, ) except ImportError: - from numpy.linalg.linalg import ( + from numpy.linalg.linalg import ( # type: ignore[attr-defined] _assert_stacked_2d, _assert_stacked_square, _commonType, diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index de5d1a5d..7a449001 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -1,8 +1,9 @@ from __future__ import annotations +from collections.abc import Sequence from functools import reduce as _reduce, wraps as _wraps from builtins import all as _builtin_all, any as _builtin_any -from typing import Any, List, Optional, Sequence, Tuple, Union, Literal +from typing import Any, Literal import torch @@ -96,9 +97,7 @@ def _fix_promotion(x1, x2, only_scalar=True): _py_scalars = (bool, int, float, complex) -def result_type( - *arrays_and_dtypes: Array | DType | bool | int | float | complex -) -> DType: +def result_type(*arrays_and_dtypes: Array | DType | complex) -> DType: num = len(arrays_and_dtypes) if num == 0: @@ -129,10 +128,7 @@ def result_type( return _reduce(_result_type, others + scalars) -def _result_type( - x: Array | DType | bool | int | float | complex, - y: Array | DType | bool | int | float | complex, -) -> DType: +def _result_type(x: Array | DType | complex, y: Array | DType | complex) -> DType: if not (isinstance(x, _py_scalars) or isinstance(y, _py_scalars)): xdt = x if isinstance(x, torch.dtype) else x.dtype ydt = y if isinstance(y, torch.dtype) else y.dtype @@ -150,7 +146,7 @@ def _result_type( return torch.result_type(x, y) -def can_cast(from_: Union[DType, Array], to: DType, /) -> bool: +def can_cast(from_: DType | Array, to: DType, /) -> bool: if not isinstance(from_, torch.dtype): from_ = from_.dtype return torch.can_cast(from_, to) @@ -194,12 +190,7 @@ def can_cast(from_: Union[DType, Array], to: DType, /) -> bool: def asarray( - obj: ( - Array - | bool | int | float | complex - | NestedSequence[bool | int | float | complex] - | SupportsBufferProtocol - ), + obj: Array | complex | NestedSequence[complex] | SupportsBufferProtocol, /, *, dtype: DType | None = None, @@ -218,13 +209,13 @@ def asarray( # of 'axis'. # torch.min and torch.max return a tuple and don't support multiple axes https://github.com/pytorch/pytorch/issues/58745 -def max(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Array: +def max(x: Array, /, *, axis: int | tuple[int, ...] | None = None, keepdims: bool = False) -> Array: # https://github.com/pytorch/pytorch/issues/29137 if axis == (): return torch.clone(x) return torch.amax(x, axis, keepdims=keepdims) -def min(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Array: +def min(x: Array, /, *, axis: int | tuple[int, ...] |None = None, keepdims: bool = False) -> Array: # https://github.com/pytorch/pytorch/issues/29137 if axis == (): return torch.clone(x) @@ -240,7 +231,15 @@ def min(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keep # torch.sort also returns a tuple # https://github.com/pytorch/pytorch/issues/70921 -def sort(x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True, **kwargs) -> Array: +def sort( + x: Array, + /, + *, + axis: int = -1, + descending: bool = False, + stable: bool = True, + **kwargs: object, +) -> Array: return torch.sort(x, dim=axis, descending=descending, stable=stable, **kwargs).values def _normalize_axes(axis, ndim): @@ -307,10 +306,10 @@ def _sum_prod_no_axis(x: Array, dtype: DType | None) -> Array: def prod(x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - dtype: Optional[DType] = None, + axis: int | tuple[int, ...] | None = None, + dtype: DType | None = None, keepdims: bool = False, - **kwargs) -> Array: + **kwargs: object) -> Array: if axis == (): return _sum_prod_no_axis(x, dtype) @@ -331,10 +330,10 @@ def prod(x: Array, def sum(x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - dtype: Optional[DType] = None, + axis: int | tuple[int, ...] | None = None, + dtype: DType | None = None, keepdims: bool = False, - **kwargs) -> Array: + **kwargs: object) -> Array: if axis == (): return _sum_prod_no_axis(x, dtype) @@ -350,9 +349,9 @@ def sum(x: Array, def any(x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, + axis: int | tuple[int, ...] | None = None, keepdims: bool = False, - **kwargs) -> Array: + **kwargs: object) -> Array: if axis == (): return x.to(torch.bool) @@ -374,9 +373,9 @@ def any(x: Array, def all(x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, + axis: int | tuple[int, ...] | None = None, keepdims: bool = False, - **kwargs) -> Array: + **kwargs: object) -> Array: if axis == (): return x.to(torch.bool) @@ -398,9 +397,9 @@ def all(x: Array, def mean(x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, + axis: int | tuple[int, ...] | None = None, keepdims: bool = False, - **kwargs) -> Array: + **kwargs: object) -> Array: # https://github.com/pytorch/pytorch/issues/29137 if axis == (): return torch.clone(x) @@ -415,10 +414,10 @@ def mean(x: Array, def std(x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - correction: Union[int, float] = 0.0, + axis: int | tuple[int, ...] | None = None, + correction: float = 0.0, keepdims: bool = False, - **kwargs) -> Array: + **kwargs: object) -> Array: # Note, float correction is not supported # https://github.com/pytorch/pytorch/issues/61492. We don't try to # implement it here for now. @@ -446,10 +445,10 @@ def std(x: Array, def var(x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - correction: Union[int, float] = 0.0, + axis: int | tuple[int, ...] | None = None, + correction: float = 0.0, keepdims: bool = False, - **kwargs) -> Array: + **kwargs: object) -> Array: # Note, float correction is not supported # https://github.com/pytorch/pytorch/issues/61492. We don't try to # implement it here for now. @@ -472,11 +471,11 @@ def var(x: Array, # torch.concat doesn't support dim=None # https://github.com/pytorch/pytorch/issues/70925 -def concat(arrays: Union[Tuple[Array, ...], List[Array]], +def concat(arrays: tuple[Array, ...] | list[Array], /, *, - axis: Optional[int] = 0, - **kwargs) -> Array: + axis: int | None = 0, + **kwargs: object) -> Array: if axis is None: arrays = tuple(ar.flatten() for ar in arrays) axis = 0 @@ -485,7 +484,7 @@ def concat(arrays: Union[Tuple[Array, ...], List[Array]], # torch.squeeze only accepts int dim and doesn't require it # https://github.com/pytorch/pytorch/issues/70924. Support for tuple dim was # added at https://github.com/pytorch/pytorch/pull/89017. -def squeeze(x: Array, /, axis: Union[int, Tuple[int, ...]]) -> Array: +def squeeze(x: Array, /, axis: int | tuple[int, ...]) -> Array: if isinstance(axis, int): axis = (axis,) for a in axis: @@ -499,27 +498,27 @@ def squeeze(x: Array, /, axis: Union[int, Tuple[int, ...]]) -> Array: return x # torch.broadcast_to uses size instead of shape -def broadcast_to(x: Array, /, shape: Tuple[int, ...], **kwargs) -> Array: +def broadcast_to(x: Array, /, shape: tuple[int, ...], **kwargs: object) -> Array: return torch.broadcast_to(x, shape, **kwargs) # torch.permute uses dims instead of axes -def permute_dims(x: Array, /, axes: Tuple[int, ...]) -> Array: +def permute_dims(x: Array, /, axes: tuple[int, ...]) -> Array: return torch.permute(x, axes) # The axis parameter doesn't work for flip() and roll() # https://github.com/pytorch/pytorch/issues/71210. Also torch.flip() doesn't # accept axis=None -def flip(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs) -> Array: +def flip(x: Array, /, *, axis: int | tuple[int, ...] | None = None, **kwargs: object) -> Array: if axis is None: axis = tuple(range(x.ndim)) # torch.flip doesn't accept dim as an int but the method does # https://github.com/pytorch/pytorch/issues/18095 return x.flip(axis, **kwargs) -def roll(x: Array, /, shift: Union[int, Tuple[int, ...]], *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs) -> Array: +def roll(x: Array, /, shift: int | tuple[int, ...], *, axis: int | tuple[int, ...] | None = None, **kwargs: object) -> Array: return torch.roll(x, shift, axis, **kwargs) -def nonzero(x: Array, /, **kwargs) -> Tuple[Array, ...]: +def nonzero(x: Array, /, **kwargs: object) -> tuple[Array, ...]: if x.ndim == 0: raise ValueError("nonzero() does not support zero-dimensional arrays") return torch.nonzero(x, as_tuple=True, **kwargs) @@ -532,8 +531,8 @@ def diff( *, axis: int = -1, n: int = 1, - prepend: Optional[Array] = None, - append: Optional[Array] = None, + prepend: Array | None = None, + append: Array | None = None, ) -> Array: return torch.diff(x, dim=axis, n=n, prepend=prepend, append=append) @@ -543,7 +542,7 @@ def count_nonzero( x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, + axis: int | tuple[int, ...] | None = None, keepdims: bool = False, ) -> Array: result = torch.count_nonzero(x, dim=axis) @@ -564,12 +563,7 @@ def repeat(x: Array, repeats: int | Array, /, *, axis: int | None = None) -> Arr return torch.repeat_interleave(x, repeats, axis) -def where( - condition: Array, - x1: Array | bool | int | float | complex, - x2: Array | bool | int | float | complex, - /, -) -> Array: +def where(condition: Array, x1: Array | complex, x2: Array | complex, /) -> Array: x1, x2 = _fix_promotion(x1, x2) return torch.where(condition, x1, x2) @@ -577,10 +571,10 @@ def where( # torch.reshape doesn't have the copy keyword def reshape(x: Array, /, - shape: Tuple[int, ...], + shape: tuple[int, ...], *, - copy: Optional[bool] = None, - **kwargs) -> Array: + copy: bool | None = None, + **kwargs: object) -> Array: if copy is not None: raise NotImplementedError("torch.reshape doesn't yet support the copy keyword") return torch.reshape(x, shape, **kwargs) @@ -589,14 +583,14 @@ def reshape(x: Array, # (https://github.com/pytorch/pytorch/issues/70915), and doesn't support some # keyword argument combinations # (https://github.com/pytorch/pytorch/issues/70914) -def arange(start: Union[int, float], +def arange(start: float, /, - stop: Optional[Union[int, float]] = None, - step: Union[int, float] = 1, + stop: float | None = None, + step: float = 1, *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs) -> Array: + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object) -> Array: if stop is None: start, stop = 0, start if step > 0 and stop <= start or step < 0 and stop >= start: @@ -611,13 +605,13 @@ def arange(start: Union[int, float], # torch.eye does not accept None as a default for the second argument and # doesn't support off-diagonals (https://github.com/pytorch/pytorch/issues/70910) def eye(n_rows: int, - n_cols: Optional[int] = None, + n_cols: int | None = None, /, *, k: int = 0, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs) -> Array: + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object) -> Array: if n_cols is None: n_cols = n_rows z = torch.zeros(n_rows, n_cols, dtype=dtype, device=device, **kwargs) @@ -626,52 +620,52 @@ def eye(n_rows: int, return z # torch.linspace doesn't have the endpoint parameter -def linspace(start: Union[int, float], - stop: Union[int, float], +def linspace(start: float, + stop: float, /, num: int, *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, + dtype: DType | None = None, + device: Device | None = None, endpoint: bool = True, - **kwargs) -> Array: + **kwargs: object) -> Array: if not endpoint: return torch.linspace(start, stop, num+1, dtype=dtype, device=device, **kwargs)[:-1] return torch.linspace(start, stop, num, dtype=dtype, device=device, **kwargs) # torch.full does not accept an int size # https://github.com/pytorch/pytorch/issues/70906 -def full(shape: Union[int, Tuple[int, ...]], - fill_value: bool | int | float | complex, +def full(shape: int | tuple[int, ...], + fill_value: complex, *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs) -> Array: + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object) -> Array: if isinstance(shape, int): shape = (shape,) return torch.full(shape, fill_value, dtype=dtype, device=device, **kwargs) # ones, zeros, and empty do not accept shape as a keyword argument -def ones(shape: Union[int, Tuple[int, ...]], +def ones(shape: int | tuple[int, ...], *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs) -> Array: + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object) -> Array: return torch.ones(shape, dtype=dtype, device=device, **kwargs) -def zeros(shape: Union[int, Tuple[int, ...]], +def zeros(shape: int | tuple[int, ...], *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs) -> Array: + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object) -> Array: return torch.zeros(shape, dtype=dtype, device=device, **kwargs) -def empty(shape: Union[int, Tuple[int, ...]], +def empty(shape: int | tuple[int, ...], *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs) -> Array: + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object) -> Array: return torch.empty(shape, dtype=dtype, device=device, **kwargs) # tril and triu do not call the keyword argument k @@ -693,14 +687,14 @@ def astype( /, *, copy: bool = True, - device: Optional[Device] = None, + device: Device | None = None, ) -> Array: if device is not None: return x.to(device, dtype=dtype, copy=copy) return x.to(dtype=dtype, copy=copy) -def broadcast_arrays(*arrays: Array) -> List[Array]: +def broadcast_arrays(*arrays: Array) -> list[Array]: shape = torch.broadcast_shapes(*[a.shape for a in arrays]) return [torch.broadcast_to(a, shape) for a in arrays] @@ -738,7 +732,7 @@ def unique_inverse(x: Array) -> UniqueInverseResult: def unique_values(x: Array) -> Array: return torch.unique(x) -def matmul(x1: Array, x2: Array, /, **kwargs) -> Array: +def matmul(x1: Array, x2: Array, /, **kwargs: object) -> Array: # torch.matmul doesn't type promote (but differently from _fix_promotion) x1, x2 = _fix_promotion(x1, x2, only_scalar=False) return torch.matmul(x1, x2, **kwargs) @@ -756,8 +750,8 @@ def tensordot( x2: Array, /, *, - axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2, - **kwargs, + axes: int | tuple[Sequence[int], Sequence[int]] = 2, + **kwargs: object, ) -> Array: # Note: torch.tensordot fails with integer dtypes when there is only 1 # element in the axis (https://github.com/pytorch/pytorch/issues/84530). @@ -766,8 +760,10 @@ def tensordot( def isdtype( - dtype: DType, kind: Union[DType, str, Tuple[Union[DType, str], ...]], - *, _tuple=True, # Disallow nested tuples + dtype: DType, + kind: DType | str | tuple[DType | str, ...], + *, + _tuple: bool = True, # Disallow nested tuples ) -> bool: """ Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``. @@ -801,7 +797,7 @@ def isdtype( else: return dtype == kind -def take(x: Array, indices: Array, /, *, axis: Optional[int] = None, **kwargs) -> Array: +def take(x: Array, indices: Array, /, *, axis: int | None = None, **kwargs: object) -> Array: if axis is None: if x.ndim != 1: raise ValueError("axis must be specified when ndim > 1") @@ -828,7 +824,7 @@ def sign(x: Array, /) -> Array: return out -def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> List[Array]: +def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> list[Array]: # enforce the default of 'xy' # TODO: is the return type a list or a tuple return list(torch.meshgrid(*arrays, indexing='xy')) diff --git a/array_api_compat/torch/fft.py b/array_api_compat/torch/fft.py index 50e6a0d0..ddf87c65 100644 --- a/array_api_compat/torch/fft.py +++ b/array_api_compat/torch/fft.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import Union, Sequence, Literal +from collections.abc import Sequence +from typing import Literal import torch import torch.fft @@ -17,7 +18,7 @@ def fftn( s: Sequence[int] = None, axes: Sequence[int] = None, norm: Literal["backward", "ortho", "forward"] = "backward", - **kwargs, + **kwargs: object, ) -> Array: return torch.fft.fftn(x, s=s, dim=axes, norm=norm, **kwargs) @@ -28,7 +29,7 @@ def ifftn( s: Sequence[int] = None, axes: Sequence[int] = None, norm: Literal["backward", "ortho", "forward"] = "backward", - **kwargs, + **kwargs: object, ) -> Array: return torch.fft.ifftn(x, s=s, dim=axes, norm=norm, **kwargs) @@ -39,7 +40,7 @@ def rfftn( s: Sequence[int] = None, axes: Sequence[int] = None, norm: Literal["backward", "ortho", "forward"] = "backward", - **kwargs, + **kwargs: object, ) -> Array: return torch.fft.rfftn(x, s=s, dim=axes, norm=norm, **kwargs) @@ -50,7 +51,7 @@ def irfftn( s: Sequence[int] = None, axes: Sequence[int] = None, norm: Literal["backward", "ortho", "forward"] = "backward", - **kwargs, + **kwargs: object, ) -> Array: return torch.fft.irfftn(x, s=s, dim=axes, norm=norm, **kwargs) @@ -58,8 +59,8 @@ def fftshift( x: Array, /, *, - axes: Union[int, Sequence[int]] = None, - **kwargs, + axes: int | Sequence[int] = None, + **kwargs: object, ) -> Array: return torch.fft.fftshift(x, dim=axes, **kwargs) @@ -67,8 +68,8 @@ def ifftshift( x: Array, /, *, - axes: Union[int, Sequence[int]] = None, - **kwargs, + axes: int | Sequence[int] = None, + **kwargs: object, ) -> Array: return torch.fft.ifftshift(x, dim=axes, **kwargs) diff --git a/array_api_compat/torch/linalg.py b/array_api_compat/torch/linalg.py index 70d72405..558cfe7b 100644 --- a/array_api_compat/torch/linalg.py +++ b/array_api_compat/torch/linalg.py @@ -1,8 +1,6 @@ from __future__ import annotations import torch -from typing import Optional, Union, Tuple - from torch.linalg import * # noqa: F403 # torch.linalg doesn't define __all__ @@ -32,7 +30,7 @@ def cross(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: x1, x2 = torch.broadcast_tensors(x1, x2) return torch_linalg.cross(x1, x2, dim=axis) -def vecdot(x1: Array, x2: Array, /, *, axis: int = -1, **kwargs) -> Array: +def vecdot(x1: Array, x2: Array, /, *, axis: int = -1, **kwargs: object) -> Array: from ._aliases import isdtype x1, x2 = _fix_promotion(x1, x2, only_scalar=False) @@ -54,7 +52,7 @@ def vecdot(x1: Array, x2: Array, /, *, axis: int = -1, **kwargs) -> Array: return res[..., 0, 0] return torch.linalg.vecdot(x1, x2, dim=axis, **kwargs) -def solve(x1: Array, x2: Array, /, **kwargs) -> Array: +def solve(x1: Array, x2: Array, /, **kwargs: object) -> Array: x1, x2 = _fix_promotion(x1, x2, only_scalar=False) # Torch tries to emulate NumPy 1 solve behavior by using batched 1-D solve # whenever @@ -75,7 +73,7 @@ def solve(x1: Array, x2: Array, /, **kwargs) -> Array: return torch.linalg.solve(x1, x2, **kwargs) # torch.trace doesn't support the offset argument and doesn't support stacking -def trace(x: Array, /, *, offset: int = 0, dtype: Optional[DType] = None) -> Array: +def trace(x: Array, /, *, offset: int = 0, dtype: DType | None = None) -> Array: # Use our wrapped sum to make sure it does upcasting correctly return sum(torch.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype) @@ -83,11 +81,11 @@ def vector_norm( x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, + axis: int | tuple[int, ...] | None = None, keepdims: bool = False, # JustFloat stands for inf | -inf, which are not valid for Literal ord: JustInt | JustFloat = 2, - **kwargs, + **kwargs: object, ) -> Array: # torch.vector_norm incorrectly treats axis=() the same as axis=None if axis == (): diff --git a/pyproject.toml b/pyproject.toml index aacebd11..ec054417 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,11 +43,11 @@ dev = [ "array-api-strict", "dask[array]>=2024.9.0", "jax[cpu]", + "ndonnx", "numpy>=1.22", "pytest", "torch", "sparse>=0.15.1", - "ndonnx" ] [project.urls] @@ -61,7 +61,7 @@ version = {attr = "array_api_compat.__version__"} include = ["array_api_compat*"] namespaces = false -[toolint] +[tool.ruff.lint] preview = true select = [ # Defaults @@ -79,20 +79,42 @@ ignore = [ "E722" ] -[tool.ruff.lint] -preview = true -select = [ -# Defaults -"E4", "E7", "E9", "F", -# Undefined export -"F822", -# Useless import alias -"PLC0414" -] -ignore = [ - # Module import not at top of file - "E402", - # Do not use bare `except` - "E722" +[tool.mypy] +files = ["array_api_compat"] +disallow_incomplete_defs = true +disallow_untyped_decorators = true +disallow_untyped_defs = false # TODO +ignore_missing_imports = false +no_implicit_optional = true +show_error_codes = true +warn_redundant_casts = true +warn_unused_ignores = true +warn_unreachable = true + +[[tool.mypy.overrides]] +module = ["cupy.*", "cupy_backends.*", "dask.*", "jax.*", "ndonnx.*", "sparse.*", "torch.*"] +ignore_missing_imports = true + + +[tool.pyright] +include = ["src", "tests"] +pythonPlatform = "All" + +reportAny = false +reportExplicitAny = false +# missing type stubs +reportAttributeAccessIssue = false +reportUnknownMemberType = false +reportUnknownVariableType = false +# Redundant with mypy checks +reportMissingImports = false +reportMissingTypeStubs = false +# false positives for input validation +reportUnreachable = false +# ruff handles this +reportUnusedParameter = false + +executionEnvironments = [ + { root = "array_api_compat" }, ] From cddc9ef8a19b453b09884987ca6a0626408a1478 Mon Sep 17 00:00:00 2001 From: Guido Imperiale Date: Fri, 6 Jun 2025 12:04:01 +0100 Subject: [PATCH 18/21] ENH: Review exported symbols; redesign `test_all` (#315) Review and discussion at https://github.com/data-apis/array-api-compat/pull/315 --- array_api_compat/_internal.py | 20 +- array_api_compat/common/_aliases.py | 2 - array_api_compat/common/_helpers.py | 2 - array_api_compat/common/_linalg.py | 2 - array_api_compat/cupy/__init__.py | 13 +- array_api_compat/cupy/_aliases.py | 3 +- array_api_compat/cupy/_typing.py | 1 - array_api_compat/cupy/fft.py | 7 +- array_api_compat/cupy/linalg.py | 6 +- array_api_compat/dask/array/__init__.py | 16 +- array_api_compat/dask/array/_aliases.py | 4 - array_api_compat/dask/array/fft.py | 19 +- array_api_compat/dask/array/linalg.py | 35 +-- array_api_compat/numpy/__init__.py | 22 +- array_api_compat/numpy/_aliases.py | 6 +- array_api_compat/numpy/_typing.py | 1 - array_api_compat/numpy/fft.py | 15 +- array_api_compat/numpy/linalg.py | 29 +- array_api_compat/torch/__init__.py | 29 +- array_api_compat/torch/_aliases.py | 5 +- array_api_compat/torch/fft.py | 16 +- array_api_compat/torch/linalg.py | 19 +- tests/test_all.py | 360 ++++++++++++++++++++---- 23 files changed, 435 insertions(+), 197 deletions(-) diff --git a/array_api_compat/_internal.py b/array_api_compat/_internal.py index b1925492..baa39ded 100644 --- a/array_api_compat/_internal.py +++ b/array_api_compat/_internal.py @@ -2,6 +2,7 @@ Internal helpers """ +import importlib from collections.abc import Callable from functools import wraps from inspect import signature @@ -52,8 +53,25 @@ def wrapped_f(*args: object, **kwargs: object) -> object: return inner -__all__ = ["get_xp"] +def clone_module(mod_name: str, globals_: dict[str, object]) -> list[str]: + """Import everything from module, updating globals(). + Returns __all__. + """ + mod = importlib.import_module(mod_name) + # Neither of these two methods is sufficient by itself, + # depending on various idiosyncrasies of the libraries we're wrapping. + objs = {} + exec(f"from {mod.__name__} import *", objs) + + for n in dir(mod): + if not n.startswith("_") and hasattr(mod, n): + objs[n] = getattr(mod, n) + + globals_.update(objs) + return list(objs) + +__all__ = ["get_xp", "clone_module"] def __dir__() -> list[str]: return __all__ diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 51732b91..27b2604b 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -721,8 +721,6 @@ def iinfo(type_: DType | Array, /, xp: Namespace) -> Any: "finfo", "iinfo", ] -_all_ignore = ["inspect", "array_namespace", "NamedTuple"] - def __dir__() -> list[str]: return __all__ diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index cae0ee0b..37f31ec2 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -1062,7 +1062,5 @@ def is_lazy_array(x: object) -> TypeGuard[_ArrayApiObj]: "to_device", ] -_all_ignore = ['lru_cache', 'sys', 'math', 'inspect', 'warnings'] - def __dir__() -> list[str]: return __all__ diff --git a/array_api_compat/common/_linalg.py b/array_api_compat/common/_linalg.py index 3fd9d860..69672af7 100644 --- a/array_api_compat/common/_linalg.py +++ b/array_api_compat/common/_linalg.py @@ -225,8 +225,6 @@ def trace( 'matrix_transpose', 'svdvals', 'vecdot', 'vector_norm', 'diagonal', 'trace'] -_all_ignore = ['math', 'normalize_axis_tuple', 'get_xp', 'np', 'isdtype'] - def __dir__() -> list[str]: return __all__ diff --git a/array_api_compat/cupy/__init__.py b/array_api_compat/cupy/__init__.py index 9a30f95d..af003c5a 100644 --- a/array_api_compat/cupy/__init__.py +++ b/array_api_compat/cupy/__init__.py @@ -1,3 +1,4 @@ +from typing import Final from cupy import * # noqa: F403 # from cupy import * doesn't overwrite these builtin names @@ -5,9 +6,19 @@ # These imports may overwrite names from the import * above. from ._aliases import * # noqa: F403 +from ._info import __array_namespace_info__ # noqa: F401 # See the comment in the numpy __init__.py __import__(__package__ + '.linalg') __import__(__package__ + '.fft') -__array_api_version__ = '2024.12' +__array_api_version__: Final = '2024.12' + +__all__ = sorted( + {name for name in globals() if not name.startswith("__")} + - {"Final", "_aliases", "_info", "_typing"} + | {"__array_api_version__", "__array_namespace_info__", "linalg", "fft"} +) + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index c0473ca4..2752bd98 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -7,7 +7,6 @@ from ..common import _aliases, _helpers from ..common._typing import NestedSequence, SupportsBufferProtocol from .._internal import get_xp -from ._info import __array_namespace_info__ from ._typing import Array, Device, DType bool = cp.bool_ @@ -141,7 +140,7 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array: else: unstack = get_xp(cp)(_aliases.unstack) -__all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', 'astype', +__all__ = _aliases.__all__ + ['asarray', 'astype', 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atan2', 'atanh', 'bitwise_left_shift', 'bitwise_invert', 'bitwise_right_shift', diff --git a/array_api_compat/cupy/_typing.py b/array_api_compat/cupy/_typing.py index d8e49ca7..e5c202dc 100644 --- a/array_api_compat/cupy/_typing.py +++ b/array_api_compat/cupy/_typing.py @@ -1,7 +1,6 @@ from __future__ import annotations __all__ = ["Array", "DType", "Device"] -_all_ignore = ["cp"] from typing import TYPE_CHECKING diff --git a/array_api_compat/cupy/fft.py b/array_api_compat/cupy/fft.py index 2bd11940..53a9a454 100644 --- a/array_api_compat/cupy/fft.py +++ b/array_api_compat/cupy/fft.py @@ -31,7 +31,6 @@ __all__ = fft_all + _fft.__all__ -del get_xp -del cp -del fft_all -del _fft +def __dir__() -> list[str]: + return __all__ + diff --git a/array_api_compat/cupy/linalg.py b/array_api_compat/cupy/linalg.py index 7bc3536e..da301574 100644 --- a/array_api_compat/cupy/linalg.py +++ b/array_api_compat/cupy/linalg.py @@ -43,7 +43,5 @@ __all__ = linalg_all + _linalg.__all__ -del get_xp -del cp -del linalg_all -del _linalg +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/dask/array/__init__.py b/array_api_compat/dask/array/__init__.py index 6d2ea7cd..f78aa8b3 100644 --- a/array_api_compat/dask/array/__init__.py +++ b/array_api_compat/dask/array/__init__.py @@ -1,12 +1,26 @@ from typing import Final -from dask.array import * # noqa: F403 +from ..._internal import clone_module + +__all__ = clone_module("dask.array", globals()) # These imports may overwrite names from the import * above. +from . import _aliases from ._aliases import * # type: ignore[assignment] # noqa: F403 +from ._info import __array_namespace_info__ # noqa: F401 __array_api_version__: Final = "2024.12" +del Final # See the comment in the numpy __init__.py __import__(__package__ + '.linalg') __import__(__package__ + '.fft') + +__all__ = sorted( + set(__all__) + | set(_aliases.__all__) + | {"__array_api_version__", "__array_namespace_info__", "linalg", "fft"} +) + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index bc0302fe..4d1e7341 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -41,7 +41,6 @@ NestedSequence, SupportsBufferProtocol, ) -from ._info import __array_namespace_info__ isdtype = get_xp(np)(_aliases.isdtype) unstack = get_xp(da)(_aliases.unstack) @@ -355,7 +354,6 @@ def count_nonzero( __all__ = [ - "__array_namespace_info__", "count_nonzero", "bool", "int8", "int16", "int32", "int64", @@ -369,8 +367,6 @@ def count_nonzero( "bitwise_left_shift", "bitwise_right_shift", "bitwise_invert", ] # fmt: skip __all__ += _aliases.__all__ -_all_ignore = ["array_namespace", "get_xp", "da", "np"] - def __dir__() -> list[str]: return __all__ diff --git a/array_api_compat/dask/array/fft.py b/array_api_compat/dask/array/fft.py index 68c4280e..44b68e73 100644 --- a/array_api_compat/dask/array/fft.py +++ b/array_api_compat/dask/array/fft.py @@ -1,13 +1,6 @@ -from dask.array.fft import * # noqa: F403 -# dask.array.fft doesn't have __all__. If it is added, replace this with -# -# from dask.array.fft import __all__ as linalg_all -_n: dict[str, object] = {} -exec('from dask.array.fft import *', _n) -for k in ("__builtins__", "Sequence", "annotations", "warnings"): - _n.pop(k, None) -fft_all = list(_n) -del _n, k +from ..._internal import clone_module + +__all__ = clone_module("dask.array.fft", globals()) from ...common import _fft from ..._internal import get_xp @@ -17,5 +10,7 @@ fftfreq = get_xp(da)(_fft.fftfreq) rfftfreq = get_xp(da)(_fft.rfftfreq) -__all__ = fft_all + ["fftfreq", "rfftfreq"] -_all_ignore = ["da", "fft_all", "get_xp", "warnings"] +__all__ += ["fftfreq", "rfftfreq"] + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/dask/array/linalg.py b/array_api_compat/dask/array/linalg.py index 06f596bc..6b3c1011 100644 --- a/array_api_compat/dask/array/linalg.py +++ b/array_api_compat/dask/array/linalg.py @@ -4,27 +4,17 @@ import dask.array as da -# Exports -from dask.array.linalg import * # noqa: F403 -from dask.array import outer # The `matmul` and `tensordot` functions are in both the main and linalg namespaces -from dask.array import matmul, tensordot - +from dask.array import matmul, outer, tensordot -from ..._internal import get_xp +# Exports +from ..._internal import clone_module, get_xp from ...common import _linalg from ...common._typing import Array -from ._aliases import matrix_transpose, vecdot -# dask.array.linalg doesn't have __all__. If it is added, replace this with -# -# from dask.array.linalg import __all__ as linalg_all -_n: dict[str, object] = {} -exec('from dask.array.linalg import *', _n) -for k in ('__builtins__', 'annotations', 'operator', 'warnings', 'Array'): - _n.pop(k, None) -linalg_all = list(_n) -del _n, k +__all__ = clone_module("dask.array.linalg", globals()) + +from ._aliases import matrix_transpose, vecdot EighResult = _linalg.EighResult QRResult = _linalg.QRResult @@ -64,10 +54,11 @@ def svdvals(x: Array) -> Array: vector_norm = get_xp(da)(_linalg.vector_norm) diagonal = get_xp(da)(_linalg.diagonal) -__all__ = linalg_all + ["trace", "outer", "matmul", "tensordot", - "matrix_transpose", "vecdot", "EighResult", - "QRResult", "SlogdetResult", "SVDResult", "qr", - "cholesky", "matrix_rank", "matrix_norm", "svdvals", - "vector_norm", "diagonal"] +__all__ += ["trace", "outer", "matmul", "tensordot", + "matrix_transpose", "vecdot", "EighResult", + "QRResult", "SlogdetResult", "SVDResult", "qr", + "cholesky", "matrix_rank", "matrix_norm", "svdvals", + "vector_norm", "diagonal"] -_all_ignore = ['get_xp', 'da', 'linalg_all', 'warnings'] +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/numpy/__init__.py b/array_api_compat/numpy/__init__.py index bf43fe61..23379e44 100644 --- a/array_api_compat/numpy/__init__.py +++ b/array_api_compat/numpy/__init__.py @@ -1,16 +1,17 @@ # ruff: noqa: PLC0414 from typing import Final -from numpy import * # noqa: F403 # pyright: ignore[reportWildcardImportFromLibrary] +from .._internal import clone_module -# from numpy import * doesn't overwrite these builtin names -from numpy import abs as abs -from numpy import max as max -from numpy import min as min -from numpy import round as round +# This needs to be loaded explicitly before cloning +import numpy.typing # noqa: F401 + +__all__ = clone_module("numpy", globals()) # These imports may overwrite names from the import * above. +from . import _aliases from ._aliases import * # type: ignore[assignment,no-redef] # noqa: F403 +from ._info import __array_namespace_info__ # noqa: F401 # Don't know why, but we have to do an absolute import to import linalg. If we # instead do @@ -26,3 +27,12 @@ from .linalg import matrix_transpose, vecdot # type: ignore[no-redef] # noqa: F401 __array_api_version__: Final = "2024.12" + +__all__ = sorted( + set(__all__) + | set(_aliases.__all__) + | {"__array_api_version__", "__array_namespace_info__", "linalg", "fft"} +) + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index 5a05a820..5bb8869a 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -9,7 +9,6 @@ from .._internal import get_xp from ..common import _aliases, _helpers from ..common._typing import NestedSequence, SupportsBufferProtocol -from ._info import __array_namespace_info__ from ._typing import Array, Device, DType bool = np.bool_ @@ -147,8 +146,7 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array: else: unstack = get_xp(np)(_aliases.unstack) -__all__ = [ - "__array_namespace_info__", +__all__ = _aliases.__all__ + [ "asarray", "astype", "acos", @@ -167,8 +165,6 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array: "pow", "take_along_axis" ] -__all__ += _aliases.__all__ -_all_ignore = ["np", "get_xp"] def __dir__() -> list[str]: diff --git a/array_api_compat/numpy/_typing.py b/array_api_compat/numpy/_typing.py index e771c788..b5fa188c 100644 --- a/array_api_compat/numpy/_typing.py +++ b/array_api_compat/numpy/_typing.py @@ -23,7 +23,6 @@ Array: TypeAlias = np.ndarray __all__ = ["Array", "DType", "Device"] -_all_ignore = ["np"] def __dir__() -> list[str]: diff --git a/array_api_compat/numpy/fft.py b/array_api_compat/numpy/fft.py index 06875f00..a492feb8 100644 --- a/array_api_compat/numpy/fft.py +++ b/array_api_compat/numpy/fft.py @@ -1,6 +1,8 @@ import numpy as np -from numpy.fft import __all__ as fft_all -from numpy.fft import fft2, ifft2, irfft2, rfft2 + +from .._internal import clone_module + +__all__ = clone_module("numpy.fft", globals()) from .._internal import get_xp from ..common import _fft @@ -21,15 +23,8 @@ ifftshift = get_xp(np)(_fft.ifftshift) -__all__ = ["rfft2", "irfft2", "fft2", "ifft2"] -__all__ += _fft.__all__ - +__all__ = sorted(set(__all__) | set(_fft.__all__)) def __dir__() -> list[str]: return __all__ - -del get_xp -del np -del fft_all -del _fft diff --git a/array_api_compat/numpy/linalg.py b/array_api_compat/numpy/linalg.py index 9a618be9..7168441c 100644 --- a/array_api_compat/numpy/linalg.py +++ b/array_api_compat/numpy/linalg.py @@ -7,26 +7,11 @@ import numpy as np -# intersection of `np.linalg.__all__` on numpy 1.22 and 2.2, minus `_linalg.__all__` -from numpy.linalg import ( - LinAlgError, - cond, - det, - eig, - eigvals, - eigvalsh, - inv, - lstsq, - matrix_power, - multi_dot, - norm, - tensorinv, - tensorsolve, -) - -from .._internal import get_xp +from .._internal import clone_module, get_xp from ..common import _linalg +__all__ = clone_module("numpy.linalg", globals()) + # These functions are in both the main and linalg namespaces from ._aliases import matmul, matrix_transpose, tensordot, vecdot # noqa: F401 from ._typing import Array @@ -120,7 +105,7 @@ def solve(x1: Array, x2: Array, /) -> Array: vector_norm = get_xp(np)(_linalg.vector_norm) -__all__ = [ +_all = [ "LinAlgError", "cond", "det", @@ -132,12 +117,12 @@ def solve(x1: Array, x2: Array, /) -> Array: "matrix_power", "multi_dot", "norm", + "solve", "tensorinv", "tensorsolve", + "vector_norm", ] -__all__ += _linalg.__all__ -__all__ += ["solve", "vector_norm"] - +__all__ = sorted(set(__all__) | set(_linalg.__all__) | set(_all)) def __dir__() -> list[str]: return __all__ diff --git a/array_api_compat/torch/__init__.py b/array_api_compat/torch/__init__.py index 69fd19ce..6cbb6ec2 100644 --- a/array_api_compat/torch/__init__.py +++ b/array_api_compat/torch/__init__.py @@ -1,22 +1,25 @@ -from torch import * # noqa: F403 +from typing import Final -# Several names are not included in the above import * -import torch -for n in dir(torch): - if (n.startswith('_') - or n.endswith('_') - or 'cuda' in n - or 'cpu' in n - or 'backward' in n): - continue - exec(f"{n} = torch.{n}") -del n +from .._internal import clone_module + +__all__ = clone_module("torch", globals()) # These imports may overwrite names from the import * above. +from . import _aliases from ._aliases import * # noqa: F403 +from ._info import __array_namespace_info__ # noqa: F401 # See the comment in the numpy __init__.py __import__(__package__ + '.linalg') __import__(__package__ + '.fft') -__array_api_version__ = '2024.12' +__array_api_version__: Final = '2024.12' + +__all__ = sorted( + set(__all__) + | set(_aliases.__all__) + | {"__array_api_version__", "__array_namespace_info__", "linalg", "fft"} +) + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 7a449001..91161d24 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -10,7 +10,6 @@ from .._internal import get_xp from ..common import _aliases from ..common._typing import NestedSequence, SupportsBufferProtocol -from ._info import __array_namespace_info__ from ._typing import Array, Device, DType _int_dtypes = { @@ -830,7 +829,7 @@ def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> list[Array return list(torch.meshgrid(*arrays, indexing='xy')) -__all__ = ['__array_namespace_info__', 'asarray', 'result_type', 'can_cast', +__all__ = ['asarray', 'result_type', 'can_cast', 'permute_dims', 'bitwise_invert', 'newaxis', 'conj', 'add', 'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'copysign', 'count_nonzero', @@ -847,5 +846,3 @@ def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> list[Array 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values', 'matmul', 'matrix_transpose', 'vecdot', 'tensordot', 'isdtype', 'take', 'take_along_axis', 'sign', 'finfo', 'iinfo', 'repeat', 'meshgrid'] - -_all_ignore = ['torch', 'get_xp'] diff --git a/array_api_compat/torch/fft.py b/array_api_compat/torch/fft.py index ddf87c65..76342980 100644 --- a/array_api_compat/torch/fft.py +++ b/array_api_compat/torch/fft.py @@ -5,9 +5,11 @@ import torch import torch.fft -from torch.fft import * # noqa: F403 from ._typing import Array +from .._internal import clone_module + +__all__ = clone_module("torch.fft", globals()) # Several torch fft functions do not map axes to dim @@ -74,13 +76,7 @@ def ifftshift( return torch.fft.ifftshift(x, dim=axes, **kwargs) -__all__ = torch.fft.__all__ + [ - "fftn", - "ifftn", - "rfftn", - "irfftn", - "fftshift", - "ifftshift", -] +__all__ += ["fftn", "ifftn", "rfftn", "irfftn", "fftshift", "ifftshift"] -_all_ignore = ['torch'] +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/torch/linalg.py b/array_api_compat/torch/linalg.py index 558cfe7b..08271d22 100644 --- a/array_api_compat/torch/linalg.py +++ b/array_api_compat/torch/linalg.py @@ -1,12 +1,11 @@ from __future__ import annotations import torch -from torch.linalg import * # noqa: F403 +import torch.linalg -# torch.linalg doesn't define __all__ -# from torch.linalg import __all__ as linalg_all -from torch import linalg as torch_linalg -linalg_all = [i for i in dir(torch_linalg) if not i.startswith('_')] +from .._internal import clone_module + +__all__ = clone_module("torch.linalg", globals()) # outer is implemented in torch but aren't in the linalg namespace from torch import outer @@ -28,7 +27,7 @@ def cross(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: if not (x1.shape[axis] == x2.shape[axis] == 3): raise ValueError(f"cross product axis must have size 3, got {x1.shape[axis]} and {x2.shape[axis]}") x1, x2 = torch.broadcast_tensors(x1, x2) - return torch_linalg.cross(x1, x2, dim=axis) + return torch.linalg.cross(x1, x2, dim=axis) def vecdot(x1: Array, x2: Array, /, *, axis: int = -1, **kwargs: object) -> Array: from ._aliases import isdtype @@ -108,12 +107,8 @@ def vector_norm( return out return torch.linalg.vector_norm(x, ord=ord, axis=axis, keepdim=keepdims, **kwargs) -__all__ = linalg_all + ['outer', 'matmul', 'matrix_transpose', 'tensordot', - 'cross', 'vecdot', 'solve', 'trace', 'vector_norm'] - -_all_ignore = ['torch_linalg', 'sum'] - -del linalg_all +__all__ += ['outer', 'matmul', 'matrix_transpose', 'tensordot', + 'cross', 'vecdot', 'solve', 'trace', 'vector_norm'] def __dir__() -> list[str]: return __all__ diff --git a/tests/test_all.py b/tests/test_all.py index 271cd189..c36aef67 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -1,63 +1,311 @@ -""" -Test that files that define __all__ aren't missing any exports. +"""Test exported names""" -You can add names that shouldn't be exported to _all_ignore, like +import builtins -_all_ignore = ['sys'] +import numpy as np +import pytest -This is preferable to del-ing the names as this will break any name that is -used inside of a function. Note that names starting with an underscore are automatically ignored. -""" +from array_api_compat._internal import clone_module +from ._helpers import wrapped_libraries -import sys +NAMES = { + "": [ + # Inspection + "__array_api_version__", + "__array_namespace_info__", + # Submodules + "fft", + "linalg", + # Constants + "e", + "inf", + "nan", + "newaxis", + "pi", + # Creation Functions + "arange", + "asarray", + "empty", + "empty_like", + "eye", + "from_dlpack", + "full", + "full_like", + "linspace", + "meshgrid", + "ones", + "ones_like", + "tril", + "triu", + "zeros", + "zeros_like", + # Data Type Functions + "astype", + "can_cast", + "finfo", + "iinfo", + "isdtype", + "result_type", + # Data Types + "bool", + "int8", + "int16", + "int32", + "int64", + "uint8", + "uint16", + "uint32", + "uint64", + "float32", + "float64", + "complex64", + "complex128", + # Elementwise Functions + "abs", + "acos", + "acosh", + "add", + "asin", + "asinh", + "atan", + "atan2", + "atanh", + "bitwise_and", + "bitwise_left_shift", + "bitwise_invert", + "bitwise_or", + "bitwise_right_shift", + "bitwise_xor", + "ceil", + "clip", + "conj", + "copysign", + "cos", + "cosh", + "divide", + "equal", + "exp", + "expm1", + "floor", + "floor_divide", + "greater", + "greater_equal", + "hypot", + "imag", + "isfinite", + "isinf", + "isnan", + "less", + "less_equal", + "log", + "log1p", + "log2", + "log10", + "logaddexp", + "logical_and", + "logical_not", + "logical_or", + "logical_xor", + "maximum", + "minimum", + "multiply", + "negative", + "nextafter", + "not_equal", + "positive", + "pow", + "real", + "reciprocal", + "remainder", + "round", + "sign", + "signbit", + "sin", + "sinh", + "square", + "sqrt", + "subtract", + "tan", + "tanh", + "trunc", + # Indexing Functions + "take", + "take_along_axis", + # Linear Algebra Functions + "matmul", + "matrix_transpose", + "tensordot", + "vecdot", + # Manipulation Functions + "broadcast_arrays", + "broadcast_to", + "concat", + "expand_dims", + "flip", + "moveaxis", + "permute_dims", + "repeat", + "reshape", + "roll", + "squeeze", + "stack", + "tile", + "unstack", + # Searching Functions + "argmax", + "argmin", + "count_nonzero", + "nonzero", + "searchsorted", + "where", + # Set Functions + "unique_all", + "unique_counts", + "unique_inverse", + "unique_values", + # Sorting Functions + "argsort", + "sort", + # Statistical Functions + "cumulative_prod", + "cumulative_sum", + "max", + "mean", + "min", + "prod", + "std", + "sum", + "var", + # Utility Functions + "all", + "any", + "diff", + ], + "fft": [ + "fft", + "ifft", + "fftn", + "ifftn", + "rfft", + "irfft", + "rfftn", + "irfftn", + "hfft", + "ihfft", + "fftfreq", + "rfftfreq", + "fftshift", + "ifftshift", + ], + "linalg": [ + "cholesky", + "cross", + "det", + "diagonal", + "eigh", + "eigvalsh", + "inv", + "matmul", + "matrix_norm", + "matrix_power", + "matrix_rank", + "matrix_transpose", + "outer", + "pinv", + "qr", + "slogdet", + "solve", + "svd", + "svdvals", + "tensordot", + "trace", + "vecdot", + "vector_norm", + ], +} -from ._helpers import import_, wrapped_libraries +XFAILS = { + ("numpy", ""): ["from_dlpack"] if np.__version__ < "1.23" else [], + ("dask.array", ""): ["from_dlpack", "take_along_axis"], + ("dask.array", "linalg"): [ + "cross", + "det", + "eigh", + "eigvalsh", + "matrix_power", + "pinv", + "slogdet", + ], +} -import pytest -import typing - -TYPING_NAMES = frozenset(( - "Array", - "Device", - "DType", - "Namespace", - "NestedSequence", - "SupportsBufferProtocol", -)) - -@pytest.mark.parametrize("library", ["common"] + wrapped_libraries) -def test_all(library): - if library == "common": - import array_api_compat.common # noqa: F401 - else: - import_(library, wrapper=True) - - # NB: iterate over a copy to avoid a "dictionary size changed" error - for mod_name in sys.modules.copy(): - if not mod_name.startswith('array_api_compat.' + library): - continue - - module = sys.modules[mod_name] - - # TODO: We should define __all__ in the __init__.py files and test it - # there too. - if not hasattr(module, '__all__'): - continue - - dir_names = [n for n in dir(module) if not n.startswith('_')] - if '__array_namespace_info__' in dir(module): - dir_names.append('__array_namespace_info__') - ignore_all_names = set(getattr(module, '_all_ignore', ())) - ignore_all_names |= set(dir(typing)) - ignore_all_names |= {"annotations"} - if not module.__name__.endswith("._typing"): - ignore_all_names |= TYPING_NAMES - dir_names = set(dir_names) - set(ignore_all_names) - all_names = module.__all__ - - if set(dir_names) != set(all_names): - extra_dir = set(dir_names) - set(all_names) - extra_all = set(all_names) - set(dir_names) - assert not extra_dir, f"Some dir() names not included in __all__ for {mod_name}: {extra_dir}" - assert not extra_all, f"Some __all__ names not in dir() for {mod_name}: {extra_all}" + +def all_names(mod): + """Return all names available in a module.""" + objs = {} + clone_module(mod.__name__, objs) + return set(objs) + + +def get_mod(library, module, *, compat): + if compat: + library = f"array_api_compat.{library}" + xp = pytest.importorskip(library) + return getattr(xp, module) if module else xp + + +@pytest.mark.parametrize("module", list(NAMES)) +@pytest.mark.parametrize("library", wrapped_libraries) +def test_array_api_names(library, module): + """Test that __all__ isn't missing any exports + dictated by the Standard. + """ + mod = get_mod(library, module, compat=True) + missing = set(NAMES[module]) - all_names(mod) + xfail = set(XFAILS.get((library, module), [])) + xpass = xfail - missing + fails = missing - xfail + assert not xpass, f"Names in XFAILS are defined: {xpass}" + assert not fails, f"Missing exports: {fails}" + + +@pytest.mark.parametrize("module", list(NAMES)) +@pytest.mark.parametrize("library", wrapped_libraries) +def test_compat_doesnt_hide_names(library, module): + """The base namespace can have more names than the ones explicitly exported + by array-api-compat. Test that we're not suppressing them. + """ + bare_mod = get_mod(library, module, compat=False) + compat_mod = get_mod(library, module, compat=True) + + missing = all_names(bare_mod) - all_names(compat_mod) + missing = {name for name in missing if not name.startswith("_")} + assert not missing, f"Non-Array API names have been hidden: {missing}" + + +@pytest.mark.parametrize("module", list(NAMES)) +@pytest.mark.parametrize("library", wrapped_libraries) +def test_compat_doesnt_add_names(library, module): + """Test that array-api-compat isn't adding names to the namespace + besides those defined by the Array API Standard. + """ + bare_mod = get_mod(library, module, compat=False) + compat_mod = get_mod(library, module, compat=True) + + aapi_names = set(NAMES[module]) + spurious = all_names(compat_mod) - all_names(bare_mod) - aapi_names + # Quietly ignore *Result dataclasses + spurious = {name for name in spurious if not name.endswith("Result")} + assert not spurious, ( + f"array-api-compat is adding non-Array API names: {spurious}" + ) + + +@pytest.mark.parametrize( + "name", [name for name in NAMES[""] if hasattr(builtins, name)] +) +@pytest.mark.parametrize("library", wrapped_libraries) +def test_builtins_collision(library, name): + """Test that xp.bool is not accidentally builtins.bool, etc.""" + xp = pytest.importorskip(f"array_api_compat.{library}") + assert getattr(xp, name) is not getattr(builtins, name) From 1d1178d33f7af737abf697a76fb161901faa075d Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 16 Jun 2025 10:16:28 +0000 Subject: [PATCH 19/21] Bump dawidd6/action-download-artifact from 10 to 11 in the actions group Bumps the actions group with 1 update: [dawidd6/action-download-artifact](https://github.com/dawidd6/action-download-artifact). Updates `dawidd6/action-download-artifact` from 10 to 11 - [Release notes](https://github.com/dawidd6/action-download-artifact/releases) - [Commits](https://github.com/dawidd6/action-download-artifact/compare/v10...v11) --- updated-dependencies: - dependency-name: dawidd6/action-download-artifact dependency-version: '11' dependency-type: direct:production update-type: version-update:semver-major dependency-group: actions ... Signed-off-by: dependabot[bot] --- .github/workflows/docs-deploy.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/docs-deploy.yml b/.github/workflows/docs-deploy.yml index 4e3efb39..ed90b29d 100644 --- a/.github/workflows/docs-deploy.yml +++ b/.github/workflows/docs-deploy.yml @@ -13,7 +13,7 @@ jobs: steps: - uses: actions/checkout@v4 - name: Download Artifact - uses: dawidd6/action-download-artifact@v10 + uses: dawidd6/action-download-artifact@v11 with: workflow: docs-build.yml name: docs-build From 4bafa4cc8a455a301f3688fd3fa7404a4fe00974 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 11 Aug 2025 14:10:23 +0000 Subject: [PATCH 20/21] Bump actions/download-artifact from 4 to 5 in the actions group Bumps the actions group with 1 update: [actions/download-artifact](https://github.com/actions/download-artifact). Updates `actions/download-artifact` from 4 to 5 - [Release notes](https://github.com/actions/download-artifact/releases) - [Commits](https://github.com/actions/download-artifact/compare/v4...v5) --- updated-dependencies: - dependency-name: actions/download-artifact dependency-version: '5' dependency-type: direct:production update-type: version-update:semver-major dependency-group: actions ... Signed-off-by: dependabot[bot] --- .github/workflows/publish-package.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/publish-package.yml b/.github/workflows/publish-package.yml index 6d88066d..1e28689c 100644 --- a/.github/workflows/publish-package.yml +++ b/.github/workflows/publish-package.yml @@ -81,7 +81,7 @@ jobs: steps: - name: Download distribution artifact - uses: actions/download-artifact@v4 + uses: actions/download-artifact@v5 with: name: dist-artifact path: dist From edd9072c296827d0e4eccf02ae87920eb2481b9c Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 18 Aug 2025 13:07:35 +0000 Subject: [PATCH 21/21] Bump actions/checkout from 4 to 5 in the actions group Bumps the actions group with 1 update: [actions/checkout](https://github.com/actions/checkout). Updates `actions/checkout` from 4 to 5 - [Release notes](https://github.com/actions/checkout/releases) - [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md) - [Commits](https://github.com/actions/checkout/compare/v4...v5) --- updated-dependencies: - dependency-name: actions/checkout dependency-version: '5' dependency-type: direct:production update-type: version-update:semver-major dependency-group: actions ... Signed-off-by: dependabot[bot] --- .github/workflows/array-api-tests.yml | 4 ++-- .github/workflows/docs-build.yml | 2 +- .github/workflows/docs-deploy.yml | 2 +- .github/workflows/publish-package.yml | 2 +- .github/workflows/ruff.yml | 2 +- .github/workflows/tests.yml | 2 +- 6 files changed, 7 insertions(+), 7 deletions(-) diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index e832f870..5c3cc7d9 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -49,12 +49,12 @@ jobs: steps: - name: Checkout array-api-compat - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: path: array-api-compat - name: Checkout array-api-tests - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: repository: data-apis/array-api-tests submodules: 'true' diff --git a/.github/workflows/docs-build.yml b/.github/workflows/docs-build.yml index 34b9cbc6..778d20e2 100644 --- a/.github/workflows/docs-build.yml +++ b/.github/workflows/docs-build.yml @@ -6,7 +6,7 @@ jobs: docs-build: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - uses: actions/setup-python@v5 - name: Install Dependencies run: | diff --git a/.github/workflows/docs-deploy.yml b/.github/workflows/docs-deploy.yml index ed90b29d..42a3598f 100644 --- a/.github/workflows/docs-deploy.yml +++ b/.github/workflows/docs-deploy.yml @@ -11,7 +11,7 @@ jobs: environment: name: docs-deploy steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - name: Download Artifact uses: dawidd6/action-download-artifact@v11 with: diff --git a/.github/workflows/publish-package.yml b/.github/workflows/publish-package.yml index 1e28689c..03cae174 100644 --- a/.github/workflows/publish-package.yml +++ b/.github/workflows/publish-package.yml @@ -30,7 +30,7 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 with: fetch-depth: 0 diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml index a9f0fd4b..68f68a14 100644 --- a/.github/workflows/ruff.yml +++ b/.github/workflows/ruff.yml @@ -5,7 +5,7 @@ jobs: runs-on: ubuntu-latest continue-on-error: true steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - name: Install Python uses: actions/setup-python@v5 with: diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index c995b370..d2e768eb 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -23,7 +23,7 @@ jobs: python-version: '3.13' steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }}