diff --git a/.github/workflows/array-api-tests-dask.yml b/.github/workflows/array-api-tests-dask.yml index b0ce007e..ef430d9c 100644 --- a/.github/workflows/array-api-tests-dask.yml +++ b/.github/workflows/array-api-tests-dask.yml @@ -9,4 +9,12 @@ jobs: package-name: dask module-name: dask.array extra-requires: numpy - pytest-extra-args: --disable-deadline --max-examples=5 + # Dask is substantially slower then other libraries on unit tests. + # Reduce the number of examples to speed up CI, even though this means that this + # workflow is barely more than a smoke test, and one should expect extreme + # flakiness. Before changes to dask-xfails.txt or dask-skips.txt, please run + # the full test suite with at least 200 examples. + pytest-extra-args: --max-examples=200 -n 4 + python-versions: '[''3.10'', ''3.13'']' + extra-env-vars: | + ARRAY_API_TESTS_XFAIL_MARK=skip diff --git a/.github/workflows/array-api-tests-numpy-1-21.yml b/.github/workflows/array-api-tests-numpy-1-21.yml deleted file mode 100644 index 2d81c3cd..00000000 --- a/.github/workflows/array-api-tests-numpy-1-21.yml +++ /dev/null @@ -1,11 +0,0 @@ -name: Array API Tests (NumPy 1.21) - -on: [push, pull_request] - -jobs: - array-api-tests-numpy-1-21: - uses: ./.github/workflows/array-api-tests.yml - with: - package-name: numpy - package-version: '== 1.21.*' - xfails-file-extra: '-1-21' diff --git a/.github/workflows/array-api-tests-numpy-1-22.yml b/.github/workflows/array-api-tests-numpy-1-22.yml new file mode 100644 index 00000000..83d4cf1d --- /dev/null +++ b/.github/workflows/array-api-tests-numpy-1-22.yml @@ -0,0 +1,15 @@ +name: Array API Tests (NumPy 1.22) + +on: [push, pull_request] + +jobs: + array-api-tests-numpy-1-22: + uses: ./.github/workflows/array-api-tests.yml + with: + package-name: numpy + package-version: '== 1.22.*' + xfails-file-extra: '-1-22' + python-versions: '[''3.10'']' + pytest-extra-args: -n 4 + extra-env-vars: | + ARRAY_API_TESTS_XFAIL_MARK=skip diff --git a/.github/workflows/array-api-tests-numpy-1-26.yml b/.github/workflows/array-api-tests-numpy-1-26.yml index 660935f0..13124644 100644 --- a/.github/workflows/array-api-tests-numpy-1-26.yml +++ b/.github/workflows/array-api-tests-numpy-1-26.yml @@ -9,3 +9,7 @@ jobs: package-name: numpy package-version: '== 1.26.*' xfails-file-extra: '-1-26' + python-versions: '[''3.10'', ''3.12'']' + pytest-extra-args: -n 4 + extra-env-vars: | + ARRAY_API_TESTS_XFAIL_MARK=skip diff --git a/.github/workflows/array-api-tests-numpy-dev.yml b/.github/workflows/array-api-tests-numpy-dev.yml index eef4269d..dec4c7ae 100644 --- a/.github/workflows/array-api-tests-numpy-dev.yml +++ b/.github/workflows/array-api-tests-numpy-dev.yml @@ -9,3 +9,7 @@ jobs: package-name: numpy extra-requires: '--pre --extra-index https://pypi.anaconda.org/scientific-python-nightly-wheels/simple' xfails-file-extra: '-dev' + python-versions: '[''3.11'', ''3.13'']' + pytest-extra-args: -n 4 + extra-env-vars: | + ARRAY_API_TESTS_XFAIL_MARK=skip diff --git a/.github/workflows/array-api-tests-numpy-latest.yml b/.github/workflows/array-api-tests-numpy-latest.yml index 36984345..65bbc9a2 100644 --- a/.github/workflows/array-api-tests-numpy-latest.yml +++ b/.github/workflows/array-api-tests-numpy-latest.yml @@ -1,4 +1,4 @@ -name: Array API Tests (NumPy Latest) +name: Array API Tests (NumPy latest) on: [push, pull_request] @@ -7,3 +7,7 @@ jobs: uses: ./.github/workflows/array-api-tests.yml with: package-name: numpy + python-versions: '[''3.10'', ''3.13'']' + pytest-extra-args: -n 4 + extra-env-vars: | + ARRAY_API_TESTS_XFAIL_MARK=skip diff --git a/.github/workflows/array-api-tests-torch.yml b/.github/workflows/array-api-tests-torch.yml index 98234ae2..4b4b945e 100644 --- a/.github/workflows/array-api-tests-torch.yml +++ b/.github/workflows/array-api-tests-torch.yml @@ -1,4 +1,4 @@ -name: Array API Tests (PyTorch Latest) +name: Array API Tests (PyTorch CPU) on: [push, pull_request] @@ -7,8 +7,9 @@ jobs: uses: ./.github/workflows/array-api-tests.yml with: package-name: torch - # Proper linalg testing will require - # https://github.com/data-apis/array-api-tests/pull/101 - pytest-extra-args: "--disable-extension linalg" + extra-requires: '--index-url https://download.pytorch.org/whl/cpu' extra-env-vars: | ARRAY_API_TESTS_SKIP_DTYPES=uint16,uint32,uint64 + ARRAY_API_TESTS_XFAIL_MARK=skip + python-versions: '[''3.10'', ''3.13'']' + pytest-extra-args: -n 4 diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index 254e4e61..e3c0c9e0 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -16,6 +16,10 @@ on: required: false type: string default: '>= 0' + python-versions: + required: true + type: string + description: JSON array of Python versions to test against. pytest-extra-args: required: false type: string @@ -30,51 +34,57 @@ on: extra-env-vars: required: false type: string - description: "Multiline string of environment variables to set for the test run." + description: Multiline string of environment variables to set for the test run. env: - PYTEST_ARGS: "--max-examples 200 -v -rxXfE --ci ${{ inputs.pytest-extra-args }} --hypothesis-disable-deadline" + PYTEST_ARGS: "--max-examples 1000 -v -rxXfE ${{ inputs.pytest-extra-args }} --hypothesis-disable-deadline --durations 20" jobs: tests: runs-on: ubuntu-latest strategy: + fail-fast: false matrix: - python-version: ['3.9', '3.10', '3.11', '3.12'] + python-version: ${{ fromJson(inputs.python-versions) }} steps: - name: Checkout array-api-compat - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: path: array-api-compat + - name: Checkout array-api-tests - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: repository: data-apis/array-api-tests submodules: 'true' path: array-api-tests + - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: ${{ matrix.python-version }} + - name: Set Extra Environment Variables # Set additional environment variables if provided if: inputs.extra-env-vars run: | echo "${{ inputs.extra-env-vars }}" >> $GITHUB_ENV + - name: Install dependencies - # NumPy 1.21 doesn't support Python 3.11. There doesn't seem to be a way - # to put this in the numpy 1.21 config file. - if: "! ((matrix.python-version == '3.11' || matrix.python-version == '3.12') && inputs.package-name == 'numpy' && contains(inputs.package-version, '1.21'))" run: | python -m pip install --upgrade pip python -m pip install '${{ inputs.package-name }} ${{ inputs.package-version }}' ${{ inputs.extra-requires }} python -m pip install -r ${GITHUB_WORKSPACE}/array-api-tests/requirements.txt + python -m pip install pytest-xdist + + - name: Dump pip environment + run: pip freeze + - name: Run the array API testsuite (${{ inputs.package-name }}) - if: "! ((matrix.python-version == '3.11' || matrix.python-version == '3.12') && inputs.package-name == 'numpy' && contains(inputs.package-version, '1.21'))" env: ARRAY_API_TESTS_MODULE: array_api_compat.${{ inputs.module-name || inputs.package-name }} - ARRAY_API_TESTS_VERSION: 2023.12 + ARRAY_API_TESTS_VERSION: 2024.12 # This enables the NEP 50 type promotion behavior (without it a lot of # tests fail on bad scalar type promotion behavior) NPY_PROMOTION_STATE: weak diff --git a/.github/workflows/docs-build.yml b/.github/workflows/docs-build.yml index 04c3aa66..305a9003 100644 --- a/.github/workflows/docs-build.yml +++ b/.github/workflows/docs-build.yml @@ -6,11 +6,11 @@ jobs: docs-build: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 - - uses: actions/setup-python@v5 + - uses: actions/checkout@v5 + - uses: actions/setup-python@v6 - name: Install Dependencies run: | - python -m pip install -r docs/requirements.txt + python -m pip install .[docs] - name: Build Docs run: | cd docs diff --git a/.github/workflows/docs-deploy.yml b/.github/workflows/docs-deploy.yml index 9aa379de..42a3598f 100644 --- a/.github/workflows/docs-deploy.yml +++ b/.github/workflows/docs-deploy.yml @@ -11,9 +11,9 @@ jobs: environment: name: docs-deploy steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - name: Download Artifact - uses: dawidd6/action-download-artifact@v6 + uses: dawidd6/action-download-artifact@v11 with: workflow: docs-build.yml name: docs-build diff --git a/.github/workflows/publish-package.yml b/.github/workflows/publish-package.yml index e391ca46..bbfb2e80 100644 --- a/.github/workflows/publish-package.yml +++ b/.github/workflows/publish-package.yml @@ -30,24 +30,25 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 with: fetch-depth: 0 - name: Set up Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: '3.x' - name: Install python-build and twine run: | - python -m pip install --upgrade pip setuptools + python -m pip install --upgrade pip "setuptools<=67" python -m pip install build twine python -m pip list - name: Build a wheel and a sdist run: | - PYTHONWARNINGS=error,default::DeprecationWarning python -m build . + #PYTHONWARNINGS=error,default::DeprecationWarning python -m build . + python -m build . - name: Verify the distribution run: twine check --strict dist/* @@ -80,7 +81,7 @@ jobs: steps: - name: Download distribution artifact - uses: actions/download-artifact@v4 + uses: actions/download-artifact@v5 with: name: dist-artifact path: dist @@ -88,15 +89,21 @@ jobs: - name: List all files run: ls -lh dist - - name: Publish distribution 📦 to Test PyPI - # Publish to TestPyPI on tag events of if manually triggered - # Compare to 'true' string as booleans get turned into strings in the console - if: >- - (github.event_name == 'push' && startsWith(github.ref, 'refs/tags')) - || (github.event_name == 'workflow_dispatch' && github.event.inputs.publish == 'true') - uses: pypa/gh-action-pypi-publish@v1.10.1 + # - name: Publish distribution 📦 to Test PyPI + # # Publish to TestPyPI on tag events of if manually triggered + # # Compare to 'true' string as booleans get turned into strings in the console + # if: >- + # (github.event_name == 'push' && startsWith(github.ref, 'refs/tags')) + # || (github.event_name == 'workflow_dispatch' && github.event.inputs.publish == 'true') + # uses: pypa/gh-action-pypi-publish@v1.13.0 + # with: + # repository-url: https://test.pypi.org/legacy/ + # print-hash: true + + - name: Publish distribution 📦 to PyPI + if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags') + uses: pypa/gh-action-pypi-publish@v1.13.0 with: - repository-url: https://test.pypi.org/legacy/ print-hash: true - name: Create GitHub Release from a Tag @@ -104,9 +111,3 @@ jobs: if: startsWith(github.ref, 'refs/tags/') with: files: dist/* - - - name: Publish distribution 📦 to PyPI - if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags') - uses: pypa/gh-action-pypi-publish@v1.10.1 - with: - print-hash: true diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml index a9f0fd4b..4a2ffcff 100644 --- a/.github/workflows/ruff.yml +++ b/.github/workflows/ruff.yml @@ -5,9 +5,9 @@ jobs: runs-on: ubuntu-latest continue-on-error: true steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - name: Install Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: "3.11" - name: Install dependencies diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index fcd43367..cfbb875f 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -4,43 +4,55 @@ jobs: tests: runs-on: ubuntu-latest strategy: + fail-fast: false matrix: - python-version: ['3.9', '3.10', '3.11', '3.12'] - numpy-version: ['1.21', '1.26', '2.0', 'dev'] - exclude: - - python-version: '3.11' - numpy-version: '1.21' - - python-version: '3.12' - numpy-version: '1.21' - fail-fast: true + include: + - numpy-version: '1.22' + python-version: '3.10' + - numpy-version: '1.26' + python-version: '3.10' + - numpy-version: '1.26' + python-version: '3.12' + - numpy-version: 'latest' + python-version: '3.10' + - numpy-version: 'latest' + python-version: '3.13' + - numpy-version: 'dev' + python-version: '3.11' + - numpy-version: 'dev' + python-version: '3.13' + steps: - - uses: actions/checkout@v4 - - uses: actions/setup-python@v5 + - uses: actions/checkout@v5 + - uses: actions/setup-python@v6 with: python-version: ${{ matrix.python-version }} - name: Install Dependencies run: | python -m pip install --upgrade pip + python -m pip install pytest + + # Don't `pip install .[dev]` as it would pull in the whole torch cuda stack + python -m pip install array-api-strict + python -m pip install torch --index-url https://download.pytorch.org/whl/cpu + if [ "${{ matrix.numpy-version }}" == "dev" ]; then - PIP_EXTRA='numpy --pre --extra-index-url https://pypi.anaconda.org/scientific-python-nightly-wheels/simple' - elif [ "${{ matrix.numpy-version }}" == "1.21" ]; then - PIP_EXTRA='numpy==1.21.*' + python -m pip install numpy --pre --extra-index-url https://pypi.anaconda.org/scientific-python-nightly-wheels/simple + python -m pip install dask[array] jax[cpu] sparse ndonnx + elif [ "${{ matrix.numpy-version }}" == "1.22" ]; then + python -m pip install 'numpy==1.22.*' + elif [ "${{ matrix.numpy-version }}" == "1.26" ]; then + python -m pip install 'numpy==1.26.*' else - PIP_EXTRA='numpy==1.26.*' + python -m pip install numpy + python -m pip install dask[array] jax[cpu] sparse ndonnx fi - if [ "${{ matrix.python-version }}" == "3.9" ]; then - sed -i '/^ndonnx/d' requirements-dev.txt - fi + - name: Dump pip environment + run: pip freeze - python -m pip install -r requirements-dev.txt $PIP_EXTRA + - name: Test it installs + run: python -m pip install . - name: Run Tests - run: | - if [[ "${{ matrix.numpy-version }}" == "1.21" || "${{ matrix.numpy-version }}" == "dev" ]]; then - PYTEST_EXTRA=(-k "numpy and not jax and not torch and not dask and not sparse") - fi - pytest -v "${PYTEST_EXTRA[@]}" - - # Make sure it installs - python -m pip install . + run: pytest -v diff --git a/array_api_compat/__init__.py b/array_api_compat/__init__.py index 5815fb27..a00e8cbc 100644 --- a/array_api_compat/__init__.py +++ b/array_api_compat/__init__.py @@ -1,9 +1,9 @@ """ NumPy Array API compatibility library -This is a small wrapper around NumPy and CuPy that is compatible with the -Array API standard https://data-apis.org/array-api/latest/. See also NEP 47 -https://numpy.org/neps/nep-0047-array-api-standard.html. +This is a small wrapper around NumPy, CuPy, JAX, sparse and others that are +compatible with the Array API standard https://data-apis.org/array-api/latest/. +See also NEP 47 https://numpy.org/neps/nep-0047-array-api-standard.html. Unlike array_api_strict, this is not a strict minimal implementation of the Array API, but rather just an extension of the main NumPy namespace with @@ -17,6 +17,6 @@ this implementation for the default when working with NumPy arrays. """ -__version__ = '1.9' +__version__ = '1.13.0.dev0' from .common import * # noqa: F401, F403 diff --git a/array_api_compat/_internal.py b/array_api_compat/_internal.py index 170a1ff9..baa39ded 100644 --- a/array_api_compat/_internal.py +++ b/array_api_compat/_internal.py @@ -2,10 +2,17 @@ Internal helpers """ +import importlib +from collections.abc import Callable from functools import wraps from inspect import signature +from types import ModuleType +from typing import TypeVar -def get_xp(xp): +_T = TypeVar("_T") + + +def get_xp(xp: ModuleType) -> Callable[[Callable[..., _T]], Callable[..., _T]]: """ Decorator to automatically replace xp with the corresponding array module. @@ -22,14 +29,14 @@ def func(x, /, xp, kwarg=None): """ - def inner(f): + def inner(f: Callable[..., _T], /) -> Callable[..., _T]: @wraps(f) - def wrapped_f(*args, **kwargs): + def wrapped_f(*args: object, **kwargs: object) -> object: return f(*args, xp=xp, **kwargs) sig = signature(f) new_sig = sig.replace( - parameters=[sig.parameters[i] for i in sig.parameters if i != "xp"] + parameters=[par for i, par in sig.parameters.items() if i != "xp"] ) if wrapped_f.__doc__ is None: @@ -40,7 +47,31 @@ def wrapped_f(*args, **kwargs): specification for more details. """ - wrapped_f.__signature__ = new_sig - return wrapped_f + wrapped_f.__signature__ = new_sig # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] + return wrapped_f # type: ignore[return-value] # pyright: ignore[reportReturnType] return inner + + +def clone_module(mod_name: str, globals_: dict[str, object]) -> list[str]: + """Import everything from module, updating globals(). + Returns __all__. + """ + mod = importlib.import_module(mod_name) + # Neither of these two methods is sufficient by itself, + # depending on various idiosyncrasies of the libraries we're wrapping. + objs = {} + exec(f"from {mod.__name__} import *", objs) + + for n in dir(mod): + if not n.startswith("_") and hasattr(mod, n): + objs[n] = getattr(mod, n) + + globals_.update(objs) + return list(objs) + + +__all__ = ["get_xp", "clone_module"] + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/common/__init__.py b/array_api_compat/common/__init__.py index 91ab1c40..82360807 100644 --- a/array_api_compat/common/__init__.py +++ b/array_api_compat/common/__init__.py @@ -1 +1 @@ -from ._helpers import * # noqa: F403 +from ._helpers import * # noqa: F403 diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 91c4d9a7..3587ef16 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -4,142 +4,172 @@ from __future__ import annotations -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from typing import Optional, Sequence, Tuple, Union - from ._typing import ndarray, Device, Dtype - -from typing import NamedTuple import inspect +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any, NamedTuple, cast + +from ._helpers import _check_device, array_namespace +from ._helpers import device as _get_device +from ._helpers import is_cupy_namespace +from ._typing import Array, Device, DType, Namespace -from ._helpers import array_namespace, _check_device, device, is_torch_array +if TYPE_CHECKING: + # TODO: import from typing (requires Python >=3.13) + from typing_extensions import TypeIs # These functions are modified from the NumPy versions. -# Creation functions add the device keyword (which does nothing for NumPy) +# Creation functions add the device keyword (which does nothing for NumPy and Dask) + def arange( - start: Union[int, float], + start: float, /, - stop: Optional[Union[int, float]] = None, - step: Union[int, float] = 1, + stop: float | None = None, + step: float = 1, *, - xp, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, - **kwargs -) -> ndarray: + xp: Namespace, + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, +) -> Array: _check_device(xp, device) return xp.arange(start, stop=stop, step=step, dtype=dtype, **kwargs) + def empty( - shape: Union[int, Tuple[int, ...]], - xp, + shape: int | tuple[int, ...], + xp: Namespace, *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, - **kwargs -) -> ndarray: + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, +) -> Array: _check_device(xp, device) return xp.empty(shape, dtype=dtype, **kwargs) + def empty_like( - x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, - **kwargs -) -> ndarray: + x: Array, + /, + xp: Namespace, + *, + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, +) -> Array: _check_device(xp, device) return xp.empty_like(x, dtype=dtype, **kwargs) + def eye( n_rows: int, - n_cols: Optional[int] = None, + n_cols: int | None = None, /, *, - xp, + xp: Namespace, k: int = 0, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, - **kwargs, -) -> ndarray: + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, +) -> Array: _check_device(xp, device) return xp.eye(n_rows, M=n_cols, k=k, dtype=dtype, **kwargs) + def full( - shape: Union[int, Tuple[int, ...]], - fill_value: Union[int, float], - xp, + shape: int | tuple[int, ...], + fill_value: complex, + xp: Namespace, *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, - **kwargs, -) -> ndarray: + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, +) -> Array: _check_device(xp, device) return xp.full(shape, fill_value, dtype=dtype, **kwargs) + def full_like( - x: ndarray, + x: Array, /, - fill_value: Union[int, float], + fill_value: complex, *, - xp, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, - **kwargs, -) -> ndarray: + xp: Namespace, + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, +) -> Array: _check_device(xp, device) return xp.full_like(x, fill_value, dtype=dtype, **kwargs) + def linspace( - start: Union[int, float], - stop: Union[int, float], + start: float, + stop: float, /, num: int, *, - xp, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, + xp: Namespace, + dtype: DType | None = None, + device: Device | None = None, endpoint: bool = True, - **kwargs, -) -> ndarray: + **kwargs: object, +) -> Array: _check_device(xp, device) return xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint, **kwargs) + def ones( - shape: Union[int, Tuple[int, ...]], - xp, + shape: int | tuple[int, ...], + xp: Namespace, *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, - **kwargs, -) -> ndarray: + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, +) -> Array: _check_device(xp, device) return xp.ones(shape, dtype=dtype, **kwargs) + def ones_like( - x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, - **kwargs, -) -> ndarray: + x: Array, + /, + xp: Namespace, + *, + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, +) -> Array: _check_device(xp, device) return xp.ones_like(x, dtype=dtype, **kwargs) + def zeros( - shape: Union[int, Tuple[int, ...]], - xp, + shape: int | tuple[int, ...], + xp: Namespace, *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, - **kwargs, -) -> ndarray: + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, +) -> Array: _check_device(xp, device) return xp.zeros(shape, dtype=dtype, **kwargs) + def zeros_like( - x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, - **kwargs, -) -> ndarray: + x: Array, + /, + xp: Namespace, + *, + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, +) -> Array: _check_device(xp, device) return xp.zeros_like(x, dtype=dtype, **kwargs) + # np.unique() is split into four functions in the array API: # unique_all, unique_counts, unique_inverse, and unique_values (this is done # to remove polymorphic return types). @@ -147,35 +177,37 @@ def zeros_like( # The functions here return namedtuples (np.unique() returns a normal # tuple). + # Note that these named tuples aren't actually part of the standard namespace, # but I don't see any issue with exporting the names here regardless. class UniqueAllResult(NamedTuple): - values: ndarray - indices: ndarray - inverse_indices: ndarray - counts: ndarray + values: Array + indices: Array + inverse_indices: Array + counts: Array class UniqueCountsResult(NamedTuple): - values: ndarray - counts: ndarray + values: Array + counts: Array class UniqueInverseResult(NamedTuple): - values: ndarray - inverse_indices: ndarray + values: Array + inverse_indices: Array -def _unique_kwargs(xp): +def _unique_kwargs(xp: Namespace) -> dict[str, bool]: # Older versions of NumPy and CuPy do not have equal_nan. Rather than # trying to parse version numbers, just check if equal_nan is in the # signature. s = inspect.signature(xp.unique) - if 'equal_nan' in s.parameters: - return {'equal_nan': False} + if "equal_nan" in s.parameters: + return {"equal_nan": False} return {} -def unique_all(x: ndarray, /, xp) -> UniqueAllResult: + +def unique_all(x: Array, /, xp: Namespace) -> UniqueAllResult: kwargs = _unique_kwargs(xp) values, indices, inverse_indices, counts = xp.unique( x, @@ -195,20 +227,16 @@ def unique_all(x: ndarray, /, xp) -> UniqueAllResult: ) -def unique_counts(x: ndarray, /, xp) -> UniqueCountsResult: +def unique_counts(x: Array, /, xp: Namespace) -> UniqueCountsResult: kwargs = _unique_kwargs(xp) res = xp.unique( - x, - return_counts=True, - return_index=False, - return_inverse=False, - **kwargs + x, return_counts=True, return_index=False, return_inverse=False, **kwargs ) return UniqueCountsResult(*res) -def unique_inverse(x: ndarray, /, xp) -> UniqueInverseResult: +def unique_inverse(x: Array, /, xp: Namespace) -> UniqueInverseResult: kwargs = _unique_kwargs(xp) values, inverse_indices = xp.unique( x, @@ -223,7 +251,7 @@ def unique_inverse(x: ndarray, /, xp) -> UniqueInverseResult: return UniqueInverseResult(values, inverse_indices) -def unique_values(x: ndarray, /, xp) -> ndarray: +def unique_values(x: Array, /, xp: Namespace) -> Array: kwargs = _unique_kwargs(xp) return xp.unique( x, @@ -233,56 +261,58 @@ def unique_values(x: ndarray, /, xp) -> ndarray: **kwargs, ) -def astype(x: ndarray, dtype: Dtype, /, *, copy: bool = True) -> ndarray: - if not copy and dtype == x.dtype: - return x - return x.astype(dtype=dtype, copy=copy) # These functions have different keyword argument names + def std( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - correction: Union[int, float] = 0.0, # correction instead of ddof + axis: int | tuple[int, ...] | None = None, + correction: float = 0.0, # correction instead of ddof keepdims: bool = False, - **kwargs, -) -> ndarray: + **kwargs: object, +) -> Array: return xp.std(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs) + def var( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - correction: Union[int, float] = 0.0, # correction instead of ddof + axis: int | tuple[int, ...] | None = None, + correction: float = 0.0, # correction instead of ddof keepdims: bool = False, - **kwargs, -) -> ndarray: + **kwargs: object, +) -> Array: return xp.var(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs) + # cumulative_sum is renamed from cumsum, and adds the include_initial keyword # argument + def cumulative_sum( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, - axis: Optional[int] = None, - dtype: Optional[Dtype] = None, + axis: int | None = None, + dtype: DType | None = None, include_initial: bool = False, - **kwargs -) -> ndarray: + **kwargs: object, +) -> Array: wrapped_xp = array_namespace(x) # TODO: The standard is not clear about what should happen when x.ndim == 0. if axis is None: if x.ndim > 1: - raise ValueError("axis must be specified in cumulative_sum for more than one dimension") + raise ValueError( + "axis must be specified in cumulative_sum for more than one dimension" + ) axis = 0 res = xp.cumsum(x, axis=axis, dtype=dtype, **kwargs) @@ -292,25 +322,69 @@ def cumulative_sum( initial_shape = list(x.shape) initial_shape[axis] = 1 res = xp.concatenate( - [wrapped_xp.zeros(shape=initial_shape, dtype=res.dtype, device=device(res)), res], + [ + wrapped_xp.zeros( + shape=initial_shape, dtype=res.dtype, device=_get_device(res) + ), + res, + ], axis=axis, ) return res + +def cumulative_prod( + x: Array, + /, + xp: Namespace, + *, + axis: int | None = None, + dtype: DType | None = None, + include_initial: bool = False, + **kwargs: object, +) -> Array: + wrapped_xp = array_namespace(x) + + if axis is None: + if x.ndim > 1: + raise ValueError( + "axis must be specified in cumulative_prod for more than one dimension" + ) + axis = 0 + + res = xp.cumprod(x, axis=axis, dtype=dtype, **kwargs) + + # np.cumprod does not support include_initial + if include_initial: + initial_shape = list(x.shape) + initial_shape[axis] = 1 + res = xp.concatenate( + [ + wrapped_xp.ones( + shape=initial_shape, dtype=res.dtype, device=_get_device(res) + ), + res, + ], + axis=axis, + ) + return res + + # The min and max argument names in clip are different and not optional in numpy, and type # promotion behavior is different. def clip( - x: ndarray, + x: Array, /, - min: Optional[Union[int, float, ndarray]] = None, - max: Optional[Union[int, float, ndarray]] = None, + min: float | Array | None = None, + max: float | Array | None = None, *, - xp, + xp: Namespace, # TODO: np.clip has other ufunc kwargs - out: Optional[ndarray] = None, -) -> ndarray: - def _isscalar(a): - return isinstance(a, (int, float, type(None))) + out: Array | None = None, +) -> Array: + def _isscalar(a: object) -> TypeIs[float | None]: + return isinstance(a, int | float) or a is None + min_shape = () if _isscalar(min) else min.shape max_shape = () if _isscalar(max) else max.shape @@ -335,44 +409,51 @@ def _isscalar(a): # but an answer of 0 might be preferred. See # https://github.com/numpy/numpy/issues/24976 for more discussion on this issue. - # At least handle the case of Python integers correctly (see # https://github.com/numpy/numpy/pull/26892). - if type(min) is int and min <= wrapped_xp.iinfo(x.dtype).min: - min = None - if type(max) is int and max >= wrapped_xp.iinfo(x.dtype).max: - max = None + if wrapped_xp.isdtype(x.dtype, "integral"): + if type(min) is int and min <= wrapped_xp.iinfo(x.dtype).min: + min = None + if type(max) is int and max >= wrapped_xp.iinfo(x.dtype).max: + max = None + dev = _get_device(x) if out is None: - out = wrapped_xp.asarray(xp.broadcast_to(x, result_shape), - copy=True, device=device(x)) + out = wrapped_xp.empty(result_shape, dtype=x.dtype, device=dev) + assert out is not None # workaround for a type-narrowing issue in pyright + out[()] = x + if min is not None: - if is_torch_array(x) and x.dtype == xp.float64 and _isscalar(min): - # Avoid loss of precision due to torch defaulting to float32 - min = wrapped_xp.asarray(min, dtype=xp.float64) - a = xp.broadcast_to(wrapped_xp.asarray(min, device=device(x)), result_shape) + a = wrapped_xp.asarray(min, dtype=x.dtype, device=dev) + a = xp.broadcast_to(a, result_shape) ia = (out < a) | xp.isnan(a) - # torch requires an explicit cast here - out[ia] = wrapped_xp.astype(a[ia], out.dtype) + out[ia] = a[ia] + if max is not None: - if is_torch_array(x) and x.dtype == xp.float64 and _isscalar(max): - max = wrapped_xp.asarray(max, dtype=xp.float64) - b = xp.broadcast_to(wrapped_xp.asarray(max, device=device(x)), result_shape) + b = wrapped_xp.asarray(max, dtype=x.dtype, device=dev) + b = xp.broadcast_to(b, result_shape) ib = (out > b) | xp.isnan(b) - out[ib] = wrapped_xp.astype(b[ib], out.dtype) + out[ib] = b[ib] + # Return a scalar for 0-D return out[()] + # Unlike transpose(), the axes argument to permute_dims() is required. -def permute_dims(x: ndarray, /, axes: Tuple[int, ...], xp) -> ndarray: +def permute_dims(x: Array, /, axes: tuple[int, ...], xp: Namespace) -> Array: return xp.transpose(x, axes) + # np.reshape calls the keyword argument 'newshape' instead of 'shape' -def reshape(x: ndarray, - /, - shape: Tuple[int, ...], - xp, copy: Optional[bool] = None, - **kwargs) -> ndarray: +def reshape( + x: Array, + /, + shape: tuple[int, ...], + xp: Namespace, + *, + copy: bool | None = None, + **kwargs: object, +) -> Array: if copy is True: x = x.copy() elif copy is False: @@ -381,17 +462,24 @@ def reshape(x: ndarray, return y return xp.reshape(x, shape, **kwargs) + # The descending keyword is new in sort and argsort, and 'kind' replaced with # 'stable' def argsort( - x: ndarray, /, xp, *, axis: int = -1, descending: bool = False, stable: bool = True, - **kwargs, -) -> ndarray: + x: Array, + /, + xp: Namespace, + *, + axis: int = -1, + descending: bool = False, + stable: bool = True, + **kwargs: object, +) -> Array: # Note: this keyword argument is different, and the default is different. # We set it in kwargs like this because numpy.sort uses kind='quicksort' # as the default whereas cupy.sort uses kind=None. if stable: - kwargs['kind'] = "stable" + kwargs["kind"] = "stable" if not descending: res = xp.argsort(x, axis=axis, **kwargs) else: @@ -408,69 +496,66 @@ def argsort( res = max_i - res return res + def sort( - x: ndarray, /, xp, *, axis: int = -1, descending: bool = False, stable: bool = True, - **kwargs, -) -> ndarray: + x: Array, + /, + xp: Namespace, + *, + axis: int = -1, + descending: bool = False, + stable: bool = True, + **kwargs: object, +) -> Array: # Note: this keyword argument is different, and the default is different. # We set it in kwargs like this because numpy.sort uses kind='quicksort' # as the default whereas cupy.sort uses kind=None. if stable: - kwargs['kind'] = "stable" + kwargs["kind"] = "stable" res = xp.sort(x, axis=axis, **kwargs) if descending: res = xp.flip(res, axis=axis) return res + # nonzero should error for zero-dimensional arrays -def nonzero(x: ndarray, /, xp, **kwargs) -> Tuple[ndarray, ...]: +def nonzero(x: Array, /, xp: Namespace, **kwargs: object) -> tuple[Array, ...]: if x.ndim == 0: raise ValueError("nonzero() does not support zero-dimensional arrays") return xp.nonzero(x, **kwargs) -# ceil, floor, and trunc return integers for integer inputs - -def ceil(x: ndarray, /, xp, **kwargs) -> ndarray: - if xp.issubdtype(x.dtype, xp.integer): - return x - return xp.ceil(x, **kwargs) - -def floor(x: ndarray, /, xp, **kwargs) -> ndarray: - if xp.issubdtype(x.dtype, xp.integer): - return x - return xp.floor(x, **kwargs) - -def trunc(x: ndarray, /, xp, **kwargs) -> ndarray: - if xp.issubdtype(x.dtype, xp.integer): - return x - return xp.trunc(x, **kwargs) # linear algebra functions -def matmul(x1: ndarray, x2: ndarray, /, xp, **kwargs) -> ndarray: + +def matmul(x1: Array, x2: Array, /, xp: Namespace, **kwargs: object) -> Array: return xp.matmul(x1, x2, **kwargs) + # Unlike transpose, matrix_transpose only transposes the last two axes. -def matrix_transpose(x: ndarray, /, xp) -> ndarray: +def matrix_transpose(x: Array, /, xp: Namespace) -> Array: if x.ndim < 2: raise ValueError("x must be at least 2-dimensional for matrix_transpose") return xp.swapaxes(x, -1, -2) -def tensordot(x1: ndarray, - x2: ndarray, - /, - xp, - *, - axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2, - **kwargs, -) -> ndarray: + +def tensordot( + x1: Array, + x2: Array, + /, + xp: Namespace, + *, + axes: int | tuple[Sequence[int], Sequence[int]] = 2, + **kwargs: object, +) -> Array: return xp.tensordot(x1, x2, axes=axes, **kwargs) -def vecdot(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1) -> ndarray: + +def vecdot(x1: Array, x2: Array, /, xp: Namespace, *, axis: int = -1) -> Array: if x1.shape[axis] != x2.shape[axis]: raise ValueError("x1 and x2 must have the same size along the given axis") - if hasattr(xp, 'broadcast_tensors'): + if hasattr(xp, "broadcast_tensors"): _broadcast = xp.broadcast_tensors else: _broadcast = xp.broadcast_arrays @@ -479,14 +564,19 @@ def vecdot(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1) -> ndarray: x2_ = xp.moveaxis(x2, axis, -1) x1_, x2_ = _broadcast(x1_, x2_) - res = x1_[..., None, :] @ x2_[..., None] + res = xp.conj(x1_[..., None, :]) @ x2_[..., None] return res[..., 0, 0] + # isdtype is a new function in the 2022.12 array API specification. + def isdtype( - dtype: Dtype, kind: Union[Dtype, str, Tuple[Union[Dtype, str], ...]], xp, - *, _tuple=True, # Disallow nested tuples + dtype: DType, + kind: DType | str | tuple[DType | str, ...], + xp: Namespace, + *, + _tuple: bool = True, # Disallow nested tuples ) -> bool: """ Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``. @@ -499,21 +589,24 @@ def isdtype( for more details """ if isinstance(kind, tuple) and _tuple: - return any(isdtype(dtype, k, xp, _tuple=False) for k in kind) + return any( + isdtype(dtype, k, xp, _tuple=False) + for k in cast("tuple[DType | str, ...]", kind) + ) elif isinstance(kind, str): - if kind == 'bool': + if kind == "bool": return dtype == xp.bool_ - elif kind == 'signed integer': + elif kind == "signed integer": return xp.issubdtype(dtype, xp.signedinteger) - elif kind == 'unsigned integer': + elif kind == "unsigned integer": return xp.issubdtype(dtype, xp.unsignedinteger) - elif kind == 'integral': + elif kind == "integral": return xp.issubdtype(dtype, xp.integer) - elif kind == 'real floating': + elif kind == "real floating": return xp.issubdtype(dtype, xp.floating) - elif kind == 'complex floating': + elif kind == "complex floating": return xp.issubdtype(dtype, xp.complexfloating) - elif kind == 'numeric': + elif kind == "numeric": return xp.issubdtype(dtype, xp.number) else: raise ValueError(f"Unrecognized data type kind: {kind!r}") @@ -524,17 +617,86 @@ def isdtype( # array_api_strict implementation will be very strict. return dtype == kind + # unstack is a new function in the 2023.12 array API standard -def unstack(x: ndarray, /, xp, *, axis: int = 0) -> Tuple[ndarray, ...]: +def unstack(x: Array, /, xp: Namespace, *, axis: int = 0) -> tuple[Array, ...]: if x.ndim == 0: raise ValueError("Input array must be at least 1-d.") return tuple(xp.moveaxis(x, axis, 0)) -__all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like', - 'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like', - 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult', - 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values', - 'astype', 'std', 'var', 'cumulative_sum', 'clip', 'permute_dims', - 'reshape', 'argsort', 'sort', 'nonzero', 'ceil', 'floor', 'trunc', - 'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype', - 'unstack'] + +# numpy 1.26 does not use the standard definition for sign on complex numbers + + +def sign(x: Array, /, xp: Namespace, **kwargs: object) -> Array: + if isdtype(x.dtype, "complex floating", xp=xp): + out = (x / xp.abs(x, **kwargs))[...] + # sign(0) = 0 but the above formula would give nan + out[x == 0j] = 0j + else: + out = xp.sign(x, **kwargs) + # CuPy sign() does not propagate nans. See + # https://github.com/data-apis/array-api-compat/issues/136 + if is_cupy_namespace(xp) and isdtype(x.dtype, "real floating", xp=xp): + out[xp.isnan(x)] = xp.nan + return out[()] + + +def finfo(type_: DType | Array, /, xp: Namespace) -> Any: + # It is surprisingly difficult to recognize a dtype apart from an array. + # np.int64 is not the same as np.asarray(1).dtype! + try: + return xp.finfo(type_) + except (ValueError, TypeError): + return xp.finfo(type_.dtype) + + +def iinfo(type_: DType | Array, /, xp: Namespace) -> Any: + try: + return xp.iinfo(type_) + except (ValueError, TypeError): + return xp.iinfo(type_.dtype) + + +__all__ = [ + "arange", + "empty", + "empty_like", + "eye", + "full", + "full_like", + "linspace", + "ones", + "ones_like", + "zeros", + "zeros_like", + "UniqueAllResult", + "UniqueCountsResult", + "UniqueInverseResult", + "unique_all", + "unique_counts", + "unique_inverse", + "unique_values", + "std", + "var", + "cumulative_sum", + "cumulative_prod", + "clip", + "permute_dims", + "reshape", + "argsort", + "sort", + "nonzero", + "matmul", + "matrix_transpose", + "tensordot", + "vecdot", + "isdtype", + "unstack", + "sign", + "finfo", + "iinfo", +] + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/common/_fft.py b/array_api_compat/common/_fft.py index 666b0b1f..18839d37 100644 --- a/array_api_compat/common/_fft.py +++ b/array_api_compat/common/_fft.py @@ -1,168 +1,195 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Union, Optional, Literal +from collections.abc import Sequence +from typing import Literal, TypeAlias -if TYPE_CHECKING: - from ._typing import Device, ndarray - from collections.abc import Sequence +from ._typing import Array, Device, DType, Namespace + +_Norm: TypeAlias = Literal["backward", "ortho", "forward"] # Note: NumPy fft functions improperly upcast float32 and complex64 to # complex128, which is why we require wrapping them all here. def fft( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, - n: Optional[int] = None, + n: int | None = None, axis: int = -1, - norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: + norm: _Norm = "backward", +) -> Array: res = xp.fft.fft(x, n=n, axis=axis, norm=norm) if x.dtype in [xp.float32, xp.complex64]: return res.astype(xp.complex64) return res def ifft( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, - n: Optional[int] = None, + n: int | None = None, axis: int = -1, - norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: + norm: _Norm = "backward", +) -> Array: res = xp.fft.ifft(x, n=n, axis=axis, norm=norm) if x.dtype in [xp.float32, xp.complex64]: return res.astype(xp.complex64) return res def fftn( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, - s: Sequence[int] = None, - axes: Sequence[int] = None, - norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: + s: Sequence[int] | None = None, + axes: Sequence[int] | None = None, + norm: _Norm = "backward", +) -> Array: res = xp.fft.fftn(x, s=s, axes=axes, norm=norm) if x.dtype in [xp.float32, xp.complex64]: return res.astype(xp.complex64) return res def ifftn( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, - s: Sequence[int] = None, - axes: Sequence[int] = None, - norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: + s: Sequence[int] | None = None, + axes: Sequence[int] | None = None, + norm: _Norm = "backward", +) -> Array: res = xp.fft.ifftn(x, s=s, axes=axes, norm=norm) if x.dtype in [xp.float32, xp.complex64]: return res.astype(xp.complex64) return res def rfft( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, - n: Optional[int] = None, + n: int | None = None, axis: int = -1, - norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: + norm: _Norm = "backward", +) -> Array: res = xp.fft.rfft(x, n=n, axis=axis, norm=norm) if x.dtype == xp.float32: return res.astype(xp.complex64) return res def irfft( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, - n: Optional[int] = None, + n: int | None = None, axis: int = -1, - norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: + norm: _Norm = "backward", +) -> Array: res = xp.fft.irfft(x, n=n, axis=axis, norm=norm) if x.dtype == xp.complex64: return res.astype(xp.float32) return res def rfftn( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, - s: Sequence[int] = None, - axes: Sequence[int] = None, - norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: + s: Sequence[int] | None = None, + axes: Sequence[int] | None = None, + norm: _Norm = "backward", +) -> Array: res = xp.fft.rfftn(x, s=s, axes=axes, norm=norm) if x.dtype == xp.float32: return res.astype(xp.complex64) return res def irfftn( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, - s: Sequence[int] = None, - axes: Sequence[int] = None, - norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: + s: Sequence[int] | None = None, + axes: Sequence[int] | None = None, + norm: _Norm = "backward", +) -> Array: res = xp.fft.irfftn(x, s=s, axes=axes, norm=norm) if x.dtype == xp.complex64: return res.astype(xp.float32) return res def hfft( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, - n: Optional[int] = None, + n: int | None = None, axis: int = -1, - norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: + norm: _Norm = "backward", +) -> Array: res = xp.fft.hfft(x, n=n, axis=axis, norm=norm) if x.dtype in [xp.float32, xp.complex64]: return res.astype(xp.float32) return res def ihfft( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, - n: Optional[int] = None, + n: int | None = None, axis: int = -1, - norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: + norm: _Norm = "backward", +) -> Array: res = xp.fft.ihfft(x, n=n, axis=axis, norm=norm) if x.dtype in [xp.float32, xp.complex64]: return res.astype(xp.complex64) return res -def fftfreq(n: int, /, xp, *, d: float = 1.0, device: Optional[Device] = None) -> ndarray: +def fftfreq( + n: int, + /, + xp: Namespace, + *, + d: float = 1.0, + dtype: DType | None = None, + device: Device | None = None, +) -> Array: if device not in ["cpu", None]: raise ValueError(f"Unsupported device {device!r}") - return xp.fft.fftfreq(n, d=d) + res = xp.fft.fftfreq(n, d=d) + if dtype is not None: + return res.astype(dtype) + return res -def rfftfreq(n: int, /, xp, *, d: float = 1.0, device: Optional[Device] = None) -> ndarray: +def rfftfreq( + n: int, + /, + xp: Namespace, + *, + d: float = 1.0, + dtype: DType | None = None, + device: Device | None = None, +) -> Array: if device not in ["cpu", None]: raise ValueError(f"Unsupported device {device!r}") - return xp.fft.rfftfreq(n, d=d) + res = xp.fft.rfftfreq(n, d=d) + if dtype is not None: + return res.astype(dtype) + return res -def fftshift(x: ndarray, /, xp, *, axes: Union[int, Sequence[int]] = None) -> ndarray: +def fftshift( + x: Array, /, xp: Namespace, *, axes: int | Sequence[int] | None = None +) -> Array: return xp.fft.fftshift(x, axes=axes) -def ifftshift(x: ndarray, /, xp, *, axes: Union[int, Sequence[int]] = None) -> ndarray: +def ifftshift( + x: Array, /, xp: Namespace, *, axes: int | Sequence[int] | None = None +) -> Array: return xp.fft.ifftshift(x, axes=axes) __all__ = [ @@ -181,3 +208,6 @@ def ifftshift(x: ndarray, /, xp, *, axes: Union[int, Sequence[int]] = None) -> n "fftshift", "ifftshift", ] + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 2467793c..37f31ec2 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -5,34 +5,94 @@ that are in __all__ are intended as additional helper functions for use by end users of the compat library. """ + from __future__ import annotations -from typing import TYPE_CHECKING +import enum +import inspect +import math +import sys +import warnings +from collections.abc import Collection, Hashable +from functools import lru_cache +from typing import ( + TYPE_CHECKING, + Any, + Final, + Literal, + SupportsIndex, + TypeAlias, + TypeGuard, + cast, + overload, +) + +from ._typing import Array, Device, HasShape, Namespace, SupportsArrayNamespace if TYPE_CHECKING: - from typing import Optional, Union, Any - from ._typing import Array, Device + import cupy as cp + import dask.array as da + import jax + import ndonnx as ndx + import numpy as np + import numpy.typing as npt + import sparse + import torch -import sys -import math -import inspect -import warnings + # TODO: import from typing (requires Python >=3.13) + from typing_extensions import TypeIs + + _ZeroGradientArray: TypeAlias = npt.NDArray[np.void] -def _is_jax_zero_gradient_array(x): + _ArrayApiObj: TypeAlias = ( + npt.NDArray[Any] + | cp.ndarray + | da.Array + | jax.Array + | ndx.Array + | sparse.SparseArray + | torch.Tensor + | SupportsArrayNamespace[Any] + ) + +_API_VERSIONS_OLD: Final = frozenset({"2021.12", "2022.12", "2023.12"}) +_API_VERSIONS: Final = _API_VERSIONS_OLD | frozenset({"2024.12"}) + + +@lru_cache(100) +def _issubclass_fast(cls: type, modname: str, clsname: str) -> bool: + try: + mod = sys.modules[modname] + except KeyError: + return False + parent_cls = getattr(mod, clsname) + return issubclass(cls, parent_cls) + + +def _is_jax_zero_gradient_array(x: object) -> TypeGuard[_ZeroGradientArray]: """Return True if `x` is a zero-gradient array. These arrays are a design quirk of Jax that may one day be removed. See https://github.com/google/jax/issues/20620. """ - if 'numpy' not in sys.modules or 'jax' not in sys.modules: + # Fast exit + try: + dtype = x.dtype # type: ignore[attr-defined] + except AttributeError: + return False + cls = cast(Hashable, type(dtype)) + if not _issubclass_fast(cls, "numpy.dtypes", "VoidDType"): + return False + + if "jax" not in sys.modules: return False - import numpy as np import jax + # jax.float0 is a np.dtype([('float0', 'V')]) + return dtype == jax.float0 - return isinstance(x, np.ndarray) and x.dtype == jax.float0 -def is_numpy_array(x): +def is_numpy_array(x: object) -> TypeIs[npt.NDArray[Any]]: """ Return True if `x` is a NumPy array. @@ -53,17 +113,15 @@ def is_numpy_array(x): is_jax_array is_pydata_sparse_array """ - # Avoid importing NumPy if it isn't already - if 'numpy' not in sys.modules: - return False - - import numpy as np - # TODO: Should we reject ndarray subclasses? - return (isinstance(x, (np.ndarray, np.generic)) - and not _is_jax_zero_gradient_array(x)) + cls = cast(Hashable, type(x)) + return ( + _issubclass_fast(cls, "numpy", "ndarray") + or _issubclass_fast(cls, "numpy", "generic") + ) and not _is_jax_zero_gradient_array(x) + -def is_cupy_array(x): +def is_cupy_array(x: object) -> bool: """ Return True if `x` is a CuPy array. @@ -84,16 +142,11 @@ def is_cupy_array(x): is_jax_array is_pydata_sparse_array """ - # Avoid importing CuPy if it isn't already - if 'cupy' not in sys.modules: - return False - - import cupy as cp + cls = cast(Hashable, type(x)) + return _issubclass_fast(cls, "cupy", "ndarray") - # TODO: Should we reject ndarray subclasses? - return isinstance(x, (cp.ndarray, cp.generic)) -def is_torch_array(x): +def is_torch_array(x: object) -> TypeIs[torch.Tensor]: """ Return True if `x` is a PyTorch tensor. @@ -111,16 +164,11 @@ def is_torch_array(x): is_jax_array is_pydata_sparse_array """ - # Avoid importing torch if it isn't already - if 'torch' not in sys.modules: - return False - - import torch + cls = cast(Hashable, type(x)) + return _issubclass_fast(cls, "torch", "Tensor") - # TODO: Should we reject ndarray subclasses? - return isinstance(x, torch.Tensor) -def is_ndonnx_array(x): +def is_ndonnx_array(x: object) -> TypeIs[ndx.Array]: """ Return True if `x` is a ndonnx Array. @@ -139,15 +187,11 @@ def is_ndonnx_array(x): is_jax_array is_pydata_sparse_array """ - # Avoid importing torch if it isn't already - if 'ndonnx' not in sys.modules: - return False - - import ndonnx as ndx + cls = cast(Hashable, type(x)) + return _issubclass_fast(cls, "ndonnx", "Array") - return isinstance(x, ndx.Array) -def is_dask_array(x): +def is_dask_array(x: object) -> TypeIs[da.Array]: """ Return True if `x` is a dask.array Array. @@ -166,15 +210,11 @@ def is_dask_array(x): is_jax_array is_pydata_sparse_array """ - # Avoid importing dask if it isn't already - if 'dask.array' not in sys.modules: - return False + cls = cast(Hashable, type(x)) + return _issubclass_fast(cls, "dask.array", "Array") - import dask.array - return isinstance(x, dask.array.Array) - -def is_jax_array(x): +def is_jax_array(x: object) -> TypeIs[jax.Array]: """ Return True if `x` is a JAX array. @@ -194,15 +234,11 @@ def is_jax_array(x): is_dask_array is_pydata_sparse_array """ - # Avoid importing jax if it isn't already - if 'jax' not in sys.modules: - return False - - import jax + cls = cast(Hashable, type(x)) + return _issubclass_fast(cls, "jax", "Array") or _is_jax_zero_gradient_array(x) - return isinstance(x, jax.Array) or _is_jax_zero_gradient_array(x) -def is_pydata_sparse_array(x) -> bool: +def is_pydata_sparse_array(x: object) -> TypeIs[sparse.SparseArray]: """ Return True if `x` is an array from the `sparse` package. @@ -222,16 +258,12 @@ def is_pydata_sparse_array(x) -> bool: is_dask_array is_jax_array """ - # Avoid importing jax if it isn't already - if 'sparse' not in sys.modules: - return False - - import sparse - # TODO: Account for other backends. - return isinstance(x, sparse.SparseArray) + cls = cast(Hashable, type(x)) + return _issubclass_fast(cls, "sparse", "SparseArray") + -def is_array_api_obj(x): +def is_array_api_obj(x: object) -> TypeGuard[_ArrayApiObj]: """ Return True if `x` is an array API compatible array object. @@ -246,19 +278,34 @@ def is_array_api_obj(x): is_dask_array is_jax_array """ - return is_numpy_array(x) \ - or is_cupy_array(x) \ - or is_torch_array(x) \ - or is_dask_array(x) \ - or is_jax_array(x) \ - or is_pydata_sparse_array(x) \ - or hasattr(x, '__array_namespace__') - -def _compat_module_name(): - assert __name__.endswith('.common._helpers') - return __name__.removesuffix('.common._helpers') - -def is_numpy_namespace(xp) -> bool: + return ( + hasattr(x, '__array_namespace__') + or _is_array_api_cls(cast(Hashable, type(x))) + ) + + +@lru_cache(100) +def _is_array_api_cls(cls: type) -> bool: + return ( + # TODO: drop support for numpy<2 which didn't have __array_namespace__ + _issubclass_fast(cls, "numpy", "ndarray") + or _issubclass_fast(cls, "numpy", "generic") + or _issubclass_fast(cls, "cupy", "ndarray") + or _issubclass_fast(cls, "torch", "Tensor") + or _issubclass_fast(cls, "dask.array", "Array") + or _issubclass_fast(cls, "sparse", "SparseArray") + # TODO: drop support for jax<0.4.32 which didn't have __array_namespace__ + or _issubclass_fast(cls, "jax", "Array") + ) + + +def _compat_module_name() -> str: + assert __name__.endswith(".common._helpers") + return __name__.removesuffix(".common._helpers") + + +@lru_cache(100) +def is_numpy_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is a NumPy namespace. @@ -276,9 +323,11 @@ def is_numpy_namespace(xp) -> bool: is_pydata_sparse_namespace is_array_api_strict_namespace """ - return xp.__name__ in {'numpy', _compat_module_name() + '.numpy'} + return xp.__name__ in {"numpy", _compat_module_name() + ".numpy"} -def is_cupy_namespace(xp) -> bool: + +@lru_cache(100) +def is_cupy_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is a CuPy namespace. @@ -296,9 +345,11 @@ def is_cupy_namespace(xp) -> bool: is_pydata_sparse_namespace is_array_api_strict_namespace """ - return xp.__name__ in {'cupy', _compat_module_name() + '.cupy'} + return xp.__name__ in {"cupy", _compat_module_name() + ".cupy"} + -def is_torch_namespace(xp) -> bool: +@lru_cache(100) +def is_torch_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is a PyTorch namespace. @@ -316,10 +367,10 @@ def is_torch_namespace(xp) -> bool: is_pydata_sparse_namespace is_array_api_strict_namespace """ - return xp.__name__ in {'torch', _compat_module_name() + '.torch'} - + return xp.__name__ in {"torch", _compat_module_name() + ".torch"} -def is_ndonnx_namespace(xp): + +def is_ndonnx_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is an NDONNX namespace. @@ -335,9 +386,11 @@ def is_ndonnx_namespace(xp): is_pydata_sparse_namespace is_array_api_strict_namespace """ - return xp.__name__ == 'ndonnx' + return xp.__name__ == "ndonnx" + -def is_dask_namespace(xp): +@lru_cache(100) +def is_dask_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is a Dask namespace. @@ -355,9 +408,10 @@ def is_dask_namespace(xp): is_pydata_sparse_namespace is_array_api_strict_namespace """ - return xp.__name__ in {'dask.array', _compat_module_name() + '.dask.array'} + return xp.__name__ in {"dask.array", _compat_module_name() + ".dask.array"} -def is_jax_namespace(xp): + +def is_jax_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is a JAX namespace. @@ -376,9 +430,10 @@ def is_jax_namespace(xp): is_pydata_sparse_namespace is_array_api_strict_namespace """ - return xp.__name__ in {'jax.numpy', 'jax.experimental.array_api'} + return xp.__name__ in {"jax.numpy", "jax.experimental.array_api"} + -def is_pydata_sparse_namespace(xp): +def is_pydata_sparse_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is a pydata/sparse namespace. @@ -394,9 +449,10 @@ def is_pydata_sparse_namespace(xp): is_jax_namespace is_array_api_strict_namespace """ - return xp.__name__ == 'sparse' + return xp.__name__ == "sparse" + -def is_array_api_strict_namespace(xp): +def is_array_api_strict_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is an array-api-strict namespace. @@ -412,26 +468,117 @@ def is_array_api_strict_namespace(xp): is_jax_namespace is_pydata_sparse_namespace """ - return xp.__name__ == 'array_api_strict' + return xp.__name__ == "array_api_strict" -def _check_api_version(api_version): - if api_version == '2021.12': - warnings.warn("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12") - elif api_version is not None and api_version != '2022.12': - raise ValueError("Only the 2022.12 version of the array API specification is currently supported") -def array_namespace(*xs, api_version=None, use_compat=None): +def _check_api_version(api_version: str | None) -> None: + if api_version in _API_VERSIONS_OLD: + warnings.warn( + f"The {api_version} version of the array API specification was requested but the returned namespace is actually version 2024.12" + ) + elif api_version is not None and api_version not in _API_VERSIONS: + raise ValueError( + "Only the 2024.12 version of the array API specification is currently supported" + ) + + +class _ClsToXPInfo(enum.Enum): + SCALAR = 0 + MAYBE_JAX_ZERO_GRADIENT = 1 + + +@lru_cache(100) +def _cls_to_namespace( + cls: type, + api_version: str | None, + use_compat: bool | None, +) -> tuple[Namespace | None, _ClsToXPInfo | None]: + if use_compat not in (None, True, False): + raise ValueError("use_compat must be None, True, or False") + _use_compat = use_compat in (None, True) + cls_ = cast(Hashable, cls) # Make mypy happy + + if ( + _issubclass_fast(cls_, "numpy", "ndarray") + or _issubclass_fast(cls_, "numpy", "generic") + ): + if use_compat is True: + _check_api_version(api_version) + from .. import numpy as xp + elif use_compat is False: + import numpy as xp # type: ignore[no-redef] + else: + # NumPy 2.0+ have __array_namespace__; however they are not + # yet fully array API compatible. + from .. import numpy as xp # type: ignore[no-redef] + return xp, _ClsToXPInfo.MAYBE_JAX_ZERO_GRADIENT + + # Note: this must happen _after_ the test for np.generic, + # because np.float64 and np.complex128 are subclasses of float and complex. + if issubclass(cls, int | float | complex | type(None)): + return None, _ClsToXPInfo.SCALAR + + if _issubclass_fast(cls_, "cupy", "ndarray"): + if _use_compat: + _check_api_version(api_version) + from .. import cupy as xp # type: ignore[no-redef] + else: + import cupy as xp # type: ignore[no-redef] + return xp, None + + if _issubclass_fast(cls_, "torch", "Tensor"): + if _use_compat: + _check_api_version(api_version) + from .. import torch as xp # type: ignore[no-redef] + else: + import torch as xp # type: ignore[no-redef] + return xp, None + + if _issubclass_fast(cls_, "dask.array", "Array"): + if _use_compat: + _check_api_version(api_version) + from ..dask import array as xp # type: ignore[no-redef] + else: + import dask.array as xp # type: ignore[no-redef] + return xp, None + + # Backwards compatibility for jax<0.4.32 + if _issubclass_fast(cls_, "jax", "Array"): + return _jax_namespace(api_version, use_compat), None + + return None, None + + +def _jax_namespace(api_version: str | None, use_compat: bool | None) -> Namespace: + if use_compat: + raise ValueError("JAX does not have an array-api-compat wrapper") + import jax.numpy as jnp + if not hasattr(jnp, "__array_namespace_info__"): + # JAX v0.4.32 and newer implements the array API directly in jax.numpy. + # For older JAX versions, it is available via jax.experimental.array_api. + # jnp.Array objects gain the __array_namespace__ method. + import jax.experimental.array_api # noqa: F401 + # Test api_version + return jnp.empty(0).__array_namespace__(api_version=api_version) + + +def array_namespace( + *xs: Array | complex | None, + api_version: str | None = None, + use_compat: bool | None = None, +) -> Namespace: """ Get the array API compatible namespace for the arrays `xs`. Parameters ---------- xs: arrays - one or more arrays. + one or more arrays. xs can also be Python scalars (bool, int, float, + complex, or None), which are ignored. api_version: str The newest version of the spec that you need support for (currently - the compat library wrapped APIs support v2022.12). + the compat library wrapped APIs support v2024.12). use_compat: bool or None If None (the default), the native namespace will be returned if it is @@ -483,115 +630,85 @@ def your_function(x, y): is_pydata_sparse_array """ - if use_compat not in [None, True, False]: - raise ValueError("use_compat must be None, True, or False") - - _use_compat = use_compat in [None, True] - - namespaces = set() + namespaces: set[Namespace] = set() for x in xs: - if is_numpy_array(x): - from .. import numpy as numpy_namespace - import numpy as np - if use_compat is True: - _check_api_version(api_version) - namespaces.add(numpy_namespace) - elif use_compat is False: - namespaces.add(np) - else: - # numpy 2.0+ have __array_namespace__, however, they are not yet fully array API - # compatible. - namespaces.add(numpy_namespace) - elif is_cupy_array(x): - if _use_compat: - _check_api_version(api_version) - from .. import cupy as cupy_namespace - namespaces.add(cupy_namespace) - else: - import cupy as cp - namespaces.add(cp) - elif is_torch_array(x): - if _use_compat: - _check_api_version(api_version) - from .. import torch as torch_namespace - namespaces.add(torch_namespace) - else: - import torch - namespaces.add(torch) - elif is_dask_array(x): - if _use_compat: - _check_api_version(api_version) - from ..dask import array as dask_namespace - namespaces.add(dask_namespace) - else: - import dask.array as da - namespaces.add(da) - elif is_jax_array(x): - if use_compat is True: - _check_api_version(api_version) - raise ValueError("JAX does not have an array-api-compat wrapper") - elif use_compat is False: - import jax.numpy as jnp - else: - # JAX v0.4.32 and newer implements the array API directly in jax.numpy. - # For older JAX versions, it is available via jax.experimental.array_api. - import jax.numpy - if hasattr(jax.numpy, "__array_api_version__"): - jnp = jax.numpy - else: - import jax.experimental.array_api as jnp - namespaces.add(jnp) - elif is_pydata_sparse_array(x): - if use_compat is True: - _check_api_version(api_version) - raise ValueError("`sparse` does not have an array-api-compat wrapper") - else: - import sparse - # `sparse` is already an array namespace. We do not have a wrapper - # submodule for it. - namespaces.add(sparse) - elif hasattr(x, '__array_namespace__'): - if use_compat is True: - raise ValueError("The given array does not have an array-api-compat wrapper") - namespaces.add(x.__array_namespace__(api_version=api_version)) - else: - # TODO: Support Python scalars? - raise TypeError(f"{type(x).__name__} is not a supported array type") - - if not namespaces: - raise TypeError("Unrecognized array input") - - if len(namespaces) != 1: + xp, info = _cls_to_namespace(cast(Hashable, type(x)), api_version, use_compat) + if info is _ClsToXPInfo.SCALAR: + continue + + if ( + info is _ClsToXPInfo.MAYBE_JAX_ZERO_GRADIENT + and _is_jax_zero_gradient_array(x) + ): + xp = _jax_namespace(api_version, use_compat) + + if xp is None: + get_ns = getattr(x, "__array_namespace__", None) + if get_ns is None: + raise TypeError(f"{type(x).__name__} is not a supported array type") + if use_compat: + raise ValueError( + "The given array does not have an array-api-compat wrapper" + ) + xp = get_ns(api_version=api_version) + + namespaces.add(xp) + + try: + (xp,) = namespaces + return xp + except ValueError: + if not namespaces: + raise TypeError( + "array_namespace requires at least one non-scalar array input" + ) raise TypeError(f"Multiple namespaces for array inputs: {namespaces}") - xp, = namespaces - - return xp # backwards compatibility alias get_namespace = array_namespace -def _check_device(xp, device): - if xp == sys.modules.get('numpy'): - if device not in ["cpu", None]: + +def _check_device(bare_xp: Namespace, device: Device) -> None: # pyright: ignore[reportUnusedFunction] + """ + Validate dummy device on device-less array backends. + + Notes + ----- + This function is also invoked by CuPy, which does have multiple devices + if there are multiple GPUs available. + However, CuPy multi-device support is currently impossible + without using the global device or a context manager: + + https://github.com/data-apis/array-api-compat/pull/293 + """ + if bare_xp is sys.modules.get("numpy"): + if device not in ("cpu", None): raise ValueError(f"Unsupported device for NumPy: {device!r}") + elif bare_xp is sys.modules.get("dask.array"): + if device not in ("cpu", _DASK_DEVICE, None): + raise ValueError(f"Unsupported device for Dask: {device!r}") + + # Placeholder object to represent the dask device # when the array backend is not the CPU. # (since it is not easy to tell which device a dask array is on) class _dask_device: - def __repr__(self): + def __repr__(self) -> Literal["DASK_DEVICE"]: return "DASK_DEVICE" + _DASK_DEVICE = _dask_device() + # device() is not on numpy.ndarray or dask.array and to_device() is not on numpy.ndarray # or cupy.ndarray. They are not included in array objects of this library # because this library just reuses the respective ndarray classes without # wrapping or subclassing them. These helper functions can be used instead of # the wrapper functions for libraries that need to support both NumPy/CuPy and # other libraries that use devices. -def device(x: Array, /) -> Device: +def device(x: _ArrayApiObj, /) -> Device: """ Hardware device the array data resides on. @@ -626,86 +743,86 @@ def device(x: Array, /) -> Device: if is_numpy_array(x): return "cpu" elif is_dask_array(x): - # Peek at the metadata of the jax array to determine type - try: - import numpy as np - if isinstance(x._meta, np.ndarray): - # Must be on CPU since backed by numpy - return "cpu" - except ImportError: - pass + # Peek at the metadata of the Dask array to determine type + if is_numpy_array(x._meta): + # Must be on CPU since backed by numpy + return "cpu" return _DASK_DEVICE elif is_jax_array(x): - # JAX has .device() as a method, but it is being deprecated so that it - # can become a property, in accordance with the standard. In order for - # this function to not break when JAX makes the flip, we check for - # both here. - if inspect.ismethod(x.device): - return x.device() + # FIXME Jitted JAX arrays do not have a device attribute + # https://github.com/jax-ml/jax/issues/26000 + # Return None in this case. Note that this workaround breaks + # the standard and will result in new arrays being created on the + # default device instead of the same device as the input array(s). + x_device = getattr(x, "device", None) + # Older JAX releases had .device() as a method, which has been replaced + # with a property in accordance with the standard. + if inspect.ismethod(x_device): + return x_device() else: - return x.device + return x_device elif is_pydata_sparse_array(x): # `sparse` will gain `.device`, so check for this first. - x_device = getattr(x, 'device', None) + x_device = getattr(x, "device", None) if x_device is not None: return x_device # Everything but DOK has this attr. try: - inner = x.data + inner = x.data # pyright: ignore except AttributeError: return "cpu" # Return the device of the constituent array - return device(inner) - return x.device + return device(inner) # pyright: ignore + return x.device # type: ignore # pyright: ignore + # Prevent shadowing, used below _device = device + # Based on cupy.array_api.Array.to_device -def _cupy_to_device(x, device, /, stream=None): +def _cupy_to_device( + x: cp.ndarray, + device: Device, + /, + stream: int | Any | None = None, +) -> cp.ndarray: import cupy as cp - from cupy.cuda import Device as _Device - from cupy.cuda import stream as stream_module - from cupy_backends.cuda.api import runtime - if device == x.device: - return x - elif device == "cpu": + if device == "cpu": # allowing us to use `to_device(x, "cpu")` # is useful for portable test swapping between # host and device backends return x.get() - elif not isinstance(device, _Device): - raise ValueError(f"Unsupported device {device!r}") - else: - # see cupy/cupy#5985 for the reason how we handle device/stream here - prev_device = runtime.getDevice() - prev_stream: stream_module.Stream = None - if stream is not None: - prev_stream = stream_module.get_current_stream() - # stream can be an int as specified in __dlpack__, or a CuPy stream - if isinstance(stream, int): - stream = cp.cuda.ExternalStream(stream) - elif isinstance(stream, cp.cuda.Stream): - pass - else: - raise ValueError('the input stream is not recognized') - stream.use() - try: - runtime.setDevice(device.id) - arr = x.copy() - finally: - runtime.setDevice(prev_device) - if stream is not None: - prev_stream.use() - return arr - -def _torch_to_device(x, device, /, stream=None): + if not isinstance(device, cp.cuda.Device): + raise TypeError(f"Unsupported device type {device!r}") + + if stream is None: + with device: + return cp.asarray(x) + + # stream can be an int as specified in __dlpack__, or a CuPy stream + if isinstance(stream, int): + stream = cp.cuda.ExternalStream(stream) + elif not isinstance(stream, cp.cuda.Stream): + raise TypeError(f"Unsupported stream type {stream!r}") + + with device, stream: + return cp.asarray(x) + + +def _torch_to_device( + x: torch.Tensor, + device: torch.device | str | int, + /, + stream: int | Any | None = None, +) -> torch.Tensor: if stream is not None: raise NotImplementedError return x.to(device) -def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] = None) -> Array: + +def to_device(x: Array, device: Device, /, *, stream: int | Any | None = None) -> Array: """ Copy the array from the device on which it currently resides to the specified ``device``. @@ -725,7 +842,7 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] a ``device`` object (see the `Device Support `__ section of the array API specification). - stream: Optional[Union[int, Any]] + stream: int | Any | None stream object to use during copy. In addition to the types supported in ``array.__dlpack__``, implementations may choose to support any library-specific stream object with the caveat that any code using @@ -757,7 +874,7 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] if is_numpy_array(x): if stream is not None: raise ValueError("The stream argument to to_device() is not supported") - if device == 'cpu': + if device == "cpu": return x raise ValueError(f"Unsupported device {device!r}") elif is_cupy_array(x): @@ -769,33 +886,155 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] if stream is not None: raise ValueError("The stream argument to to_device() is not supported") # TODO: What if our array is on the GPU already? - if device == 'cpu': + if device == "cpu": return x raise ValueError(f"Unsupported device {device!r}") elif is_jax_array(x): if not hasattr(x, "__array_namespace__"): - # In JAX v0.4.31 and older, this import adds to_device method to x. - import jax.experimental.array_api # noqa: F401 + # In JAX v0.4.31 and older, this import adds to_device method to x... + import jax.experimental.array_api # noqa: F401 # pyright: ignore + + # ... but only on eager JAX. It won't work inside jax.jit. + if not hasattr(x, "to_device"): + return x return x.to_device(device, stream=stream) elif is_pydata_sparse_array(x) and device == _device(x): # Perform trivial check to return the same array if # device is same instead of err-ing. return x - return x.to_device(device, stream=stream) + return x.to_device(device, stream=stream) # pyright: ignore + -def size(x): +@overload +def size(x: HasShape[Collection[SupportsIndex]]) -> int: ... +@overload +def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None: ... +def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None: """ Return the total number of elements of x. This is equivalent to `x.size` according to the `standard `__. + This helper is included because PyTorch defines `size` in an :external+torch:meth:`incompatible way `. - + It also fixes dask.array's behaviour which returns nan for unknown sizes, whereas + the standard requires None. """ + # Lazy API compliant arrays, such as ndonnx, can contain None in their shape if None in x.shape: return None - return math.prod(x.shape) + out = math.prod(cast("Collection[SupportsIndex]", x.shape)) + # dask.array.Array.shape can contain NaN + return None if math.isnan(out) else out + + +@lru_cache(100) +def _is_writeable_cls(cls: type) -> bool | None: + if ( + _issubclass_fast(cls, "numpy", "generic") + or _issubclass_fast(cls, "jax", "Array") + or _issubclass_fast(cls, "sparse", "SparseArray") + ): + return False + if _is_array_api_cls(cls): + return True + return None + + +def is_writeable_array(x: object) -> TypeGuard[_ArrayApiObj]: + """ + Return False if ``x.__setitem__`` is expected to raise; True otherwise. + Return False if `x` is not an array API compatible object. + + Warning + ------- + As there is no standard way to check if an array is writeable without actually + writing to it, this function blindly returns True for all unknown array types. + """ + cls = cast(Hashable, type(x)) + if _issubclass_fast(cls, "numpy", "ndarray"): + return cast("npt.NDArray", x).flags.writeable + res = _is_writeable_cls(cls) + if res is not None: + return res + return hasattr(x, '__array_namespace__') + + +@lru_cache(100) +def _is_lazy_cls(cls: type) -> bool | None: + if ( + _issubclass_fast(cls, "numpy", "ndarray") + or _issubclass_fast(cls, "numpy", "generic") + or _issubclass_fast(cls, "cupy", "ndarray") + or _issubclass_fast(cls, "torch", "Tensor") + or _issubclass_fast(cls, "sparse", "SparseArray") + ): + return False + if ( + _issubclass_fast(cls, "jax", "Array") + or _issubclass_fast(cls, "dask.array", "Array") + or _issubclass_fast(cls, "ndonnx", "Array") + ): + return True + return None + + +def is_lazy_array(x: object) -> TypeGuard[_ArrayApiObj]: + """Return True if x is potentially a future or it may be otherwise impossible or + expensive to eagerly read its contents, regardless of their size, e.g. by + calling ``bool(x)`` or ``float(x)``. + + Return False otherwise; e.g. ``bool(x)`` etc. is guaranteed to succeed and to be + cheap as long as the array has the right dtype and size. + + Note + ---- + This function errs on the side of caution for array types that may or may not be + lazy, e.g. JAX arrays, by always returning True for them. + """ + # **JAX note:** while it is possible to determine if you're inside or outside + # jax.jit by testing the subclass of a jax.Array object, as well as testing bool() + # as we do below for unknown arrays, this is not recommended by JAX best practices. + + # **Dask note:** Dask eagerly computes the graph on __bool__, __float__, and so on. + # This behaviour, while impossible to change without breaking backwards + # compatibility, is highly detrimental to performance as the whole graph will end + # up being computed multiple times. + + # Note: skipping reclassification of JAX zero gradient arrays, as one will + # exclusively get them once they leave a jax.grad JIT context. + cls = cast(Hashable, type(x)) + res = _is_lazy_cls(cls) + if res is not None: + return res + + if not hasattr(x, "__array_namespace__"): + return False + + # Unknown Array API compatible object. Note that this test may have dire consequences + # in terms of performance, e.g. for a lazy object that eagerly computes the graph + # on __bool__ (dask is one such example, which however is special-cased above). + + # Select a single point of the array + s = size(cast("HasShape[Collection[SupportsIndex | None]]", x)) + if s is None: + return True + xp = array_namespace(x) + if s > 1: + x = xp.reshape(x, (-1,))[0] + # Cast to dtype=bool and deal with size 0 arrays + x = xp.any(x) + + try: + bool(x) + return False + # The Array API standard dictactes that __bool__ should raise TypeError if the + # output cannot be defined. + # Here we allow for it to raise arbitrary exceptions, e.g. like Dask does. + except Exception: + return True + __all__ = [ "array_namespace", @@ -817,8 +1056,11 @@ def size(x): "is_ndonnx_namespace", "is_pydata_sparse_array", "is_pydata_sparse_namespace", + "is_writeable_array", + "is_lazy_array", "size", "to_device", ] -_all_ignore = ['sys', 'math', 'inspect', 'warnings'] +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/common/_linalg.py b/array_api_compat/common/_linalg.py index bfa1f1b9..69672af7 100644 --- a/array_api_compat/common/_linalg.py +++ b/array_api_compat/common/_linalg.py @@ -1,85 +1,114 @@ from __future__ import annotations -from typing import TYPE_CHECKING, NamedTuple -if TYPE_CHECKING: - from typing import Literal, Optional, Tuple, Union - from ._typing import ndarray - import math +from typing import Literal, NamedTuple, cast import numpy as np + if np.__version__[0] == "2": from numpy.lib.array_utils import normalize_axis_tuple else: - from numpy.core.numeric import normalize_axis_tuple + from numpy.core.numeric import normalize_axis_tuple # type: ignore[no-redef] -from ._aliases import matmul, matrix_transpose, tensordot, vecdot, isdtype from .._internal import get_xp +from ._aliases import isdtype, matmul, matrix_transpose, tensordot, vecdot +from ._typing import Array, DType, JustFloat, JustInt, Namespace + # These are in the main NumPy namespace but not in numpy.linalg -def cross(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1, **kwargs) -> ndarray: +def cross( + x1: Array, + x2: Array, + /, + xp: Namespace, + *, + axis: int = -1, + **kwargs: object, +) -> Array: return xp.cross(x1, x2, axis=axis, **kwargs) -def outer(x1: ndarray, x2: ndarray, /, xp, **kwargs) -> ndarray: +def outer(x1: Array, x2: Array, /, xp: Namespace, **kwargs: object) -> Array: return xp.outer(x1, x2, **kwargs) class EighResult(NamedTuple): - eigenvalues: ndarray - eigenvectors: ndarray + eigenvalues: Array + eigenvectors: Array class QRResult(NamedTuple): - Q: ndarray - R: ndarray + Q: Array + R: Array class SlogdetResult(NamedTuple): - sign: ndarray - logabsdet: ndarray + sign: Array + logabsdet: Array class SVDResult(NamedTuple): - U: ndarray - S: ndarray - Vh: ndarray + U: Array + S: Array + Vh: Array # These functions are the same as their NumPy counterparts except they return # a namedtuple. -def eigh(x: ndarray, /, xp, **kwargs) -> EighResult: +def eigh(x: Array, /, xp: Namespace, **kwargs: object) -> EighResult: return EighResult(*xp.linalg.eigh(x, **kwargs)) -def qr(x: ndarray, /, xp, *, mode: Literal['reduced', 'complete'] = 'reduced', - **kwargs) -> QRResult: +def qr( + x: Array, + /, + xp: Namespace, + *, + mode: Literal["reduced", "complete"] = "reduced", + **kwargs: object, +) -> QRResult: return QRResult(*xp.linalg.qr(x, mode=mode, **kwargs)) -def slogdet(x: ndarray, /, xp, **kwargs) -> SlogdetResult: +def slogdet(x: Array, /, xp: Namespace, **kwargs: object) -> SlogdetResult: return SlogdetResult(*xp.linalg.slogdet(x, **kwargs)) -def svd(x: ndarray, /, xp, *, full_matrices: bool = True, **kwargs) -> SVDResult: +def svd( + x: Array, + /, + xp: Namespace, + *, + full_matrices: bool = True, + **kwargs: object, +) -> SVDResult: return SVDResult(*xp.linalg.svd(x, full_matrices=full_matrices, **kwargs)) # These functions have additional keyword arguments # The upper keyword argument is new from NumPy -def cholesky(x: ndarray, /, xp, *, upper: bool = False, **kwargs) -> ndarray: +def cholesky( + x: Array, + /, + xp: Namespace, + *, + upper: bool = False, + **kwargs: object, +) -> Array: L = xp.linalg.cholesky(x, **kwargs) if upper: U = get_xp(xp)(matrix_transpose)(L) if get_xp(xp)(isdtype)(U.dtype, 'complex floating'): - U = xp.conj(U) + U = xp.conj(U) # pyright: ignore[reportConstantRedefinition] return U return L # The rtol keyword argument of matrix_rank() and pinv() is new from NumPy. # Note that it has a different semantic meaning from tol and rcond. -def matrix_rank(x: ndarray, - /, - xp, - *, - rtol: Optional[Union[float, ndarray]] = None, - **kwargs) -> ndarray: +def matrix_rank( + x: Array, + /, + xp: Namespace, + *, + rtol: float | Array | None = None, + **kwargs: object, +) -> Array: # this is different from xp.linalg.matrix_rank, which supports 1 # dimensional arrays. if x.ndim < 2: raise xp.linalg.LinAlgError("1-dimensional array given. Array must be at least two-dimensional") - S = get_xp(xp)(svdvals)(x, **kwargs) + S: Array = get_xp(xp)(svdvals)(x, **kwargs) if rtol is None: tol = S.max(axis=-1, keepdims=True) * max(x.shape[-2:]) * xp.finfo(S.dtype).eps else: @@ -88,7 +117,14 @@ def matrix_rank(x: ndarray, tol = S.max(axis=-1, keepdims=True)*xp.asarray(rtol)[..., xp.newaxis] return xp.count_nonzero(S > tol, axis=-1) -def pinv(x: ndarray, /, xp, *, rtol: Optional[Union[float, ndarray]] = None, **kwargs) -> ndarray: +def pinv( + x: Array, + /, + xp: Namespace, + *, + rtol: float | Array | None = None, + **kwargs: object, +) -> Array: # this is different from xp.linalg.pinv, which does not multiply the # default tolerance by max(M, N). if rtol is None: @@ -97,15 +133,30 @@ def pinv(x: ndarray, /, xp, *, rtol: Optional[Union[float, ndarray]] = None, **k # These functions are new in the array API spec -def matrix_norm(x: ndarray, /, xp, *, keepdims: bool = False, ord: Optional[Union[int, float, Literal['fro', 'nuc']]] = 'fro') -> ndarray: +def matrix_norm( + x: Array, + /, + xp: Namespace, + *, + keepdims: bool = False, + ord: Literal[1, 2, -1, -2] | JustFloat | Literal["fro", "nuc"] | None = "fro", +) -> Array: return xp.linalg.norm(x, axis=(-2, -1), keepdims=keepdims, ord=ord) # svdvals is not in NumPy (but it is in SciPy). It is equivalent to # xp.linalg.svd(compute_uv=False). -def svdvals(x: ndarray, /, xp) -> Union[ndarray, Tuple[ndarray, ...]]: +def svdvals(x: Array, /, xp: Namespace) -> Array | tuple[Array, ...]: return xp.linalg.svd(x, compute_uv=False) -def vector_norm(x: ndarray, /, xp, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ord: Optional[Union[int, float]] = 2) -> ndarray: +def vector_norm( + x: Array, + /, + xp: Namespace, + *, + axis: int | tuple[int, ...] | None = None, + keepdims: bool = False, + ord: JustInt | JustFloat = 2, +) -> Array: # xp.linalg.norm tries to do a matrix norm whenever axis is a 2-tuple or # when axis=None and the input is 2-D, so to force a vector norm, we make # it so the input is 1-D (for axis=None), or reshape so that norm is done @@ -117,7 +168,10 @@ def vector_norm(x: ndarray, /, xp, *, axis: Optional[Union[int, Tuple[int, ...]] elif isinstance(axis, tuple): # Note: The axis argument supports any number of axes, whereas # xp.linalg.norm() only supports a single axis for vector norm. - normalized_axis = normalize_axis_tuple(axis, x.ndim) + normalized_axis = cast( + "tuple[int, ...]", + normalize_axis_tuple(axis, x.ndim), # pyright: ignore[reportCallIssue] + ) rest = tuple(i for i in range(x.ndim) if i not in normalized_axis) newshape = axis + rest _x = xp.transpose(x, newshape).reshape( @@ -133,8 +187,14 @@ def vector_norm(x: ndarray, /, xp, *, axis: Optional[Union[int, Tuple[int, ...]] # We can't reuse xp.linalg.norm(keepdims) because of the reshape hacks # above to avoid matrix norm logic. shape = list(x.shape) - _axis = normalize_axis_tuple(range(x.ndim) if axis is None else axis, x.ndim) - for i in _axis: + axes = cast( + "tuple[int, ...]", + normalize_axis_tuple( # pyright: ignore[reportCallIssue] + range(x.ndim) if axis is None else axis, + x.ndim, + ), + ) + for i in axes: shape[i] = 1 res = xp.reshape(res, tuple(shape)) @@ -143,14 +203,28 @@ def vector_norm(x: ndarray, /, xp, *, axis: Optional[Union[int, Tuple[int, ...]] # xp.diagonal and xp.trace operate on the first two axes whereas these # operates on the last two -def diagonal(x: ndarray, /, xp, *, offset: int = 0, **kwargs) -> ndarray: +def diagonal(x: Array, /, xp: Namespace, *, offset: int = 0, **kwargs: object) -> Array: return xp.diagonal(x, offset=offset, axis1=-2, axis2=-1, **kwargs) -def trace(x: ndarray, /, xp, *, offset: int = 0, dtype=None, **kwargs) -> ndarray: - return xp.asarray(xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs)) +def trace( + x: Array, + /, + xp: Namespace, + *, + offset: int = 0, + dtype: DType | None = None, + **kwargs: object, +) -> Array: + return xp.asarray( + xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs) + ) __all__ = ['cross', 'matmul', 'outer', 'tensordot', 'EighResult', 'QRResult', 'SlogdetResult', 'SVDResult', 'eigh', 'qr', 'slogdet', 'svd', 'cholesky', 'matrix_rank', 'pinv', 'matrix_norm', 'matrix_transpose', 'svdvals', 'vecdot', 'vector_norm', 'diagonal', 'trace'] + + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/common/_typing.py b/array_api_compat/common/_typing.py index 07f3850d..11b00bd1 100644 --- a/array_api_compat/common/_typing.py +++ b/array_api_compat/common/_typing.py @@ -1,23 +1,189 @@ from __future__ import annotations -__all__ = [ - "NestedSequence", - "SupportsBufferProtocol", -] - +from collections.abc import Mapping +from types import ModuleType as Namespace from typing import ( - Any, - TypeVar, + TYPE_CHECKING, + Literal, Protocol, + TypeAlias, + TypedDict, + TypeVar, + final, ) +if TYPE_CHECKING: + from _typeshed import Incomplete + + SupportsBufferProtocol: TypeAlias = Incomplete + Array: TypeAlias = Incomplete + Device: TypeAlias = Incomplete + DType: TypeAlias = Incomplete +else: + SupportsBufferProtocol = object + Array = object + Device = object + DType = object + + _T_co = TypeVar("_T_co", covariant=True) + +# These "Just" types are equivalent to the `Just` type from the `optype` library, +# apart from them not being `@runtime_checkable`. +# - docs: https://github.com/jorenham/optype/blob/master/README.md#just +# - code: https://github.com/jorenham/optype/blob/master/optype/_core/_just.py +@final +class JustInt(Protocol): # type: ignore[misc] + @property # type: ignore[override] + def __class__(self, /) -> type[int]: ... + @__class__.setter + def __class__(self, value: type[int], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride] + + +@final +class JustFloat(Protocol): # type: ignore[misc] + @property # type: ignore[override] + def __class__(self, /) -> type[float]: ... + @__class__.setter + def __class__(self, value: type[float], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride] + + +@final +class JustComplex(Protocol): # type: ignore[misc] + @property # type: ignore[override] + def __class__(self, /) -> type[complex]: ... + @__class__.setter + def __class__(self, value: type[complex], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride] + + class NestedSequence(Protocol[_T_co]): def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ... def __len__(self, /) -> int: ... -SupportsBufferProtocol = Any -Array = Any -Device = Any +class SupportsArrayNamespace(Protocol[_T_co]): + def __array_namespace__(self, /, *, api_version: str | None) -> _T_co: ... + + +class HasShape(Protocol[_T_co]): + @property + def shape(self, /) -> _T_co: ... + + +# Return type of `__array_namespace_info__.default_dtypes` +Capabilities = TypedDict( + "Capabilities", + { + "boolean indexing": bool, + "data-dependent shapes": bool, + "max dimensions": int, + }, +) + +# Return type of `__array_namespace_info__.default_dtypes` +DefaultDTypes = TypedDict( + "DefaultDTypes", + { + "real floating": DType, + "complex floating": DType, + "integral": DType, + "indexing": DType, + }, +) + + +_DTypeKind: TypeAlias = Literal[ + "bool", + "signed integer", + "unsigned integer", + "integral", + "real floating", + "complex floating", + "numeric", +] +# Type of the `kind` parameter in `__array_namespace_info__.dtypes` +DTypeKind: TypeAlias = _DTypeKind | tuple[_DTypeKind, ...] + + +# `__array_namespace_info__.dtypes(kind="bool")` +class DTypesBool(TypedDict): + bool: DType + + +# `__array_namespace_info__.dtypes(kind="signed integer")` +class DTypesSigned(TypedDict): + int8: DType + int16: DType + int32: DType + int64: DType + + +# `__array_namespace_info__.dtypes(kind="unsigned integer")` +class DTypesUnsigned(TypedDict): + uint8: DType + uint16: DType + uint32: DType + uint64: DType + + +# `__array_namespace_info__.dtypes(kind="integral")` +class DTypesIntegral(DTypesSigned, DTypesUnsigned): + pass + + +# `__array_namespace_info__.dtypes(kind="real floating")` +class DTypesReal(TypedDict): + float32: DType + float64: DType + + +# `__array_namespace_info__.dtypes(kind="complex floating")` +class DTypesComplex(TypedDict): + complex64: DType + complex128: DType + + +# `__array_namespace_info__.dtypes(kind="numeric")` +class DTypesNumeric(DTypesIntegral, DTypesReal, DTypesComplex): + pass + + +# `__array_namespace_info__.dtypes(kind=None)` (default) +class DTypesAll(DTypesBool, DTypesNumeric): + pass + + +# `__array_namespace_info__.dtypes(kind=?)` (fallback) +DTypesAny: TypeAlias = Mapping[str, DType] + + +__all__ = [ + "Array", + "Capabilities", + "DType", + "DTypeKind", + "DTypesAny", + "DTypesAll", + "DTypesBool", + "DTypesNumeric", + "DTypesIntegral", + "DTypesSigned", + "DTypesUnsigned", + "DTypesReal", + "DTypesComplex", + "DefaultDTypes", + "Device", + "HasShape", + "Namespace", + "JustInt", + "JustFloat", + "JustComplex", + "NestedSequence", + "SupportsArrayNamespace", + "SupportsBufferProtocol", +] + + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/cupy/__init__.py b/array_api_compat/cupy/__init__.py index 7968d68d..af003c5a 100644 --- a/array_api_compat/cupy/__init__.py +++ b/array_api_compat/cupy/__init__.py @@ -1,3 +1,4 @@ +from typing import Final from cupy import * # noqa: F403 # from cupy import * doesn't overwrite these builtin names @@ -5,12 +6,19 @@ # These imports may overwrite names from the import * above. from ._aliases import * # noqa: F403 +from ._info import __array_namespace_info__ # noqa: F401 # See the comment in the numpy __init__.py __import__(__package__ + '.linalg') - __import__(__package__ + '.fft') -from ..common._helpers import * # noqa: F401,F403 +__array_api_version__: Final = '2024.12' + +__all__ = sorted( + {name for name in globals() if not name.startswith("__")} + - {"Final", "_aliases", "_info", "_typing"} + | {"__array_api_version__", "__array_namespace_info__", "linalg", "fft"} +) -__array_api_version__ = '2022.12' +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index 30ae2943..2e512fc8 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -1,16 +1,13 @@ from __future__ import annotations +from builtins import bool as py_bool + import cupy as cp -from ..common import _aliases +from ..common import _aliases, _helpers +from ..common._typing import NestedSequence, SupportsBufferProtocol from .._internal import get_xp - -from ._info import __array_namespace_info__ - -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from typing import Optional, Union - from ._typing import ndarray, Device, Dtype, NestedSequence, SupportsBufferProtocol +from ._typing import Array, Device, DType bool = cp.bool_ @@ -46,42 +43,34 @@ unique_counts = get_xp(cp)(_aliases.unique_counts) unique_inverse = get_xp(cp)(_aliases.unique_inverse) unique_values = get_xp(cp)(_aliases.unique_values) -astype = _aliases.astype std = get_xp(cp)(_aliases.std) var = get_xp(cp)(_aliases.var) cumulative_sum = get_xp(cp)(_aliases.cumulative_sum) +cumulative_prod = get_xp(cp)(_aliases.cumulative_prod) clip = get_xp(cp)(_aliases.clip) permute_dims = get_xp(cp)(_aliases.permute_dims) reshape = get_xp(cp)(_aliases.reshape) argsort = get_xp(cp)(_aliases.argsort) sort = get_xp(cp)(_aliases.sort) nonzero = get_xp(cp)(_aliases.nonzero) -ceil = get_xp(cp)(_aliases.ceil) -floor = get_xp(cp)(_aliases.floor) -trunc = get_xp(cp)(_aliases.trunc) matmul = get_xp(cp)(_aliases.matmul) matrix_transpose = get_xp(cp)(_aliases.matrix_transpose) tensordot = get_xp(cp)(_aliases.tensordot) +sign = get_xp(cp)(_aliases.sign) +finfo = get_xp(cp)(_aliases.finfo) +iinfo = get_xp(cp)(_aliases.iinfo) -_copy_default = object() # asarray also adds the copy keyword, which is not present in numpy 1.0. def asarray( - obj: Union[ - ndarray, - bool, - int, - float, - NestedSequence[bool | int | float], - SupportsBufferProtocol, - ], + obj: Array | complex | NestedSequence[complex] | SupportsBufferProtocol, /, *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, - copy: Optional[bool] = _copy_default, - **kwargs, -) -> ndarray: + dtype: DType | None = None, + device: Device | None = None, + copy: py_bool | None = None, + **kwargs: object, +) -> Array: """ Array API compatibility wrapper for asarray(). @@ -89,32 +78,66 @@ def asarray( specification for more details. """ with cp.cuda.Device(device): - # cupy is like NumPy 1.26 (except without _CopyMode). See the comments - # in asarray in numpy/_aliases.py. - if copy is not _copy_default: - # A future version of CuPy will change the meaning of copy=False - # to mean no-copy. We don't know for certain what version it will - # be yet, so to avoid breaking that version, we use a different - # default value for copy so asarray(obj) with no copy kwarg will - # always do the copy-if-needed behavior. - - # This will still need to be updated to remove the - # NotImplementedError for copy=False, but at least this won't - # break the default or existing behavior. - if copy is None: - copy = False - elif copy is False: - raise NotImplementedError("asarray(copy=False) is not yet supported in cupy") - kwargs['copy'] = copy - - return cp.array(obj, dtype=dtype, **kwargs) - -def sign(x: ndarray, /) -> ndarray: - # CuPy sign() does not propagate nans. See - # https://github.com/data-apis/array-api-compat/issues/136 - out = cp.sign(x) - out[cp.isnan(x)] = cp.nan - return out + if copy is None: + return cp.asarray(obj, dtype=dtype, **kwargs) + else: + res = cp.array(obj, dtype=dtype, copy=copy, **kwargs) + if not copy and res is not obj: + raise ValueError("Unable to avoid copy while creating an array as requested") + return res + + +def astype( + x: Array, + dtype: DType, + /, + *, + copy: py_bool = True, + device: Device | None = None, +) -> Array: + if device is None: + return x.astype(dtype=dtype, copy=copy) + out = _helpers.to_device(x.astype(dtype=dtype, copy=False), device) + return out.copy() if copy and out is x else out + + +# cupy.count_nonzero does not have keepdims +def count_nonzero( + x: Array, + axis: int | tuple[int, ...] | None = None, + keepdims: py_bool = False, +) -> Array: + result = cp.count_nonzero(x, axis) + if keepdims: + if axis is None: + return cp.reshape(result, [1]*x.ndim) + return cp.expand_dims(result, axis) + return result + +# ceil, floor, and trunc return integers for integer inputs + +def ceil(x: Array, /) -> Array: + if cp.issubdtype(x.dtype, cp.integer): + return x.copy() + return cp.ceil(x) + + +def floor(x: Array, /) -> Array: + if cp.issubdtype(x.dtype, cp.integer): + return x.copy() + return cp.floor(x) + + +def trunc(x: Array, /) -> Array: + if cp.issubdtype(x.dtype, cp.integer): + return x.copy() + return cp.trunc(x) + + +# take_along_axis: axis defaults to -1 but in cupy (and numpy) axis is a required arg +def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array: + return cp.take_along_axis(x, indices, axis=axis) + # These functions are completely new here. If the library already has them # (i.e., numpy 2.0), use the library version instead of our wrapper. @@ -133,10 +156,13 @@ def sign(x: ndarray, /) -> ndarray: else: unstack = get_xp(cp)(_aliases.unstack) -__all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', 'bool', +__all__ = _aliases.__all__ + ['asarray', 'astype', 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atan2', 'atanh', 'bitwise_left_shift', 'bitwise_invert', 'bitwise_right_shift', - 'concat', 'pow', 'sign'] + 'bool', 'concat', 'count_nonzero', 'pow', 'sign', + 'ceil', 'floor', 'trunc', 'take_along_axis'] + -_all_ignore = ['cp', 'get_xp'] +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/cupy/_info.py b/array_api_compat/cupy/_info.py index 4440807d..78e48a33 100644 --- a/array_api_compat/cupy/_info.py +++ b/array_api_compat/cupy/_info.py @@ -26,6 +26,7 @@ complex128, ) + class __array_namespace_info__: """ Get the array API inspection namespace for CuPy. @@ -49,7 +50,7 @@ class __array_namespace_info__: Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.default_dtypes() {'real floating': cupy.float64, 'complex floating': cupy.complex128, @@ -94,14 +95,14 @@ def capabilities(self): >>> info = xp.__array_namespace_info__() >>> info.capabilities() {'boolean indexing': True, - 'data-dependent shapes': True} + 'data-dependent shapes': True, + 'max dimensions': 64} """ return { "boolean indexing": True, "data-dependent shapes": True, - # 'max rank' will be part of the 2024.12 standard - # "max rank": 64, + "max dimensions": 64, } def default_device(self): @@ -117,7 +118,7 @@ def default_device(self): Returns ------- - device : str + device : Device The default device used for new CuPy arrays. Examples @@ -126,6 +127,15 @@ def default_device(self): >>> info.default_device() Device(0) + Notes + ----- + This method returns the static default device when CuPy is initialized. + However, the *current* device used by creation functions (``empty`` etc.) + can be changed globally or with a context manager. + + See Also + -------- + https://github.com/data-apis/array-api/issues/835 """ return cuda.Device(0) @@ -312,7 +322,7 @@ def devices(self): Returns ------- - devices : list of str + devices : list[Device] The devices supported by CuPy. See Also diff --git a/array_api_compat/cupy/_typing.py b/array_api_compat/cupy/_typing.py index f3d9aab6..e5c202dc 100644 --- a/array_api_compat/cupy/_typing.py +++ b/array_api_compat/cupy/_typing.py @@ -1,46 +1,30 @@ from __future__ import annotations -__all__ = [ - "ndarray", - "Device", - "Dtype", -] +__all__ = ["Array", "DType", "Device"] -import sys -from typing import ( - Union, - TYPE_CHECKING, -) - -from cupy import ( - ndarray, - dtype, - int8, - int16, - int32, - int64, - uint8, - uint16, - uint32, - uint64, - float32, - float64, -) +from typing import TYPE_CHECKING +import cupy as cp +from cupy import ndarray as Array from cupy.cuda.device import Device -if TYPE_CHECKING or sys.version_info >= (3, 9): - Dtype = dtype[Union[ - int8, - int16, - int32, - int64, - uint8, - uint16, - uint32, - uint64, - float32, - float64, - ]] +if TYPE_CHECKING: + # NumPy 1.x on Python 3.10 fails to parse np.dtype[] + DType = cp.dtype[ + cp.intp + | cp.int8 + | cp.int16 + | cp.int32 + | cp.int64 + | cp.uint8 + | cp.uint16 + | cp.uint32 + | cp.uint64 + | cp.float32 + | cp.float64 + | cp.complex64 + | cp.complex128 + | cp.bool_ + ] else: - Dtype = dtype + DType = cp.dtype diff --git a/array_api_compat/cupy/fft.py b/array_api_compat/cupy/fft.py index 307e0f72..53a9a454 100644 --- a/array_api_compat/cupy/fft.py +++ b/array_api_compat/cupy/fft.py @@ -1,10 +1,11 @@ -from cupy.fft import * # noqa: F403 +from cupy.fft import * # noqa: F403 + # cupy.fft doesn't have __all__. If it is added, replace this with # # from cupy.fft import __all__ as linalg_all -_n = {} -exec('from cupy.fft import *', _n) -del _n['__builtins__'] +_n: dict[str, object] = {} +exec("from cupy.fft import *", _n) +del _n["__builtins__"] fft_all = list(_n) del _n @@ -30,7 +31,6 @@ __all__ = fft_all + _fft.__all__ -del get_xp -del cp -del fft_all -del _fft +def __dir__() -> list[str]: + return __all__ + diff --git a/array_api_compat/cupy/linalg.py b/array_api_compat/cupy/linalg.py index 7fcdd498..da301574 100644 --- a/array_api_compat/cupy/linalg.py +++ b/array_api_compat/cupy/linalg.py @@ -2,7 +2,7 @@ # cupy.linalg doesn't have __all__. If it is added, replace this with # # from cupy.linalg import __all__ as linalg_all -_n = {} +_n: dict[str, object] = {} exec('from cupy.linalg import *', _n) del _n['__builtins__'] linalg_all = list(_n) @@ -43,7 +43,5 @@ __all__ = linalg_all + _linalg.__all__ -del get_xp -del cp -del linalg_all -del _linalg +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/dask/array/__init__.py b/array_api_compat/dask/array/__init__.py index 03e0cd72..f78aa8b3 100644 --- a/array_api_compat/dask/array/__init__.py +++ b/array_api_compat/dask/array/__init__.py @@ -1,8 +1,26 @@ -from dask.array import * # noqa: F403 +from typing import Final + +from ..._internal import clone_module + +__all__ = clone_module("dask.array", globals()) # These imports may overwrite names from the import * above. -from ._aliases import * # noqa: F403 +from . import _aliases +from ._aliases import * # type: ignore[assignment] # noqa: F403 +from ._info import __array_namespace_info__ # noqa: F401 -__array_api_version__ = '2022.12' +__array_api_version__: Final = "2024.12" +del Final +# See the comment in the numpy __init__.py __import__(__package__ + '.linalg') +__import__(__package__ + '.fft') + +__all__ = sorted( + set(__all__) + | set(_aliases.__all__) + | {"__array_api_version__", "__array_namespace_info__", "linalg", "fft"} +) + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index cf57c824..54d323b2 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -1,72 +1,101 @@ -from __future__ import annotations +# pyright: reportPrivateUsage=false +# pyright: reportUnknownArgumentType=false +# pyright: reportUnknownMemberType=false +# pyright: reportUnknownVariableType=false -from ...common import _aliases -from ...common._helpers import _check_device +from __future__ import annotations -from ..._internal import get_xp +from builtins import bool as py_bool +from collections.abc import Callable +from typing import TYPE_CHECKING, Any -from ._info import __array_namespace_info__ +if TYPE_CHECKING: + from typing_extensions import TypeIs +import dask.array as da import numpy as np +from numpy import bool_ as bool from numpy import ( - # Constants - e, - inf, - nan, - pi, - newaxis, - # Dtypes - bool_ as bool, + can_cast, + complex64, + complex128, float32, float64, int8, int16, int32, int64, + result_type, uint8, uint16, uint32, uint64, - complex64, - complex128, - iinfo, - finfo, - can_cast, - result_type, ) -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from typing import Optional, Union - - from ...common._typing import Device, Dtype, Array, NestedSequence, SupportsBufferProtocol - -import dask.array as da +from ..._internal import get_xp +from ...common import _aliases, _helpers, array_namespace +from ...common._typing import ( + Array, + Device, + DType, + NestedSequence, + SupportsBufferProtocol, +) isdtype = get_xp(np)(_aliases.isdtype) unstack = get_xp(da)(_aliases.unstack) -astype = _aliases.astype + + +# da.astype doesn't respect copy=True +def astype( + x: Array, + dtype: DType, + /, + *, + copy: py_bool = True, + device: Device | None = None, +) -> Array: + """ + Array API compatibility wrapper for astype(). + + See the corresponding documentation in the array library and/or the array API + specification for more details. + """ + # TODO: respect device keyword? + _helpers._check_device(da, device) + + if not copy and dtype == x.dtype: + return x + x = x.astype(dtype) + return x.copy() if copy else x + # Common aliases + # This arange func is modified from the common one to # not pass stop/step as keyword arguments, which will cause # an error with dask - -# TODO: delete the xp stuff, it shouldn't be necessary -def _dask_arange( - start: Union[int, float], +def arange( + start: float, /, - stop: Optional[Union[int, float]] = None, - step: Union[int, float] = 1, + stop: float | None = None, + step: float = 1, *, - xp, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, ) -> Array: - _check_device(xp, device) - args = [start] + """ + Array API compatibility wrapper for arange(). + + See the corresponding documentation in the array library and/or the array API + specification for more details. + """ + # TODO: respect device keyword? + _helpers._check_device(da, device) + + args: list[Any] = [start] if stop is not None: args.append(stop) else: @@ -74,13 +103,12 @@ def _dask_arange( # prepend the default value for start which is 0 args.insert(0, 0) args.append(step) - return xp.arange(*args, dtype=dtype, **kwargs) -arange = get_xp(da)(_dask_arange) -eye = get_xp(da)(_aliases.eye) + return da.arange(*args, dtype=dtype, **kwargs) + -linspace = get_xp(da)(_aliases.linspace) eye = get_xp(da)(_aliases.eye) +linspace = get_xp(da)(_aliases.linspace) UniqueAllResult = get_xp(da)(_aliases.UniqueAllResult) UniqueCountsResult = get_xp(da)(_aliases.UniqueCountsResult) UniqueInverseResult = get_xp(da)(_aliases.UniqueInverseResult) @@ -92,6 +120,7 @@ def _dask_arange( std = get_xp(da)(_aliases.std) var = get_xp(da)(_aliases.var) cumulative_sum = get_xp(da)(_aliases.cumulative_sum) +cumulative_prod = get_xp(da)(_aliases.cumulative_prod) empty = get_xp(da)(_aliases.empty) empty_like = get_xp(da)(_aliases.empty_like) full = get_xp(da)(_aliases.full) @@ -103,31 +132,23 @@ def _dask_arange( reshape = get_xp(da)(_aliases.reshape) matrix_transpose = get_xp(da)(_aliases.matrix_transpose) vecdot = get_xp(da)(_aliases.vecdot) - nonzero = get_xp(da)(_aliases.nonzero) -ceil = get_xp(np)(_aliases.ceil) -floor = get_xp(np)(_aliases.floor) -trunc = get_xp(np)(_aliases.trunc) matmul = get_xp(np)(_aliases.matmul) tensordot = get_xp(np)(_aliases.tensordot) +sign = get_xp(np)(_aliases.sign) +finfo = get_xp(np)(_aliases.finfo) +iinfo = get_xp(np)(_aliases.iinfo) # asarray also adds the copy keyword, which is not present in numpy 1.0. def asarray( - obj: Union[ - Array, - bool, - int, - float, - NestedSequence[bool | int | float], - SupportsBufferProtocol, - ], + obj: Array | complex | NestedSequence[complex] | SupportsBufferProtocol, /, *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, - copy: "Optional[Union[bool, np._CopyMode]]" = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + copy: py_bool | None = None, + **kwargs: object, ) -> Array: """ Array API compatibility wrapper for asarray(). @@ -135,90 +156,214 @@ def asarray( See the corresponding documentation in the array library and/or the array API specification for more details. """ + # TODO: respect device keyword? + _helpers._check_device(da, device) + + if isinstance(obj, da.Array): + if dtype is not None and dtype != obj.dtype: + if copy is False: + raise ValueError("Unable to avoid copy when changing dtype") + obj = obj.astype(dtype) + return obj.copy() if copy else obj # pyright: ignore[reportAttributeAccessIssue] + if copy is False: - # copy=False is not yet implemented in dask - raise NotImplementedError("copy=False is not yet implemented") - elif copy is True: - if isinstance(obj, da.Array) and dtype is None: - return obj.copy() - # Go through numpy, since dask copy is no-op by default - obj = np.array(obj, dtype=dtype, copy=True) - return da.array(obj, dtype=dtype) - else: - if not isinstance(obj, da.Array) or dtype is not None and obj.dtype != dtype: - obj = np.asarray(obj, dtype=dtype) - return da.from_array(obj) - return obj - - return da.asarray(obj, dtype=dtype, **kwargs) - -from dask.array import ( - # Element wise aliases - arccos as acos, - arccosh as acosh, - arcsin as asin, - arcsinh as asinh, - arctan as atan, - arctan2 as atan2, - arctanh as atanh, - left_shift as bitwise_left_shift, - right_shift as bitwise_right_shift, - invert as bitwise_invert, - power as pow, - # Other - concatenate as concat, -) + raise ValueError( + "Unable to avoid copy when converting a non-dask object to dask" + ) + + # copy=None to be uniform across dask < 2024.12 and >= 2024.12 + # see https://github.com/dask/dask/pull/11524/ + obj = np.array(obj, dtype=dtype, copy=True) + return da.from_array(obj) + + +# Element wise aliases +from dask.array import arccos as acos +from dask.array import arccosh as acosh +from dask.array import arcsin as asin +from dask.array import arcsinh as asinh +from dask.array import arctan as atan +from dask.array import arctan2 as atan2 +from dask.array import arctanh as atanh + +# Other +from dask.array import concatenate as concat +from dask.array import invert as bitwise_invert +from dask.array import left_shift as bitwise_left_shift +from dask.array import power as pow +from dask.array import right_shift as bitwise_right_shift + # dask.array.clip does not work unless all three arguments are provided. # Furthermore, the masking workaround in common._aliases.clip cannot work with # dask (meaning uint64 promoting to float64 is going to just be unfixed for # now). -@get_xp(da) def clip( x: Array, /, - min: Optional[Union[int, float, Array]] = None, - max: Optional[Union[int, float, Array]] = None, - *, - xp, + min: float | Array | None = None, + max: float | Array | None = None, ) -> Array: - def _isscalar(a): - return isinstance(a, (int, float, type(None))) + """ + Array API compatibility wrapper for clip(). + + See the corresponding documentation in the array library and/or the array API + specification for more details. + """ + + def _isscalar(a: float | Array | None, /) -> TypeIs[float | None]: + return a is None or isinstance(a, (int, float)) + min_shape = () if _isscalar(min) else min.shape max_shape = () if _isscalar(max) else max.shape # TODO: This won't handle dask unknown shapes - import numpy as np result_shape = np.broadcast_shapes(x.shape, min_shape, max_shape) if min is not None: - min = xp.broadcast_to(xp.asarray(min), result_shape) + min = da.broadcast_to(da.asarray(min), result_shape) if max is not None: - max = xp.broadcast_to(xp.asarray(max), result_shape) + max = da.broadcast_to(da.asarray(max), result_shape) if min is None and max is None: - return xp.positive(x) + return da.positive(x) if min is None: - return astype(xp.minimum(x, max), x.dtype) + return astype(da.minimum(x, max), x.dtype) if max is None: - return astype(xp.maximum(x, min), x.dtype) + return astype(da.maximum(x, min), x.dtype) + + return astype(da.minimum(da.maximum(x, min), max), x.dtype) + + +def _ensure_single_chunk(x: Array, axis: int) -> tuple[Array, Callable[[Array], Array]]: + """ + Make sure that Array is not broken into multiple chunks along axis. + + Returns + ------- + x : Array + The input Array with a single chunk along axis. + restore : Callable[Array, Array] + function to apply to the output to rechunk it back into reasonable chunks + """ + if axis < 0: + axis += x.ndim + if x.numblocks[axis] < 2: + return x, lambda x: x + + # Break chunks on other axes in an attempt to keep chunk size low + x = x.rechunk({i: -1 if i == axis else "auto" for i in range(x.ndim)}) + + # Rather than reconstructing the original chunks, which can be a + # very expensive affair, just break down oversized chunks without + # incurring in any transfers over the network. + # This has the downside of a risk of overchunking if the array is + # then used in operations against other arrays that match the + # original chunking pattern. + return x, lambda x: x.rechunk() + + +def sort( + x: Array, + /, + *, + axis: int = -1, + descending: py_bool = False, + stable: py_bool = True, +) -> Array: + """ + Array API compatibility layer around the lack of sort() in Dask. + + Warnings + -------- + This function temporarily rechunks the array along `axis` to a single chunk. + This can be extremely inefficient and can lead to out-of-memory errors. + + See the corresponding documentation in the array library and/or the array API + specification for more details. + """ + x, restore = _ensure_single_chunk(x, axis) - return astype(xp.minimum(xp.maximum(x, min), max), x.dtype) + meta_xp = array_namespace(x._meta) + x = da.map_blocks( + meta_xp.sort, + x, + axis=axis, + meta=x._meta, + dtype=x.dtype, + descending=descending, + stable=stable, + ) -# exclude these from all since -_da_unsupported = ['sort', 'argsort'] + return restore(x) + + +def argsort( + x: Array, + /, + *, + axis: int = -1, + descending: py_bool = False, + stable: py_bool = True, +) -> Array: + """ + Array API compatibility layer around the lack of argsort() in Dask. + + See the corresponding documentation in the array library and/or the array API + specification for more details. + + Warnings + -------- + This function temporarily rechunks the array along `axis` into a single chunk. + This can be extremely inefficient and can lead to out-of-memory errors. + """ + x, restore = _ensure_single_chunk(x, axis) + + meta_xp = array_namespace(x._meta) + dtype = meta_xp.argsort(x._meta).dtype + meta = meta_xp.astype(x._meta, dtype) + x = da.map_blocks( + meta_xp.argsort, + x, + axis=axis, + meta=meta, + dtype=dtype, + descending=descending, + stable=stable, + ) + + return restore(x) + + +# dask.array.count_nonzero does not have keepdims +def count_nonzero( + x: Array, + axis: int | None = None, + keepdims: py_bool = False, +) -> Array: + result = da.count_nonzero(x, axis) + if keepdims: + if axis is None: + return da.reshape(result, [1] * x.ndim) + return da.expand_dims(result, axis) + return result -common_aliases = [alias for alias in _aliases.__all__ if alias not in _da_unsupported] -__all__ = common_aliases + ['__array_namespace_info__', 'asarray', 'bool', - 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atan2', - 'atanh', 'bitwise_left_shift', 'bitwise_invert', - 'bitwise_right_shift', 'concat', 'pow', 'e', - 'inf', 'nan', 'pi', 'newaxis', 'float32', - 'float64', 'int8', 'int16', 'int32', 'int64', - 'uint8', 'uint16', 'uint32', 'uint64', - 'complex64', 'complex128', 'iinfo', 'finfo', - 'can_cast', 'result_type'] +__all__ = [ + "count_nonzero", + "bool", + "int8", "int16", "int32", "int64", + "uint8", "uint16", "uint32", "uint64", + "float32", "float64", + "complex64", "complex128", + "asarray", "astype", "can_cast", "result_type", + "pow", + "concat", + "acos", "acosh", "asin", "asinh", "atan", "atan2", "atanh", + "bitwise_left_shift", "bitwise_right_shift", "bitwise_invert", +] # fmt: skip +__all__ += _aliases.__all__ -_all_ignore = ['get_xp', 'da', 'partial', 'common_aliases', 'np'] +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/dask/array/_info.py b/array_api_compat/dask/array/_info.py index d3b12dc9..2f39fc4b 100644 --- a/array_api_compat/dask/array/_info.py +++ b/array_api_compat/dask/array/_info.py @@ -7,25 +7,50 @@ more details. """ + +# pyright: reportPrivateUsage=false + +from __future__ import annotations + +from typing import Literal, TypeAlias, overload + +import dask.array as da +from numpy import bool_ as bool from numpy import ( + complex64, + complex128, dtype, - bool_ as bool, - intp, + float32, + float64, int8, int16, int32, int64, + intp, uint8, uint16, uint32, uint64, - float32, - float64, - complex64, - complex128, ) -from ...common._helpers import _DASK_DEVICE +from ...common._helpers import _DASK_DEVICE, _check_device, _dask_device +from ...common._typing import ( + Capabilities, + DefaultDTypes, + DType, + DTypeKind, + DTypesAll, + DTypesAny, + DTypesBool, + DTypesComplex, + DTypesIntegral, + DTypesNumeric, + DTypesReal, + DTypesSigned, + DTypesUnsigned, +) +Device: TypeAlias = Literal["cpu"] | _dask_device + class __array_namespace_info__: """ @@ -50,7 +75,7 @@ class __array_namespace_info__: Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.default_dtypes() {'real floating': dask.float64, 'complex floating': dask.complex128, @@ -59,20 +84,31 @@ class __array_namespace_info__: """ - __module__ = 'dask.array' + __module__ = "dask.array" - def capabilities(self): + def capabilities(self) -> Capabilities: """ Return a dictionary of array API library capabilities. The resulting dictionary has the following keys: - **"boolean indexing"**: boolean indicating whether an array library - supports boolean indexing. Always ``False`` for Dask. + supports boolean indexing. + + Dask support boolean indexing as long as both the index + and the indexed arrays have known shapes. + Note however that the output .shape and .size properties + will contain a non-compliant math.nan instead of None. - **"data-dependent shapes"**: boolean indicating whether an array - library supports data-dependent output shapes. Always ``False`` for - Dask. + library supports data-dependent output shapes. + + Dask implements unique_values et.al. + Note however that the output .shape and .size properties + will contain a non-compliant math.nan instead of None. + + - **"max dimensions"**: integer indicating the maximum number of + dimensions supported by the array library. See https://data-apis.org/array-api/latest/API_specification/generated/array_api.info.capabilities.html @@ -92,20 +128,20 @@ def capabilities(self): Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.capabilities() {'boolean indexing': True, - 'data-dependent shapes': True} + 'data-dependent shapes': True, + 'max dimensions': 64} """ return { - "boolean indexing": False, - "data-dependent shapes": False, - # 'max rank' will be part of the 2024.12 standard - # "max rank": 64, + "boolean indexing": True, + "data-dependent shapes": True, + "max dimensions": 64, } - def default_device(self): + def default_device(self) -> Device: """ The default device used for new Dask arrays. @@ -120,19 +156,19 @@ def default_device(self): Returns ------- - device : str + device : Device The default device used for new Dask arrays. Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.default_device() 'cpu' """ return "cpu" - def default_dtypes(self, *, device=None): + def default_dtypes(self, /, *, device: Device | None = None) -> DefaultDTypes: """ The default data types used for new Dask arrays. @@ -163,7 +199,7 @@ def default_dtypes(self, *, device=None): Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.default_dtypes() {'real floating': dask.float64, 'complex floating': dask.complex128, @@ -171,11 +207,7 @@ def default_dtypes(self, *, device=None): 'indexing': dask.int64} """ - if device not in ["cpu", _DASK_DEVICE, None]: - raise ValueError( - 'Device not understood. Only "cpu" or _DASK_DEVICE is allowed, but received:' - f' {device}' - ) + _check_device(da, device) return { "real floating": dtype(float64), "complex floating": dtype(complex128), @@ -183,7 +215,41 @@ def default_dtypes(self, *, device=None): "indexing": dtype(intp), } - def dtypes(self, *, device=None, kind=None): + @overload + def dtypes( + self, /, *, device: Device | None = None, kind: None = None + ) -> DTypesAll: ... + @overload + def dtypes( + self, /, *, device: Device | None = None, kind: Literal["bool"] + ) -> DTypesBool: ... + @overload + def dtypes( + self, /, *, device: Device | None = None, kind: Literal["signed integer"] + ) -> DTypesSigned: ... + @overload + def dtypes( + self, /, *, device: Device | None = None, kind: Literal["unsigned integer"] + ) -> DTypesUnsigned: ... + @overload + def dtypes( + self, /, *, device: Device | None = None, kind: Literal["integral"] + ) -> DTypesIntegral: ... + @overload + def dtypes( + self, /, *, device: Device | None = None, kind: Literal["real floating"] + ) -> DTypesReal: ... + @overload + def dtypes( + self, /, *, device: Device | None = None, kind: Literal["complex floating"] + ) -> DTypesComplex: ... + @overload + def dtypes( + self, /, *, device: Device | None = None, kind: Literal["numeric"] + ) -> DTypesNumeric: ... + def dtypes( + self, /, *, device: Device | None = None, kind: DTypeKind | None = None + ) -> DTypesAny: """ The array API data types supported by Dask. @@ -229,7 +295,7 @@ def dtypes(self, *, device=None, kind=None): Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.dtypes(kind='signed integer') {'int8': dask.int8, 'int16': dask.int16, @@ -237,11 +303,7 @@ def dtypes(self, *, device=None, kind=None): 'int64': dask.int64} """ - if device not in ["cpu", _DASK_DEVICE, None]: - raise ValueError( - 'Device not understood. Only "cpu" or _DASK_DEVICE is allowed, but received:' - f' {device}' - ) + _check_device(da, device) if kind is None: return { "bool": dtype(bool), @@ -311,13 +373,13 @@ def dtypes(self, *, device=None, kind=None): "complex128": dtype(complex128), } if isinstance(kind, tuple): - res = {} + res: dict[str, DType] = {} for k in kind: res.update(self.dtypes(kind=k)) return res raise ValueError(f"unsupported kind: {kind!r}") - def devices(self): + def devices(self) -> list[Device]: """ The devices supported by Dask. @@ -325,7 +387,7 @@ def devices(self): Returns ------- - devices : list of str + devices : list[Device] The devices supported by Dask. See Also @@ -337,7 +399,7 @@ def devices(self): Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.devices() ['cpu', DASK_DEVICE] diff --git a/array_api_compat/dask/array/fft.py b/array_api_compat/dask/array/fft.py new file mode 100644 index 00000000..44b68e73 --- /dev/null +++ b/array_api_compat/dask/array/fft.py @@ -0,0 +1,16 @@ +from ..._internal import clone_module + +__all__ = clone_module("dask.array.fft", globals()) + +from ...common import _fft +from ..._internal import get_xp + +import dask.array as da + +fftfreq = get_xp(da)(_fft.fftfreq) +rfftfreq = get_xp(da)(_fft.rfftfreq) + +__all__ += ["fftfreq", "rfftfreq"] + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/dask/array/linalg.py b/array_api_compat/dask/array/linalg.py index 49c26d8b..6b3c1011 100644 --- a/array_api_compat/dask/array/linalg.py +++ b/array_api_compat/dask/array/linalg.py @@ -1,33 +1,20 @@ from __future__ import annotations -from ...common import _linalg -from ..._internal import get_xp +from typing import Literal -# Exports -from dask.array.linalg import * # noqa: F403 -from dask.array import outer +import dask.array as da -# These functions are in both the main and linalg namespaces -from dask.array import matmul, tensordot -from ._aliases import matrix_transpose, vecdot +# The `matmul` and `tensordot` functions are in both the main and linalg namespaces +from dask.array import matmul, outer, tensordot -import dask.array as da +# Exports +from ..._internal import clone_module, get_xp +from ...common import _linalg +from ...common._typing import Array -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from ...common._typing import Array - from typing import Literal +__all__ = clone_module("dask.array.linalg", globals()) -# dask.array.linalg doesn't have __all__. If it is added, replace this with -# -# from dask.array.linalg import __all__ as linalg_all -_n = {} -exec('from dask.array.linalg import *', _n) -del _n['__builtins__'] -if 'annotations' in _n: - del _n['annotations'] -linalg_all = list(_n) -del _n +from ._aliases import matrix_transpose, vecdot EighResult = _linalg.EighResult QRResult = _linalg.QRResult @@ -37,8 +24,11 @@ # supports the mode keyword on QR # https://github.com/dask/dask/issues/10388 #qr = get_xp(da)(_linalg.qr) -def qr(x: Array, mode: Literal['reduced', 'complete'] = 'reduced', - **kwargs) -> QRResult: +def qr( # type: ignore[no-redef] + x: Array, + mode: Literal["reduced", "complete"] = "reduced", + **kwargs: object, +) -> QRResult: if mode != "reduced": raise ValueError("dask arrays only support using mode='reduced'") return QRResult(*da.linalg.qr(x, **kwargs)) @@ -51,7 +41,7 @@ def qr(x: Array, mode: Literal['reduced', 'complete'] = 'reduced', # Wrap the svd functions to not pass full_matrices to dask # when full_matrices=False (as that is the default behavior for dask), # and dask doesn't have the full_matrices keyword -def svd(x: Array, full_matrices: bool = True, **kwargs) -> SVDResult: +def svd(x: Array, full_matrices: bool = True, **kwargs: object) -> SVDResult: # type: ignore[no-redef] if full_matrices: raise ValueError("full_matrics=True is not supported by dask.") return da.linalg.svd(x, coerce_signs=False, **kwargs) @@ -64,10 +54,11 @@ def svdvals(x: Array) -> Array: vector_norm = get_xp(da)(_linalg.vector_norm) diagonal = get_xp(da)(_linalg.diagonal) -__all__ = linalg_all + ["trace", "outer", "matmul", "tensordot", - "matrix_transpose", "vecdot", "EighResult", - "QRResult", "SlogdetResult", "SVDResult", "qr", - "cholesky", "matrix_rank", "matrix_norm", "svdvals", - "vector_norm", "diagonal"] +__all__ += ["trace", "outer", "matmul", "tensordot", + "matrix_transpose", "vecdot", "EighResult", + "QRResult", "SlogdetResult", "SVDResult", "qr", + "cholesky", "matrix_rank", "matrix_norm", "svdvals", + "vector_norm", "diagonal"] -_all_ignore = ['get_xp', 'da', 'linalg_all'] +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/numpy/__init__.py b/array_api_compat/numpy/__init__.py index b66f30a2..23379e44 100644 --- a/array_api_compat/numpy/__init__.py +++ b/array_api_compat/numpy/__init__.py @@ -1,10 +1,17 @@ -from numpy import * # noqa: F403 +# ruff: noqa: PLC0414 +from typing import Final -# from numpy import * doesn't overwrite these builtin names -from numpy import abs, max, min, round # noqa: F401 +from .._internal import clone_module + +# This needs to be loaded explicitly before cloning +import numpy.typing # noqa: F401 + +__all__ = clone_module("numpy", globals()) # These imports may overwrite names from the import * above. -from ._aliases import * # noqa: F403 +from . import _aliases +from ._aliases import * # type: ignore[assignment,no-redef] # noqa: F403 +from ._info import __array_namespace_info__ # noqa: F401 # Don't know why, but we have to do an absolute import to import linalg. If we # instead do @@ -13,18 +20,19 @@ # # It doesn't overwrite np.linalg from above. The import is generated # dynamically so that the library can be vendored. -__import__(__package__ + '.linalg') +__import__(__package__ + ".linalg") -__import__(__package__ + '.fft') +__import__(__package__ + ".fft") -from .linalg import matrix_transpose, vecdot # noqa: F401 +from .linalg import matrix_transpose, vecdot # type: ignore[no-redef] # noqa: F401 -from ..common._helpers import * # noqa: F403 +__array_api_version__: Final = "2024.12" -try: - # Used in asarray(). Not present in older versions. - from numpy import _CopyMode # noqa: F401 -except ImportError: - pass +__all__ = sorted( + set(__all__) + | set(_aliases.__all__) + | {"__array_api_version__", "__array_namespace_info__", "linalg", "fft"} +) -__array_api_version__ = '2022.12' +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index 355215e4..87b3c2f3 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -1,17 +1,16 @@ +# pyright: reportPrivateUsage=false from __future__ import annotations -from ..common import _aliases +from builtins import bool as py_bool +from typing import Any, cast -from .._internal import get_xp - -from ._info import __array_namespace_info__ +import numpy as np -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from typing import Optional, Union - from ._typing import ndarray, Device, Dtype, NestedSequence, SupportsBufferProtocol +from .._internal import get_xp +from ..common import _aliases, _helpers +from ..common._typing import NestedSequence, SupportsBufferProtocol +from ._typing import Array, Device, DType -import numpy as np bool = np.bool_ # Basic renames @@ -46,95 +45,147 @@ unique_counts = get_xp(np)(_aliases.unique_counts) unique_inverse = get_xp(np)(_aliases.unique_inverse) unique_values = get_xp(np)(_aliases.unique_values) -astype = _aliases.astype std = get_xp(np)(_aliases.std) var = get_xp(np)(_aliases.var) cumulative_sum = get_xp(np)(_aliases.cumulative_sum) +cumulative_prod = get_xp(np)(_aliases.cumulative_prod) clip = get_xp(np)(_aliases.clip) permute_dims = get_xp(np)(_aliases.permute_dims) reshape = get_xp(np)(_aliases.reshape) argsort = get_xp(np)(_aliases.argsort) sort = get_xp(np)(_aliases.sort) nonzero = get_xp(np)(_aliases.nonzero) -ceil = get_xp(np)(_aliases.ceil) -floor = get_xp(np)(_aliases.floor) -trunc = get_xp(np)(_aliases.trunc) matmul = get_xp(np)(_aliases.matmul) matrix_transpose = get_xp(np)(_aliases.matrix_transpose) tensordot = get_xp(np)(_aliases.tensordot) +sign = get_xp(np)(_aliases.sign) +finfo = get_xp(np)(_aliases.finfo) +iinfo = get_xp(np)(_aliases.iinfo) -def _supports_buffer_protocol(obj): - try: - memoryview(obj) - except TypeError: - return False - return True # asarray also adds the copy keyword, which is not present in numpy 1.0. # asarray() is different enough between numpy, cupy, and dask, the logic # complicated enough that it's easier to define it separately for each module # rather than trying to combine everything into one function in common/ def asarray( - obj: Union[ - ndarray, - bool, - int, - float, - NestedSequence[bool | int | float], - SupportsBufferProtocol, - ], + obj: Array | complex | NestedSequence[complex] | SupportsBufferProtocol, /, *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, - copy: "Optional[Union[bool, np._CopyMode]]" = None, - **kwargs, -) -> ndarray: + dtype: DType | None = None, + device: Device | None = None, + copy: py_bool | None = None, + **kwargs: Any, +) -> Array: """ Array API compatibility wrapper for asarray(). See the corresponding documentation in the array library and/or the array API specification for more details. """ - if device not in ["cpu", None]: - raise ValueError(f"Unsupported device for NumPy: {device!r}") - - if hasattr(np, '_CopyMode'): - if copy is None: - copy = np._CopyMode.IF_NEEDED - elif copy is False: - copy = np._CopyMode.NEVER - elif copy is True: - copy = np._CopyMode.ALWAYS - else: - # Not present in older NumPys. In this case, we cannot really support - # copy=False. - if copy is False: - raise NotImplementedError("asarray(copy=False) requires a newer version of NumPy.") + _helpers._check_device(np, device) + + # None is unsupported in NumPy 1.0, but we can use an internal enum + # False in NumPy 1.0 means None in NumPy 2.0 and in the Array API + if copy is None: + copy = np._CopyMode.IF_NEEDED # type: ignore[assignment,attr-defined] + elif copy is False: + copy = np._CopyMode.NEVER # type: ignore[assignment,attr-defined] return np.array(obj, copy=copy, dtype=dtype, **kwargs) + +def astype( + x: Array, + dtype: DType, + /, + *, + copy: py_bool = True, + device: Device | None = None, +) -> Array: + _helpers._check_device(np, device) + return x.astype(dtype=dtype, copy=copy) + + +# count_nonzero returns a python int for axis=None and keepdims=False +# https://github.com/numpy/numpy/issues/17562 +def count_nonzero( + x: Array, + axis: int | tuple[int, ...] | None = None, + keepdims: py_bool = False, +) -> Array: + # NOTE: this is currently incorrectly typed in numpy, but will be fixed in + # numpy 2.2.5 and 2.3.0: https://github.com/numpy/numpy/pull/28750 + result = cast("Any", np.count_nonzero(x, axis=axis, keepdims=keepdims)) # pyright: ignore[reportArgumentType, reportCallIssue] + if axis is None and not keepdims: + return np.asarray(result) + return result + + +# take_along_axis: axis defaults to -1 but in numpy axis is a required arg +def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array: + return np.take_along_axis(x, indices, axis=axis) + + +# ceil, floor, and trunc return integers for integer inputs in NumPy < 2 + +def ceil(x: Array, /) -> Array: + if np.__version__ < '2' and np.issubdtype(x.dtype, np.integer): + return x.copy() + return np.ceil(x) + + +def floor(x: Array, /) -> Array: + if np.__version__ < '2' and np.issubdtype(x.dtype, np.integer): + return x.copy() + return np.floor(x) + + +def trunc(x: Array, /) -> Array: + if np.__version__ < '2' and np.issubdtype(x.dtype, np.integer): + return x.copy() + return np.trunc(x) + + # These functions are completely new here. If the library already has them # (i.e., numpy 2.0), use the library version instead of our wrapper. -if hasattr(np, 'vecdot'): +if hasattr(np, "vecdot"): vecdot = np.vecdot else: - vecdot = get_xp(np)(_aliases.vecdot) + vecdot = get_xp(np)(_aliases.vecdot) # type: ignore[assignment] -if hasattr(np, 'isdtype'): +if hasattr(np, "isdtype"): isdtype = np.isdtype else: isdtype = get_xp(np)(_aliases.isdtype) -if hasattr(np, 'unstack'): +if hasattr(np, "unstack"): unstack = np.unstack else: unstack = get_xp(np)(_aliases.unstack) -__all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', 'bool', - 'acos', 'acosh', 'asin', 'asinh', 'atan', - 'atan2', 'atanh', 'bitwise_left_shift', - 'bitwise_invert', 'bitwise_right_shift', - 'concat', 'pow'] - -_all_ignore = ['np', 'get_xp'] +__all__ = _aliases.__all__ + [ + "asarray", + "astype", + "acos", + "acosh", + "asin", + "asinh", + "atan", + "atan2", + "atanh", + "ceil", + "floor", + "trunc", + "bitwise_left_shift", + "bitwise_invert", + "bitwise_right_shift", + "bool", + "concat", + "count_nonzero", + "pow", + "take_along_axis" +] + + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/numpy/_info.py b/array_api_compat/numpy/_info.py index 62f7ae62..c625c13e 100644 --- a/array_api_compat/numpy/_info.py +++ b/array_api_compat/numpy/_info.py @@ -7,24 +7,29 @@ more details. """ +from __future__ import annotations + +from numpy import bool_ as bool from numpy import ( + complex64, + complex128, dtype, - bool_ as bool, - intp, + float32, + float64, int8, int16, int32, int64, + intp, uint8, uint16, uint32, uint64, - float32, - float64, - complex64, - complex128, ) +from ..common._typing import DefaultDTypes +from ._typing import Device, DType + class __array_namespace_info__: """ @@ -94,14 +99,14 @@ def capabilities(self): >>> info = np.__array_namespace_info__() >>> info.capabilities() {'boolean indexing': True, - 'data-dependent shapes': True} + 'data-dependent shapes': True, + 'max dimensions': 64} """ return { "boolean indexing": True, "data-dependent shapes": True, - # 'max rank' will be part of the 2024.12 standard - # "max rank": 64, + "max dimensions": 64, } def default_device(self): @@ -119,7 +124,7 @@ def default_device(self): Returns ------- - device : str + device : Device The default device used for new NumPy arrays. Examples @@ -131,7 +136,11 @@ def default_device(self): """ return "cpu" - def default_dtypes(self, *, device=None): + def default_dtypes( + self, + *, + device: Device | None = None, + ) -> DefaultDTypes: """ The default data types used for new NumPy arrays. @@ -183,7 +192,12 @@ def default_dtypes(self, *, device=None): "indexing": dtype(intp), } - def dtypes(self, *, device=None, kind=None): + def dtypes( + self, + *, + device: Device | None = None, + kind: str | tuple[str, ...] | None = None, + ) -> dict[str, DType]: """ The array API data types supported by NumPy. @@ -260,7 +274,7 @@ def dtypes(self, *, device=None, kind=None): "complex128": dtype(complex128), } if kind == "bool": - return {"bool": bool} + return {"bool": dtype(bool)} if kind == "signed integer": return { "int8": dtype(int8), @@ -312,13 +326,13 @@ def dtypes(self, *, device=None, kind=None): "complex128": dtype(complex128), } if isinstance(kind, tuple): - res = {} + res: dict[str, DType] = {} for k in kind: res.update(self.dtypes(kind=k)) return res raise ValueError(f"unsupported kind: {kind!r}") - def devices(self): + def devices(self) -> list[Device]: """ The devices supported by NumPy. @@ -326,7 +340,7 @@ def devices(self): Returns ------- - devices : list of str + devices : list[Device] The devices supported by NumPy. See Also @@ -344,3 +358,10 @@ def devices(self): """ return ["cpu"] + + +__all__ = ["__array_namespace_info__"] + + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/numpy/_typing.py b/array_api_compat/numpy/_typing.py index c5ebb5ab..b5fa188c 100644 --- a/array_api_compat/numpy/_typing.py +++ b/array_api_compat/numpy/_typing.py @@ -1,46 +1,29 @@ from __future__ import annotations -__all__ = [ - "ndarray", - "Device", - "Dtype", -] - -import sys -from typing import ( - Literal, - Union, - TYPE_CHECKING, -) - -from numpy import ( - ndarray, - dtype, - int8, - int16, - int32, - int64, - uint8, - uint16, - uint32, - uint64, - float32, - float64, -) - -Device = Literal["cpu"] -if TYPE_CHECKING or sys.version_info >= (3, 9): - Dtype = dtype[Union[ - int8, - int16, - int32, - int64, - uint8, - uint16, - uint32, - uint64, - float32, - float64, - ]] +from typing import TYPE_CHECKING, Any, Literal, TypeAlias + +import numpy as np + +Device: TypeAlias = Literal["cpu"] + +if TYPE_CHECKING: + + # NumPy 1.x on Python 3.10 fails to parse np.dtype[] + DType: TypeAlias = np.dtype[ + np.bool_ + | np.integer[Any] + | np.float32 + | np.float64 + | np.complex64 + | np.complex128 + ] + Array: TypeAlias = np.ndarray[Any, DType] else: - Dtype = dtype + DType: TypeAlias = np.dtype + Array: TypeAlias = np.ndarray + +__all__ = ["Array", "DType", "Device"] + + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/numpy/fft.py b/array_api_compat/numpy/fft.py index 28667594..a492feb8 100644 --- a/array_api_compat/numpy/fft.py +++ b/array_api_compat/numpy/fft.py @@ -1,10 +1,11 @@ -from numpy.fft import * # noqa: F403 -from numpy.fft import __all__ as fft_all +import numpy as np -from ..common import _fft -from .._internal import get_xp +from .._internal import clone_module -import numpy as np +__all__ = clone_module("numpy.fft", globals()) + +from .._internal import get_xp +from ..common import _fft fft = get_xp(np)(_fft.fft) ifft = get_xp(np)(_fft.ifft) @@ -21,9 +22,9 @@ fftshift = get_xp(np)(_fft.fftshift) ifftshift = get_xp(np)(_fft.ifftshift) -__all__ = fft_all + _fft.__all__ -del get_xp -del np -del fft_all -del _fft +__all__ = sorted(set(__all__) | set(_fft.__all__)) + +def __dir__() -> list[str]: + return __all__ + diff --git a/array_api_compat/numpy/linalg.py b/array_api_compat/numpy/linalg.py index 8f01593b..7168441c 100644 --- a/array_api_compat/numpy/linalg.py +++ b/array_api_compat/numpy/linalg.py @@ -1,14 +1,20 @@ -from numpy.linalg import * # noqa: F403 -from numpy.linalg import __all__ as linalg_all -import numpy as _np +# pyright: reportAttributeAccessIssue=false +# pyright: reportUnknownArgumentType=false +# pyright: reportUnknownMemberType=false +# pyright: reportUnknownVariableType=false +from __future__ import annotations + +import numpy as np + +from .._internal import clone_module, get_xp from ..common import _linalg -from .._internal import get_xp -# These functions are in both the main and linalg namespaces -from ._aliases import matmul, matrix_transpose, tensordot, vecdot # noqa: F401 +__all__ = clone_module("numpy.linalg", globals()) -import numpy as np +# These functions are in both the main and linalg namespaces +from ._aliases import matmul, matrix_transpose, tensordot, vecdot # noqa: F401 +from ._typing import Array cross = get_xp(np)(_linalg.cross) outer = get_xp(np)(_linalg.outer) @@ -38,19 +44,28 @@ # To workaround this, the below is the code from np.linalg.solve except # only calling solve1 in the exactly 1D case. + # This code is here instead of in common because it is numpy specific. Also # note that CuPy's solve() does not currently support broadcasting (see # https://github.com/cupy/cupy/blob/main/cupy/cublas.py#L43). -def solve(x1: _np.ndarray, x2: _np.ndarray, /) -> _np.ndarray: +def solve(x1: Array, x2: Array, /) -> Array: try: - from numpy.linalg._linalg import ( - _makearray, _assert_stacked_2d, _assert_stacked_square, - _commonType, isComplexType, _raise_linalgerror_singular + from numpy.linalg._linalg import ( # type: ignore[attr-defined] + _assert_stacked_2d, + _assert_stacked_square, + _commonType, + _makearray, + _raise_linalgerror_singular, + isComplexType, ) except ImportError: - from numpy.linalg.linalg import ( - _makearray, _assert_stacked_2d, _assert_stacked_square, - _commonType, isComplexType, _raise_linalgerror_singular + from numpy.linalg.linalg import ( # type: ignore[attr-defined] + _assert_stacked_2d, + _assert_stacked_square, + _commonType, + _makearray, + _raise_linalgerror_singular, + isComplexType, ) from numpy.linalg import _umath_linalg @@ -61,6 +76,7 @@ def solve(x1: _np.ndarray, x2: _np.ndarray, /) -> _np.ndarray: t, result_t = _commonType(x1, x2) # This part is different from np.linalg.solve + gufunc: np.ufunc if x2.ndim == 1: gufunc = _umath_linalg.solve1 else: @@ -68,23 +84,45 @@ def solve(x1: _np.ndarray, x2: _np.ndarray, /) -> _np.ndarray: # This does nothing currently but is left in because it will be relevant # when complex dtype support is added to the spec in 2022. - signature = 'DD->D' if isComplexType(t) else 'dd->d' - with _np.errstate(call=_raise_linalgerror_singular, invalid='call', - over='ignore', divide='ignore', under='ignore'): - r = gufunc(x1, x2, signature=signature) + signature = "DD->D" if isComplexType(t) else "dd->d" + with np.errstate( + call=_raise_linalgerror_singular, + invalid="call", + over="ignore", + divide="ignore", + under="ignore", + ): + r: Array = gufunc(x1, x2, signature=signature) return wrap(r.astype(result_t, copy=False)) + # These functions are completely new here. If the library already has them # (i.e., numpy 2.0), use the library version instead of our wrapper. -if hasattr(np.linalg, 'vector_norm'): +if hasattr(np.linalg, "vector_norm"): vector_norm = np.linalg.vector_norm else: vector_norm = get_xp(np)(_linalg.vector_norm) -__all__ = linalg_all + _linalg.__all__ + ['solve'] -del get_xp -del np -del linalg_all -del _linalg +_all = [ + "LinAlgError", + "cond", + "det", + "eig", + "eigvals", + "eigvalsh", + "inv", + "lstsq", + "matrix_power", + "multi_dot", + "norm", + "solve", + "tensorinv", + "tensorsolve", + "vector_norm", +] +__all__ = sorted(set(__all__) | set(_linalg.__all__) | set(_all)) + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/py.typed b/array_api_compat/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/array_api_compat/torch/__init__.py b/array_api_compat/torch/__init__.py index 172f5279..6cbb6ec2 100644 --- a/array_api_compat/torch/__init__.py +++ b/array_api_compat/torch/__init__.py @@ -1,24 +1,25 @@ -from torch import * # noqa: F403 +from typing import Final -# Several names are not included in the above import * -import torch -for n in dir(torch): - if (n.startswith('_') - or n.endswith('_') - or 'cuda' in n - or 'cpu' in n - or 'backward' in n): - continue - exec(n + ' = torch.' + n) +from .._internal import clone_module + +__all__ = clone_module("torch", globals()) # These imports may overwrite names from the import * above. +from . import _aliases from ._aliases import * # noqa: F403 +from ._info import __array_namespace_info__ # noqa: F401 # See the comment in the numpy __init__.py __import__(__package__ + '.linalg') - __import__(__package__ + '.fft') -from ..common._helpers import * # noqa: F403 +__array_api_version__: Final = '2024.12' + +__all__ = sorted( + set(__all__) + | set(_aliases.__all__) + | {"__array_api_version__", "__array_namespace_info__", "linalg", "fft"} +) -__array_api_version__ = '2022.12' +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 5ac66bcb..91161d24 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -1,27 +1,16 @@ from __future__ import annotations -from functools import wraps as _wraps +from collections.abc import Sequence +from functools import reduce as _reduce, wraps as _wraps from builtins import all as _builtin_all, any as _builtin_any - -from ..common._aliases import (matrix_transpose as _aliases_matrix_transpose, - vecdot as _aliases_vecdot, - clip as _aliases_clip, - unstack as _aliases_unstack, - cumulative_sum as _aliases_cumulative_sum, - ) -from .._internal import get_xp - -from ._info import __array_namespace_info__ +from typing import Any, Literal import torch -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from typing import List, Optional, Sequence, Tuple, Union - from ..common._typing import Device - from torch import dtype as Dtype - - array = torch.Tensor +from .._internal import get_xp +from ..common import _aliases +from ..common._typing import NestedSequence, SupportsBufferProtocol +from ._typing import Array, Device, DType _int_dtypes = { torch.uint8, @@ -30,6 +19,12 @@ torch.int32, torch.int64, } +try: + # torch >=2.3 + _int_dtypes |= {torch.uint16, torch.uint32, torch.uint64} +except AttributeError: + pass + _array_api_dtypes = { torch.bool, @@ -40,47 +35,23 @@ torch.complex128, } -_promotion_table = { - # bool - (torch.bool, torch.bool): torch.bool, +_promotion_table = { # ints - (torch.int8, torch.int8): torch.int8, (torch.int8, torch.int16): torch.int16, (torch.int8, torch.int32): torch.int32, (torch.int8, torch.int64): torch.int64, - (torch.int16, torch.int8): torch.int16, - (torch.int16, torch.int16): torch.int16, (torch.int16, torch.int32): torch.int32, (torch.int16, torch.int64): torch.int64, - (torch.int32, torch.int8): torch.int32, - (torch.int32, torch.int16): torch.int32, - (torch.int32, torch.int32): torch.int32, (torch.int32, torch.int64): torch.int64, - (torch.int64, torch.int8): torch.int64, - (torch.int64, torch.int16): torch.int64, - (torch.int64, torch.int32): torch.int64, - (torch.int64, torch.int64): torch.int64, - # uints - (torch.uint8, torch.uint8): torch.uint8, # ints and uints (mixed sign) - (torch.int8, torch.uint8): torch.int16, - (torch.int16, torch.uint8): torch.int16, - (torch.int32, torch.uint8): torch.int32, - (torch.int64, torch.uint8): torch.int64, (torch.uint8, torch.int8): torch.int16, (torch.uint8, torch.int16): torch.int16, (torch.uint8, torch.int32): torch.int32, (torch.uint8, torch.int64): torch.int64, # floats - (torch.float32, torch.float32): torch.float32, (torch.float32, torch.float64): torch.float64, - (torch.float64, torch.float32): torch.float64, - (torch.float64, torch.float64): torch.float64, # complexes - (torch.complex64, torch.complex64): torch.complex64, (torch.complex64, torch.complex128): torch.complex128, - (torch.complex128, torch.complex64): torch.complex128, - (torch.complex128, torch.complex128): torch.complex128, # Mixed float and complex (torch.float32, torch.complex64): torch.complex64, (torch.float32, torch.complex128): torch.complex128, @@ -88,6 +59,9 @@ (torch.float64, torch.complex128): torch.complex128, } +_promotion_table.update({(b, a): c for (a, b), c in _promotion_table.items()}) +_promotion_table.update({(a, a): a for a in _array_api_dtypes}) + def _two_arg(f): @_wraps(f) @@ -118,23 +92,50 @@ def _fix_promotion(x1, x2, only_scalar=True): x1 = x1.to(dtype) return x1, x2 -def result_type(*arrays_and_dtypes: Union[array, Dtype]) -> Dtype: - if len(arrays_and_dtypes) == 0: - raise TypeError("At least one array or dtype must be provided") - if len(arrays_and_dtypes) == 1: + +_py_scalars = (bool, int, float, complex) + + +def result_type(*arrays_and_dtypes: Array | DType | complex) -> DType: + num = len(arrays_and_dtypes) + + if num == 0: + raise ValueError("At least one array or dtype must be provided") + + elif num == 1: x = arrays_and_dtypes[0] if isinstance(x, torch.dtype): return x return x.dtype - if len(arrays_and_dtypes) > 2: - return result_type(arrays_and_dtypes[0], result_type(*arrays_and_dtypes[1:])) - x, y = arrays_and_dtypes - xdt = x.dtype if not isinstance(x, torch.dtype) else x - ydt = y.dtype if not isinstance(y, torch.dtype) else y + if num == 2: + x, y = arrays_and_dtypes + return _result_type(x, y) + + else: + # sort scalars so that they are treated last + scalars, others = [], [] + for x in arrays_and_dtypes: + if isinstance(x, _py_scalars): + scalars.append(x) + else: + others.append(x) + if not others: + raise ValueError("At least one array or dtype must be provided") + + # combine left-to-right + return _reduce(_result_type, others + scalars) + + +def _result_type(x: Array | DType | complex, y: Array | DType | complex) -> DType: + if not (isinstance(x, _py_scalars) or isinstance(y, _py_scalars)): + xdt = x if isinstance(x, torch.dtype) else x.dtype + ydt = y if isinstance(y, torch.dtype) else y.dtype - if (xdt, ydt) in _promotion_table: - return _promotion_table[xdt, ydt] + try: + return _promotion_table[xdt, ydt] + except KeyError: + pass # This doesn't result_type(dtype, dtype) for non-array API dtypes # because torch.result_type only accepts tensors. This does however, allow @@ -143,7 +144,8 @@ def result_type(*arrays_and_dtypes: Union[array, Dtype]) -> Dtype: y = torch.tensor([], dtype=y) if isinstance(y, torch.dtype) else y return torch.result_type(x, y) -def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool: + +def can_cast(from_: DType | Array, to: DType, /) -> bool: if not isinstance(from_, torch.dtype): from_ = from_.dtype return torch.can_cast(from_, to) @@ -185,29 +187,58 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool: remainder = _two_arg(torch.remainder) subtract = _two_arg(torch.subtract) + +def asarray( + obj: Array | complex | NestedSequence[complex] | SupportsBufferProtocol, + /, + *, + dtype: DType | None = None, + device: Device | None = None, + copy: bool | None = None, + **kwargs: Any, +) -> Array: + # torch.asarray does not respect input->output device propagation + # https://github.com/pytorch/pytorch/issues/150199 + if device is None and isinstance(obj, torch.Tensor): + device = obj.device + return torch.asarray(obj, dtype=dtype, device=device, copy=copy, **kwargs) + + # These wrappers are mostly based on the fact that pytorch uses 'dim' instead # of 'axis'. # torch.min and torch.max return a tuple and don't support multiple axes https://github.com/pytorch/pytorch/issues/58745 -def max(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array: +def max(x: Array, /, *, axis: int | tuple[int, ...] | None = None, keepdims: bool = False) -> Array: # https://github.com/pytorch/pytorch/issues/29137 if axis == (): return torch.clone(x) return torch.amax(x, axis, keepdims=keepdims) -def min(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array: +def min(x: Array, /, *, axis: int | tuple[int, ...] |None = None, keepdims: bool = False) -> Array: # https://github.com/pytorch/pytorch/issues/29137 if axis == (): return torch.clone(x) return torch.amin(x, axis, keepdims=keepdims) -clip = get_xp(torch)(_aliases_clip) -unstack = get_xp(torch)(_aliases_unstack) -cumulative_sum = get_xp(torch)(_aliases_cumulative_sum) +clip = get_xp(torch)(_aliases.clip) +unstack = get_xp(torch)(_aliases.unstack) +cumulative_sum = get_xp(torch)(_aliases.cumulative_sum) +cumulative_prod = get_xp(torch)(_aliases.cumulative_prod) +finfo = get_xp(torch)(_aliases.finfo) +iinfo = get_xp(torch)(_aliases.iinfo) + # torch.sort also returns a tuple # https://github.com/pytorch/pytorch/issues/70921 -def sort(x: array, /, *, axis: int = -1, descending: bool = False, stable: bool = True, **kwargs) -> array: +def sort( + x: Array, + /, + *, + axis: int = -1, + descending: bool = False, + stable: bool = True, + **kwargs: object, +) -> Array: return torch.sort(x, dim=axis, descending=descending, stable=stable, **kwargs).values def _normalize_axes(axis, ndim): @@ -252,28 +283,35 @@ def _reduce_multiple_axes(f, x, axis, keepdims=False, **kwargs): out = torch.unsqueeze(out, a) return out -def prod(x: array, + +def _sum_prod_no_axis(x: Array, dtype: DType | None) -> Array: + """ + Implements `sum(..., axis=())` and `prod(..., axis=())`. + + Works around https://github.com/pytorch/pytorch/issues/29137 + """ + if dtype is not None: + return x.clone() if dtype == x.dtype else x.to(dtype) + + # We can't upcast uint8 according to the spec because there is no + # torch.uint64, so at least upcast to int64 which is what prod does + # when axis=None. + if x.dtype in (torch.uint8, torch.int8, torch.int16, torch.int32): + return x.to(torch.int64) + + return x.clone() + + +def prod(x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - dtype: Optional[Dtype] = None, + axis: int | tuple[int, ...] | None = None, + dtype: DType | None = None, keepdims: bool = False, - **kwargs) -> array: - x = torch.asarray(x) - ndim = x.ndim + **kwargs: object) -> Array: - # https://github.com/pytorch/pytorch/issues/29137. Separate from the logic - # below because it still needs to upcast. if axis == (): - if dtype is None: - # We can't upcast uint8 according to the spec because there is no - # torch.uint64, so at least upcast to int64 which is what sum does - # when axis=None. - if x.dtype in [torch.int8, torch.int16, torch.int32, torch.uint8]: - return x.to(torch.int64) - return x.clone() - return x.to(dtype) - + return _sum_prod_no_axis(x, dtype) # torch.prod doesn't support multiple axes # (https://github.com/pytorch/pytorch/issues/56586). if isinstance(axis, tuple): @@ -282,51 +320,38 @@ def prod(x: array, # torch doesn't support keepdims with axis=None # (https://github.com/pytorch/pytorch/issues/71209) res = torch.prod(x, dtype=dtype, **kwargs) - res = _axis_none_keepdims(res, ndim, keepdims) + res = _axis_none_keepdims(res, x.ndim, keepdims) return res return torch.prod(x, axis, dtype=dtype, keepdims=keepdims, **kwargs) -def sum(x: array, +def sum(x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - dtype: Optional[Dtype] = None, + axis: int | tuple[int, ...] | None = None, + dtype: DType | None = None, keepdims: bool = False, - **kwargs) -> array: - x = torch.asarray(x) - ndim = x.ndim + **kwargs: object) -> Array: - # https://github.com/pytorch/pytorch/issues/29137. - # Make sure it upcasts. if axis == (): - if dtype is None: - # We can't upcast uint8 according to the spec because there is no - # torch.uint64, so at least upcast to int64 which is what sum does - # when axis=None. - if x.dtype in [torch.int8, torch.int16, torch.int32, torch.uint8]: - return x.to(torch.int64) - return x.clone() - return x.to(dtype) - + return _sum_prod_no_axis(x, dtype) if axis is None: # torch doesn't support keepdims with axis=None # (https://github.com/pytorch/pytorch/issues/71209) res = torch.sum(x, dtype=dtype, **kwargs) - res = _axis_none_keepdims(res, ndim, keepdims) + res = _axis_none_keepdims(res, x.ndim, keepdims) return res return torch.sum(x, axis, dtype=dtype, keepdims=keepdims, **kwargs) -def any(x: array, +def any(x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, + axis: int | tuple[int, ...] | None = None, keepdims: bool = False, - **kwargs) -> array: - x = torch.asarray(x) - ndim = x.ndim + **kwargs: object) -> Array: + if axis == (): return x.to(torch.bool) # torch.any doesn't support multiple axes @@ -338,20 +363,19 @@ def any(x: array, # torch doesn't support keepdims with axis=None # (https://github.com/pytorch/pytorch/issues/71209) res = torch.any(x, **kwargs) - res = _axis_none_keepdims(res, ndim, keepdims) + res = _axis_none_keepdims(res, x.ndim, keepdims) return res.to(torch.bool) # torch.any doesn't return bool for uint8 return torch.any(x, axis, keepdims=keepdims).to(torch.bool) -def all(x: array, +def all(x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, + axis: int | tuple[int, ...] | None = None, keepdims: bool = False, - **kwargs) -> array: - x = torch.asarray(x) - ndim = x.ndim + **kwargs: object) -> Array: + if axis == (): return x.to(torch.bool) # torch.all doesn't support multiple axes @@ -363,18 +387,18 @@ def all(x: array, # torch doesn't support keepdims with axis=None # (https://github.com/pytorch/pytorch/issues/71209) res = torch.all(x, **kwargs) - res = _axis_none_keepdims(res, ndim, keepdims) + res = _axis_none_keepdims(res, x.ndim, keepdims) return res.to(torch.bool) # torch.all doesn't return bool for uint8 return torch.all(x, axis, keepdims=keepdims).to(torch.bool) -def mean(x: array, +def mean(x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, + axis: int | tuple[int, ...] | None = None, keepdims: bool = False, - **kwargs) -> array: + **kwargs: object) -> Array: # https://github.com/pytorch/pytorch/issues/29137 if axis == (): return torch.clone(x) @@ -386,13 +410,13 @@ def mean(x: array, return res return torch.mean(x, axis, keepdims=keepdims, **kwargs) -def std(x: array, +def std(x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - correction: Union[int, float] = 0.0, + axis: int | tuple[int, ...] | None = None, + correction: float = 0.0, keepdims: bool = False, - **kwargs) -> array: + **kwargs: object) -> Array: # Note, float correction is not supported # https://github.com/pytorch/pytorch/issues/61492. We don't try to # implement it here for now. @@ -417,13 +441,13 @@ def std(x: array, return res return torch.std(x, axis, correction=_correction, keepdims=keepdims, **kwargs) -def var(x: array, +def var(x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - correction: Union[int, float] = 0.0, + axis: int | tuple[int, ...] | None = None, + correction: float = 0.0, keepdims: bool = False, - **kwargs) -> array: + **kwargs: object) -> Array: # Note, float correction is not supported # https://github.com/pytorch/pytorch/issues/61492. We don't try to # implement it here for now. @@ -446,11 +470,11 @@ def var(x: array, # torch.concat doesn't support dim=None # https://github.com/pytorch/pytorch/issues/70925 -def concat(arrays: Union[Tuple[array, ...], List[array]], +def concat(arrays: tuple[Array, ...] | list[Array], /, *, - axis: Optional[int] = 0, - **kwargs) -> array: + axis: int | None = 0, + **kwargs: object) -> Array: if axis is None: arrays = tuple(ar.flatten() for ar in arrays) axis = 0 @@ -459,7 +483,7 @@ def concat(arrays: Union[Tuple[array, ...], List[array]], # torch.squeeze only accepts int dim and doesn't require it # https://github.com/pytorch/pytorch/issues/70924. Support for tuple dim was # added at https://github.com/pytorch/pytorch/pull/89017. -def squeeze(x: array, /, axis: Union[int, Tuple[int, ...]]) -> array: +def squeeze(x: Array, /, axis: int | tuple[int, ...]) -> Array: if isinstance(axis, int): axis = (axis,) for a in axis: @@ -473,41 +497,83 @@ def squeeze(x: array, /, axis: Union[int, Tuple[int, ...]]) -> array: return x # torch.broadcast_to uses size instead of shape -def broadcast_to(x: array, /, shape: Tuple[int, ...], **kwargs) -> array: +def broadcast_to(x: Array, /, shape: tuple[int, ...], **kwargs: object) -> Array: return torch.broadcast_to(x, shape, **kwargs) # torch.permute uses dims instead of axes -def permute_dims(x: array, /, axes: Tuple[int, ...]) -> array: +def permute_dims(x: Array, /, axes: tuple[int, ...]) -> Array: return torch.permute(x, axes) # The axis parameter doesn't work for flip() and roll() # https://github.com/pytorch/pytorch/issues/71210. Also torch.flip() doesn't # accept axis=None -def flip(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs) -> array: +def flip(x: Array, /, *, axis: int | tuple[int, ...] | None = None, **kwargs: object) -> Array: if axis is None: axis = tuple(range(x.ndim)) # torch.flip doesn't accept dim as an int but the method does # https://github.com/pytorch/pytorch/issues/18095 return x.flip(axis, **kwargs) -def roll(x: array, /, shift: Union[int, Tuple[int, ...]], *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs) -> array: +def roll(x: Array, /, shift: int | tuple[int, ...], *, axis: int | tuple[int, ...] | None = None, **kwargs: object) -> Array: return torch.roll(x, shift, axis, **kwargs) -def nonzero(x: array, /, **kwargs) -> Tuple[array, ...]: +def nonzero(x: Array, /, **kwargs: object) -> tuple[Array, ...]: if x.ndim == 0: raise ValueError("nonzero() does not support zero-dimensional arrays") return torch.nonzero(x, as_tuple=True, **kwargs) -def where(condition: array, x1: array, x2: array, /) -> array: + +# torch uses `dim` instead of `axis` +def diff( + x: Array, + /, + *, + axis: int = -1, + n: int = 1, + prepend: Array | None = None, + append: Array | None = None, +) -> Array: + return torch.diff(x, dim=axis, n=n, prepend=prepend, append=append) + + +# torch uses `dim` instead of `axis`, does not have keepdims +def count_nonzero( + x: Array, + /, + *, + axis: int | tuple[int, ...] | None = None, + keepdims: bool = False, +) -> Array: + result = torch.count_nonzero(x, dim=axis) + if keepdims: + if isinstance(axis, int): + return result.unsqueeze(axis) + elif isinstance(axis, tuple): + n_axis = [x.ndim + ax if ax < 0 else ax for ax in axis] + sh = [1 if i in n_axis else x.shape[i] for i in range(x.ndim)] + return torch.reshape(result, sh) + return _axis_none_keepdims(result, x.ndim, keepdims) + else: + return result + + +# "repeat" is torch.repeat_interleave; also the dim argument +def repeat(x: Array, repeats: int | Array, /, *, axis: int | None = None) -> Array: + return torch.repeat_interleave(x, repeats, axis) + + +def where(condition: Array, x1: Array | complex, x2: Array | complex, /) -> Array: x1, x2 = _fix_promotion(x1, x2) return torch.where(condition, x1, x2) + # torch.reshape doesn't have the copy keyword -def reshape(x: array, +def reshape(x: Array, /, - shape: Tuple[int, ...], - copy: Optional[bool] = None, - **kwargs) -> array: + shape: tuple[int, ...], + *, + copy: bool | None = None, + **kwargs: object) -> Array: if copy is not None: raise NotImplementedError("torch.reshape doesn't yet support the copy keyword") return torch.reshape(x, shape, **kwargs) @@ -516,14 +582,14 @@ def reshape(x: array, # (https://github.com/pytorch/pytorch/issues/70915), and doesn't support some # keyword argument combinations # (https://github.com/pytorch/pytorch/issues/70914) -def arange(start: Union[int, float], +def arange(start: float, /, - stop: Optional[Union[int, float]] = None, - step: Union[int, float] = 1, + stop: float | None = None, + step: float = 1, *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, - **kwargs) -> array: + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object) -> Array: if stop is None: start, stop = 0, start if step > 0 and stop <= start or step < 0 and stop >= start: @@ -538,13 +604,13 @@ def arange(start: Union[int, float], # torch.eye does not accept None as a default for the second argument and # doesn't support off-diagonals (https://github.com/pytorch/pytorch/issues/70910) def eye(n_rows: int, - n_cols: Optional[int] = None, + n_cols: int | None = None, /, *, k: int = 0, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, - **kwargs) -> array: + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object) -> Array: if n_cols is None: n_cols = n_rows z = torch.zeros(n_rows, n_cols, dtype=dtype, device=device, **kwargs) @@ -553,70 +619,81 @@ def eye(n_rows: int, return z # torch.linspace doesn't have the endpoint parameter -def linspace(start: Union[int, float], - stop: Union[int, float], +def linspace(start: float, + stop: float, /, num: int, *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, + dtype: DType | None = None, + device: Device | None = None, endpoint: bool = True, - **kwargs) -> array: + **kwargs: object) -> Array: if not endpoint: return torch.linspace(start, stop, num+1, dtype=dtype, device=device, **kwargs)[:-1] return torch.linspace(start, stop, num, dtype=dtype, device=device, **kwargs) # torch.full does not accept an int size # https://github.com/pytorch/pytorch/issues/70906 -def full(shape: Union[int, Tuple[int, ...]], - fill_value: Union[bool, int, float, complex], +def full(shape: int | tuple[int, ...], + fill_value: complex, *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, - **kwargs) -> array: + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object) -> Array: if isinstance(shape, int): shape = (shape,) return torch.full(shape, fill_value, dtype=dtype, device=device, **kwargs) # ones, zeros, and empty do not accept shape as a keyword argument -def ones(shape: Union[int, Tuple[int, ...]], +def ones(shape: int | tuple[int, ...], *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, - **kwargs) -> array: + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object) -> Array: return torch.ones(shape, dtype=dtype, device=device, **kwargs) -def zeros(shape: Union[int, Tuple[int, ...]], +def zeros(shape: int | tuple[int, ...], *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, - **kwargs) -> array: + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object) -> Array: return torch.zeros(shape, dtype=dtype, device=device, **kwargs) -def empty(shape: Union[int, Tuple[int, ...]], +def empty(shape: int | tuple[int, ...], *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, - **kwargs) -> array: + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object) -> Array: return torch.empty(shape, dtype=dtype, device=device, **kwargs) # tril and triu do not call the keyword argument k -def tril(x: array, /, *, k: int = 0) -> array: +def tril(x: Array, /, *, k: int = 0) -> Array: return torch.tril(x, k) -def triu(x: array, /, *, k: int = 0) -> array: +def triu(x: Array, /, *, k: int = 0) -> Array: return torch.triu(x, k) # Functions that aren't in torch https://github.com/pytorch/pytorch/issues/58742 -def expand_dims(x: array, /, *, axis: int = 0) -> array: +def expand_dims(x: Array, /, *, axis: int = 0) -> Array: return torch.unsqueeze(x, axis) -def astype(x: array, dtype: Dtype, /, *, copy: bool = True) -> array: - return x.to(dtype, copy=copy) -def broadcast_arrays(*arrays: array) -> List[array]: +def astype( + x: Array, + dtype: DType, + /, + *, + copy: bool = True, + device: Device | None = None, +) -> Array: + if device is not None: + return x.to(device, dtype=dtype, copy=copy) + return x.to(dtype=dtype, copy=copy) + + +def broadcast_arrays(*arrays: Array) -> list[Array]: shape = torch.broadcast_shapes(*[a.shape for a in arrays]) return [torch.broadcast_to(a, shape) for a in arrays] @@ -626,7 +703,7 @@ def broadcast_arrays(*arrays: array) -> List[array]: UniqueInverseResult) # https://github.com/pytorch/pytorch/issues/70920 -def unique_all(x: array) -> UniqueAllResult: +def unique_all(x: Array) -> UniqueAllResult: # torch.unique doesn't support returning indices. # https://github.com/pytorch/pytorch/issues/36748. The workaround # suggested in that issue doesn't actually function correctly (it relies @@ -639,7 +716,7 @@ def unique_all(x: array) -> UniqueAllResult: # counts[torch.isnan(values)] = 1 # return UniqueAllResult(values, indices, inverse_indices, counts) -def unique_counts(x: array) -> UniqueCountsResult: +def unique_counts(x: Array) -> UniqueCountsResult: values, counts = torch.unique(x, return_counts=True) # torch.unique incorrectly gives a 0 count for nan values. @@ -647,27 +724,34 @@ def unique_counts(x: array) -> UniqueCountsResult: counts[torch.isnan(values)] = 1 return UniqueCountsResult(values, counts) -def unique_inverse(x: array) -> UniqueInverseResult: +def unique_inverse(x: Array) -> UniqueInverseResult: values, inverse = torch.unique(x, return_inverse=True) return UniqueInverseResult(values, inverse) -def unique_values(x: array) -> array: +def unique_values(x: Array) -> Array: return torch.unique(x) -def matmul(x1: array, x2: array, /, **kwargs) -> array: +def matmul(x1: Array, x2: Array, /, **kwargs: object) -> Array: # torch.matmul doesn't type promote (but differently from _fix_promotion) x1, x2 = _fix_promotion(x1, x2, only_scalar=False) return torch.matmul(x1, x2, **kwargs) -matrix_transpose = get_xp(torch)(_aliases_matrix_transpose) -_vecdot = get_xp(torch)(_aliases_vecdot) +matrix_transpose = get_xp(torch)(_aliases.matrix_transpose) +_vecdot = get_xp(torch)(_aliases.vecdot) -def vecdot(x1: array, x2: array, /, *, axis: int = -1) -> array: +def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: x1, x2 = _fix_promotion(x1, x2, only_scalar=False) return _vecdot(x1, x2, axis=axis) # torch.tensordot uses dims instead of axes -def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2, **kwargs) -> array: +def tensordot( + x1: Array, + x2: Array, + /, + *, + axes: int | tuple[Sequence[int], Sequence[int]] = 2, + **kwargs: object, +) -> Array: # Note: torch.tensordot fails with integer dtypes when there is only 1 # element in the axis (https://github.com/pytorch/pytorch/issues/84530). x1, x2 = _fix_promotion(x1, x2, only_scalar=False) @@ -675,8 +759,10 @@ def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int], def isdtype( - dtype: Dtype, kind: Union[Dtype, str, Tuple[Union[Dtype, str], ...]], - *, _tuple=True, # Disallow nested tuples + dtype: DType, + kind: DType | str | tuple[DType | str, ...], + *, + _tuple: bool = True, # Disallow nested tuples ) -> bool: """ Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``. @@ -710,14 +796,19 @@ def isdtype( else: return dtype == kind -def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -> array: +def take(x: Array, indices: Array, /, *, axis: int | None = None, **kwargs: object) -> Array: if axis is None: if x.ndim != 1: raise ValueError("axis must be specified when ndim > 1") axis = 0 return torch.index_select(x, axis, indices, **kwargs) -def sign(x: array, /) -> array: + +def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array: + return torch.take_along_dim(x, indices, dim=axis) + + +def sign(x: Array, /) -> Array: # torch sign() does not support complex numbers and does not propagate # nans. See https://github.com/data-apis/array-api-compat/issues/136 if x.dtype.is_complex: @@ -732,14 +823,21 @@ def sign(x: array, /) -> array: return out -__all__ = ['__array_namespace_info__', 'result_type', 'can_cast', +def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> list[Array]: + # enforce the default of 'xy' + # TODO: is the return type a list or a tuple + return list(torch.meshgrid(*arrays, indexing='xy')) + + +__all__ = ['asarray', 'result_type', 'can_cast', 'permute_dims', 'bitwise_invert', 'newaxis', 'conj', 'add', 'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or', - 'bitwise_right_shift', 'bitwise_xor', 'copysign', 'divide', + 'bitwise_right_shift', 'bitwise_xor', 'copysign', 'count_nonzero', + 'diff', 'divide', 'equal', 'floor_divide', 'greater', 'greater_equal', 'hypot', 'less', 'less_equal', 'logaddexp', 'maximum', 'minimum', 'multiply', 'not_equal', 'pow', 'remainder', 'subtract', 'max', - 'min', 'clip', 'unstack', 'cumulative_sum', 'sort', 'prod', 'sum', + 'min', 'clip', 'unstack', 'cumulative_sum', 'cumulative_prod', 'sort', 'prod', 'sum', 'any', 'all', 'mean', 'std', 'var', 'concat', 'squeeze', 'broadcast_to', 'flip', 'roll', 'nonzero', 'where', 'reshape', 'arange', 'eye', 'linspace', 'full', 'ones', 'zeros', 'empty', @@ -747,6 +845,4 @@ def sign(x: array, /) -> array: 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult', 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values', 'matmul', 'matrix_transpose', 'vecdot', 'tensordot', 'isdtype', - 'take', 'sign'] - -_all_ignore = ['torch', 'get_xp'] + 'take', 'take_along_axis', 'sign', 'finfo', 'iinfo', 'repeat', 'meshgrid'] diff --git a/array_api_compat/torch/_info.py b/array_api_compat/torch/_info.py index 264caa9e..818e5d37 100644 --- a/array_api_compat/torch/_info.py +++ b/array_api_compat/torch/_info.py @@ -34,7 +34,7 @@ class __array_namespace_info__: Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.default_dtypes() {'real floating': numpy.float64, 'complex floating': numpy.complex128, @@ -76,17 +76,17 @@ def capabilities(self): Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.capabilities() {'boolean indexing': True, - 'data-dependent shapes': True} + 'data-dependent shapes': True, + 'max dimensions': 64} """ return { "boolean indexing": True, "data-dependent shapes": True, - # 'max rank' will be part of the 2024.12 standard - # "max rank": 64, + "max dimensions": 64, } def default_device(self): @@ -102,15 +102,24 @@ def default_device(self): Returns ------- - device : str + device : Device The default device used for new PyTorch arrays. Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.default_device() - 'cpu' + device(type='cpu') + Notes + ----- + This method returns the static default device when PyTorch is initialized. + However, the *current* device used by creation functions (``empty`` etc.) + can be changed at runtime. + + See Also + -------- + https://github.com/data-apis/array-api/issues/835 """ return torch.device("cpu") @@ -120,9 +129,9 @@ def default_dtypes(self, *, device=None): Parameters ---------- - device : str, optional - The device to get the default data types for. For PyTorch, only - ``'cpu'`` is allowed. + device : Device, optional + The device to get the default data types for. + Unused for PyTorch, as all devices use the same default dtypes. Returns ------- @@ -139,7 +148,7 @@ def default_dtypes(self, *, device=None): Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.default_dtypes() {'real floating': torch.float32, 'complex floating': torch.complex64, @@ -250,8 +259,9 @@ def dtypes(self, *, device=None, kind=None): Parameters ---------- - device : str, optional + device : Device, optional The device to get the data types for. + Unused for PyTorch, as all devices use the same dtypes. kind : str or tuple of str, optional The kind of data types to return. If ``None``, all data types are returned. If a string, only data types of that kind are returned. @@ -287,7 +297,7 @@ def dtypes(self, *, device=None, kind=None): Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.dtypes(kind='signed integer') {'int8': numpy.int8, 'int16': numpy.int16, @@ -310,7 +320,7 @@ def devices(self): Returns ------- - devices : list of str + devices : list[Device] The devices supported by PyTorch. See Also @@ -322,7 +332,7 @@ def devices(self): Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.devices() [device(type='cpu'), device(type='mps', index=0), device(type='meta')] @@ -333,6 +343,7 @@ def devices(self): # device: try: torch.device('notadevice') + raise AssertionError("unreachable") # pragma: nocover except RuntimeError as e: # The error message is something like: # "Expected one of cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, ideep, hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, hpu, mtia, privateuseone device type at start of device string: notadevice" diff --git a/array_api_compat/torch/_typing.py b/array_api_compat/torch/_typing.py new file mode 100644 index 00000000..52670871 --- /dev/null +++ b/array_api_compat/torch/_typing.py @@ -0,0 +1,3 @@ +__all__ = ["Array", "Device", "DType"] + +from torch import device as Device, dtype as DType, Tensor as Array diff --git a/array_api_compat/torch/fft.py b/array_api_compat/torch/fft.py index 3c9117ee..76342980 100644 --- a/array_api_compat/torch/fft.py +++ b/array_api_compat/torch/fft.py @@ -1,86 +1,82 @@ from __future__ import annotations -from typing import TYPE_CHECKING -if TYPE_CHECKING: - import torch - array = torch.Tensor - from typing import Union, Sequence, Literal +from collections.abc import Sequence +from typing import Literal -from torch.fft import * # noqa: F403 +import torch import torch.fft +from ._typing import Array +from .._internal import clone_module + +__all__ = clone_module("torch.fft", globals()) + # Several torch fft functions do not map axes to dim def fftn( - x: array, + x: Array, /, *, s: Sequence[int] = None, axes: Sequence[int] = None, norm: Literal["backward", "ortho", "forward"] = "backward", - **kwargs, -) -> array: + **kwargs: object, +) -> Array: return torch.fft.fftn(x, s=s, dim=axes, norm=norm, **kwargs) def ifftn( - x: array, + x: Array, /, *, s: Sequence[int] = None, axes: Sequence[int] = None, norm: Literal["backward", "ortho", "forward"] = "backward", - **kwargs, -) -> array: + **kwargs: object, +) -> Array: return torch.fft.ifftn(x, s=s, dim=axes, norm=norm, **kwargs) def rfftn( - x: array, + x: Array, /, *, s: Sequence[int] = None, axes: Sequence[int] = None, norm: Literal["backward", "ortho", "forward"] = "backward", - **kwargs, -) -> array: + **kwargs: object, +) -> Array: return torch.fft.rfftn(x, s=s, dim=axes, norm=norm, **kwargs) def irfftn( - x: array, + x: Array, /, *, s: Sequence[int] = None, axes: Sequence[int] = None, norm: Literal["backward", "ortho", "forward"] = "backward", - **kwargs, -) -> array: + **kwargs: object, +) -> Array: return torch.fft.irfftn(x, s=s, dim=axes, norm=norm, **kwargs) def fftshift( - x: array, + x: Array, /, *, - axes: Union[int, Sequence[int]] = None, - **kwargs, -) -> array: + axes: int | Sequence[int] = None, + **kwargs: object, +) -> Array: return torch.fft.fftshift(x, dim=axes, **kwargs) def ifftshift( - x: array, + x: Array, /, *, - axes: Union[int, Sequence[int]] = None, - **kwargs, -) -> array: + axes: int | Sequence[int] = None, + **kwargs: object, +) -> Array: return torch.fft.ifftshift(x, dim=axes, **kwargs) -__all__ = torch.fft.__all__ + [ - "fftn", - "ifftn", - "rfftn", - "irfftn", - "fftshift", - "ifftshift", -] +__all__ += ["fftn", "ifftn", "rfftn", "irfftn", "fftshift", "ifftshift"] -_all_ignore = ['torch'] +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/torch/linalg.py b/array_api_compat/torch/linalg.py index e26198b9..08271d22 100644 --- a/array_api_compat/torch/linalg.py +++ b/array_api_compat/torch/linalg.py @@ -1,42 +1,35 @@ from __future__ import annotations -from typing import TYPE_CHECKING -if TYPE_CHECKING: - import torch - array = torch.Tensor - from torch import dtype as Dtype - from typing import Optional, Union, Tuple, Literal - inf = float('inf') +import torch +import torch.linalg -from ._aliases import _fix_promotion, sum - -from torch.linalg import * # noqa: F403 +from .._internal import clone_module -# torch.linalg doesn't define __all__ -# from torch.linalg import __all__ as linalg_all -from torch import linalg as torch_linalg -linalg_all = [i for i in dir(torch_linalg) if not i.startswith('_')] +__all__ = clone_module("torch.linalg", globals()) # outer is implemented in torch but aren't in the linalg namespace from torch import outer +from ._aliases import _fix_promotion, sum # These functions are in both the main and linalg namespaces from ._aliases import matmul, matrix_transpose, tensordot +from ._typing import Array, DType +from ..common._typing import JustInt, JustFloat # Note: torch.linalg.cross does not default to axis=-1 (it defaults to the # first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743 # torch.cross also does not support broadcasting when it would add new # dimensions https://github.com/pytorch/pytorch/issues/39656 -def cross(x1: array, x2: array, /, *, axis: int = -1) -> array: +def cross(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: x1, x2 = _fix_promotion(x1, x2, only_scalar=False) if not (-min(x1.ndim, x2.ndim) <= axis < max(x1.ndim, x2.ndim)): raise ValueError(f"axis {axis} out of bounds for cross product of arrays with shapes {x1.shape} and {x2.shape}") if not (x1.shape[axis] == x2.shape[axis] == 3): raise ValueError(f"cross product axis must have size 3, got {x1.shape[axis]} and {x2.shape[axis]}") x1, x2 = torch.broadcast_tensors(x1, x2) - return torch_linalg.cross(x1, x2, dim=axis) + return torch.linalg.cross(x1, x2, dim=axis) -def vecdot(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array: +def vecdot(x1: Array, x2: Array, /, *, axis: int = -1, **kwargs: object) -> Array: from ._aliases import isdtype x1, x2 = _fix_promotion(x1, x2, only_scalar=False) @@ -58,7 +51,7 @@ def vecdot(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array: return res[..., 0, 0] return torch.linalg.vecdot(x1, x2, dim=axis, **kwargs) -def solve(x1: array, x2: array, /, **kwargs) -> array: +def solve(x1: Array, x2: Array, /, **kwargs: object) -> Array: x1, x2 = _fix_promotion(x1, x2, only_scalar=False) # Torch tries to emulate NumPy 1 solve behavior by using batched 1-D solve # whenever @@ -79,19 +72,20 @@ def solve(x1: array, x2: array, /, **kwargs) -> array: return torch.linalg.solve(x1, x2, **kwargs) # torch.trace doesn't support the offset argument and doesn't support stacking -def trace(x: array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> array: +def trace(x: Array, /, *, offset: int = 0, dtype: DType | None = None) -> Array: # Use our wrapped sum to make sure it does upcasting correctly return sum(torch.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype) def vector_norm( - x: array, + x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, + axis: int | tuple[int, ...] | None = None, keepdims: bool = False, - ord: Union[int, float, Literal[inf, -inf]] = 2, - **kwargs, -) -> array: + # JustFloat stands for inf | -inf, which are not valid for Literal + ord: JustInt | JustFloat = 2, + **kwargs: object, +) -> Array: # torch.vector_norm incorrectly treats axis=() the same as axis=None if axis == (): out = kwargs.get('out') @@ -113,9 +107,8 @@ def vector_norm( return out return torch.linalg.vector_norm(x, ord=ord, axis=axis, keepdim=keepdims, **kwargs) -__all__ = linalg_all + ['outer', 'matmul', 'matrix_transpose', 'tensordot', - 'cross', 'vecdot', 'solve', 'trace', 'vector_norm'] - -_all_ignore = ['torch_linalg', 'sum'] +__all__ += ['outer', 'matmul', 'matrix_transpose', 'tensordot', + 'cross', 'vecdot', 'solve', 'trace', 'vector_norm'] -del linalg_all +def __dir__() -> list[str]: + return __all__ diff --git a/cupy-xfails.txt b/cupy-xfails.txt index fb7b03da..0a91cafe 100644 --- a/cupy-xfails.txt +++ b/cupy-xfails.txt @@ -11,12 +11,10 @@ array_api_tests/test_array_object.py::test_scalar_casting[__index__(int64)] # testsuite bug (https://github.com/data-apis/array-api-tests/issues/172) array_api_tests/test_array_object.py::test_getitem -# copy=False is not yet implemented -array_api_tests/test_creation_functions.py::test_asarray_arrays - -# finfo test is testing that the result is a float instead of float32 (see -# also https://github.com/data-apis/array-api/issues/405) +# attributes are np.float32 instead of float +# (see also https://github.com/data-apis/array-api/issues/405) array_api_tests/test_data_type_functions.py::test_finfo[float32] +array_api_tests/test_data_type_functions.py::test_finfo[complex64] # Some array attributes are missing, and we do not wrap the array object array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__] @@ -36,6 +34,16 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__sub array_api_tests/test_operators_and_elementwise_functions.py::test_add[__add__(x, s)] # floating point inaccuracy array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[remainder(x1, x2)] +# incomplete NEP50 support in CuPy 13.x (fixed in 14.0.0a1) +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[pow] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[logaddexp] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[nextafter] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[hypot] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[atan2] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[divide] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[copysign] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[divide] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[maximum] # cupy (arg)min/max wrong with infinities # https://github.com/cupy/cupy/issues/7424 @@ -173,10 +181,23 @@ array_api_tests/test_special_cases.py::test_unary[trunc(x_i is -0) -> -0] array_api_tests/test_fft.py::test_fftn array_api_tests/test_fft.py::test_ifftn array_api_tests/test_fft.py::test_rfftn + +# observed in the 1.10 release process, is likely related to xfails above +array_api_tests/test_fft.py::test_irfftn # 2023.12 support # cupy.ndaray cannot be specified as `repeats` argument. array_api_tests/test_manipulation_functions.py::test_repeat array_api_tests/test_signatures.py::test_func_signature[from_dlpack] -array_api_tests/test_signatures.py::test_func_signature[astype] array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] + +# 2024.12 support +array_api_tests/test_signatures.py::test_func_signature[bitwise_and] +array_api_tests/test_signatures.py::test_func_signature[bitwise_left_shift] +array_api_tests/test_signatures.py::test_func_signature[bitwise_or] +array_api_tests/test_signatures.py::test_func_signature[bitwise_right_shift] +array_api_tests/test_signatures.py::test_func_signature[bitwise_xor] +array_api_tests/test_special_cases.py::test_binary[nextafter(x1_i is +0 and x2_i is -0) -> -0] + +# cupy 13.x follows numpy 1.x w/o weak promotion: result_type(int32, uint8, 1) != result_type(int32, uint8) +array_api_tests/test_data_type_functions.py::TestResultType::test_with_scalars diff --git a/dask-skips.txt b/dask-skips.txt index 2a67d75d..a16a8588 100644 --- a/dask-skips.txt +++ b/dask-skips.txt @@ -1,17 +1,9 @@ -# FFT isn't conformant -array_api_tests/test_fft.py -array_api_tests/test_signatures.py::test_extension_func_signature[fft.fft] -array_api_tests/test_signatures.py::test_extension_func_signature[fft.ifft] -array_api_tests/test_signatures.py::test_extension_func_signature[fft.fftn] -array_api_tests/test_signatures.py::test_extension_func_signature[fft.ifftn] -array_api_tests/test_signatures.py::test_extension_func_signature[fft.rfft] -array_api_tests/test_signatures.py::test_extension_func_signature[fft.irfft] -array_api_tests/test_signatures.py::test_extension_func_signature[fft.rfftn] -array_api_tests/test_signatures.py::test_extension_func_signature[fft.irfftn] -array_api_tests/test_signatures.py::test_extension_func_signature[fft.hfft] -array_api_tests/test_signatures.py::test_extension_func_signature[fft.ihfft] -array_api_tests/test_signatures.py::test_extension_func_signature[fft.fftfreq] -array_api_tests/test_signatures.py::test_extension_func_signature[fft.rfftfreq] +# NOTE: dask tests run on a very small number of examples in CI due to +# slowness. This causes very high flakiness in the tests. +# Before changing this file, please run with at least 200 examples. -# slow and not implemented in dask -array_api_tests/test_linalg.py::test_matrix_power +# Passes, but extremely slow +array_api_tests/test_linalg.py::test_outer + +# Hangs +array_api_tests/test_creation_functions.py::test_eye diff --git a/dask-xfails.txt b/dask-xfails.txt index 1e9c421c..3efb4f96 100644 --- a/dask-xfails.txt +++ b/dask-xfails.txt @@ -1,73 +1,46 @@ -# This fails in dask -# import dask.array as da -# a = da.array([1]).reshape((1,1)) -# key = (0, slice(None, None, -1)) -# a[key] = da.array([1]) - -# Failing hypothesis test case -#x=dask.array -#| Draw 1 (key): (slice(None, None, None), slice(None, None, None)) -#| Draw 2 (value): dask.array - -# Various shape mismatches e.g. -ValueError: shape mismatch: value array of shape (0, 2) could not be broadcast to indexing result of shape (0, 2) -array_api_tests/test_array_object.py::test_setitem +# NOTE: dask tests run on a very small number of examples in CI due to +# slowness. This causes very high flakiness in the tests. +# Before changing this file, please run with at least 200 examples. -# Fails since bad upcast from uint8 -> int64 -# MRE: -# a = da.array(0, dtype="uint8") -# b = da.array(False) -# a[b] = 0 -array_api_tests/test_array_object.py::test_setitem_masking +# Broken edge case with shape 0 +# https://github.com/dask/dask/issues/11800 +array_api_tests/test_array_object.py::test_setitem # Various indexing errors array_api_tests/test_array_object.py::test_getitem_masking -# asarray(copy=False) is not yet implemented -# copied from numpy xfails, TODO: should this pass with dask? -array_api_tests/test_creation_functions.py::test_asarray_arrays - # zero division error, and typeerror: tuple indices must be integers or slices not tuple array_api_tests/test_creation_functions.py::test_eye -# finfo(float32).eps returns float32 but should return float +# attributes are np.float32 instead of float +# (see also https://github.com/data-apis/array-api/issues/405) array_api_tests/test_data_type_functions.py::test_finfo[float32] +array_api_tests/test_data_type_functions.py::test_finfo[complex64] -# out[-1]=dask.aray but should be some floating number +# out[-1]=dask.array but should be some floating number # (I think the test is not forcing the op to be computed?) array_api_tests/test_creation_functions.py::test_linspace -# out.shape=(2,) but should be (1,) +# Shape mismatch array_api_tests/test_indexing_functions.py::test_take -# out=-0, but should be +0 -array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0] -array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0] - -# output is nan but should be infinity -array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity] -array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity] +# missing `take_along_axis`, https://github.com/dask/dask/issues/3663 +array_api_tests/test_indexing_functions.py::test_take_along_axis -# No sorting in dask -array_api_tests/test_has_names.py::test_has_names[sorting-argsort] -array_api_tests/test_has_names.py::test_has_names[sorting-sort] -array_api_tests/test_sorting_functions.py::test_argsort -array_api_tests/test_sorting_functions.py::test_sort -array_api_tests/test_signatures.py::test_func_signature[argsort] -array_api_tests/test_signatures.py::test_func_signature[sort] - -# Array methods and attributes not already on np.ndarray cannot be wrapped +# Array methods and attributes not already on da.Array cannot be wrapped array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__] array_api_tests/test_has_names.py::test_has_names[array_method-to_device] array_api_tests/test_has_names.py::test_has_names[array_attribute-device] array_api_tests/test_has_names.py::test_has_names[array_attribute-mT] -# Fails because shape is NaN since we don't materialize it yet +# Data-dependent output shape +# These tests fail as array-api-tests doesn't cope with unknown shapes +# Also, output shape is (math.nan, ) instead of (None, ) +# Also, da.unique() doesn't accept equals_nan which causes non-compliant +# output when there are NaNs in the input. array_api_tests/test_searching_functions.py::test_nonzero array_api_tests/test_set_functions.py::test_unique_all array_api_tests/test_set_functions.py::test_unique_counts - -# Different error but same cause as above, we're just trying to do ndindex on nan shape array_api_tests/test_set_functions.py::test_unique_inverse array_api_tests/test_set_functions.py::test_unique_values @@ -75,24 +48,17 @@ array_api_tests/test_set_functions.py::test_unique_values # fails for ndim > 2 array_api_tests/test_linalg.py::test_svdvals -array_api_tests/test_linalg.py::test_cholesky -# dtype mismatch got uint64, but should be uint8, NPY_PROMOTION_STATE=weak doesn't help :( + +# dtype mismatch got uint64, but should be uint8; NPY_PROMOTION_STATE=weak doesn't help array_api_tests/test_linalg.py::test_tensordot # AssertionError: out.dtype=uint64, but should be uint8 [tensordot(uint8, uint8)] array_api_tests/test_linalg.py::test_linalg_tensordot -# AssertionError: out.shape=(1,), but should be () [linalg.vector_norm(keepdims=True)] -array_api_tests/test_linalg.py::test_vector_norm - # ZeroDivisionError in dask's normalize_chunks/auto_chunks internals array_api_tests/test_linalg.py::test_inv array_api_tests/test_linalg.py::test_matrix_power -# did not raise error for invalid shapes -array_api_tests/test_linalg.py::test_matmul -array_api_tests/test_linalg.py::test_linalg_matmul - # Linalg - these don't exist in dask array_api_tests/test_signatures.py::test_extension_func_signature[linalg.cross] array_api_tests/test_signatures.py::test_extension_func_signature[linalg.det] @@ -105,6 +71,7 @@ array_api_tests/test_linalg.py::test_cross array_api_tests/test_linalg.py::test_det array_api_tests/test_linalg.py::test_eigh array_api_tests/test_linalg.py::test_eigvalsh +array_api_tests/test_linalg.py::test_matrix_rank array_api_tests/test_linalg.py::test_pinv array_api_tests/test_linalg.py::test_slogdet array_api_tests/test_has_names.py::test_has_names[linalg-cross] @@ -115,17 +82,10 @@ array_api_tests/test_has_names.py::test_has_names[linalg-matrix_power] array_api_tests/test_has_names.py::test_has_names[linalg-pinv] array_api_tests/test_has_names.py::test_has_names[linalg-slogdet] -array_api_tests/test_linalg.py::test_matrix_norm -array_api_tests/test_linalg.py::test_matrix_rank - -# missing mode kw -# https://github.com/dask/dask/issues/10388 -array_api_tests/test_linalg.py::test_qr - # Constructing the input arrays fails to a weird shape error... array_api_tests/test_linalg.py::test_solve -# missing full_matrics kw +# missing full_matrices kw # https://github.com/dask/dask/issues/10389 # also only supports 2-d inputs array_api_tests/test_linalg.py::test_svd @@ -140,18 +100,51 @@ array_api_tests/test_has_names.py::test_has_names[creation-from_dlpack] array_api_tests/test_has_names.py::test_has_names[array_method-__dlpack__] array_api_tests/test_has_names.py::test_has_names[array_method-__dlpack_device__] -# Some cases unsupported by dask -array_api_tests/test_manipulation_functions.py::test_roll - # No mT on dask array array_api_tests/meta/test_hypothesis_helpers.py::test_symmetric_matrices -# The test suite is incorrectly checking sums that have loss of significance -# (https://github.com/data-apis/array-api-tests/issues/168) -array_api_tests/test_statistical_functions.py::test_sum -array_api_tests/test_statistical_functions.py::test_prod +# Edge case of args near 2**63 +# https://github.com/dask/dask/issues/11706 +array_api_tests/test_creation_functions.py::test_arange + +# da.searchsorted with a sorter argument is not supported +array_api_tests/test_searching_functions.py::test_searchsorted # 2023.12 support array_api_tests/test_manipulation_functions.py::test_repeat -array_api_tests/test_searching_functions.py::test_searchsorted -array_api_tests/test_signatures.py::test_func_signature[astype] + +# 2024.12 support +array_api_tests/test_array_object.py::test_getitem_arrays_and_ints_1[1] +array_api_tests/test_array_object.py::test_getitem_arrays_and_ints_1[None] +array_api_tests/test_array_object.py::test_getitem_arrays_and_ints_2[1] +array_api_tests/test_array_object.py::test_getitem_arrays_and_ints_2[None] +array_api_tests/test_has_names.py::test_has_names[indexing-take_along_axis] +array_api_tests/test_signatures.py::test_func_signature[count_nonzero] +array_api_tests/test_signatures.py::test_func_signature[take_along_axis] + +array_api_tests/test_linalg.py::test_cholesky +array_api_tests/test_linalg.py::test_linalg_matmul +array_api_tests/test_linalg.py::test_matmul +array_api_tests/test_linalg.py::test_matrix_norm +array_api_tests/test_linalg.py::test_qr +array_api_tests/test_manipulation_functions.py::test_roll + +# Stubs have a comment: (**note**: libraries may return ``NaN`` to match Python behavior.) +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] +array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] diff --git a/docs/_static/custom.css b/docs/_static/custom.css index bac04989..c712f02d 100644 --- a/docs/_static/custom.css +++ b/docs/_static/custom.css @@ -10,3 +10,17 @@ body { html { scroll-behavior: auto; } + +/* Make checkboxes from the tasklist extension ('- [ ]' in Markdown) not add bullet points to the checkboxes. + +This can be removed once https://github.com/executablebooks/mdit-py-plugins/issues/59 is addressed. +*/ + +.contains-task-list { + list-style: none; +} + +/* Make the checkboxes indented like they are bullets */ +.task-list-item-checkbox { + margin: 0 0.2em 0.25em -1.4em; +} diff --git a/docs/changelog.md b/docs/changelog.md index d1e67333..6f6c1251 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -1,5 +1,176 @@ # Changelog +## 1.12.0 (2025-05-13) + + +### Major changes + +- The build system has been updated to use `pyproject.toml` instead of `setup.py` +- Support for Python 3.9 has been dropped. The minimum supported Python version is now + 3.10; the minimum supported NumPy version is 1.22. +- The `linalg` extension works correctly with `pytorch>=2.7`. +- Multiple improvements to handling of devices in CuPy and PyTorch backends. + Support for multiple devices in CuPy is still immature and you should use + context managers rather than relying on input-output device propagation or + on the `device` parameter. Please report any issues you encounter. + +### Minor changes + +- `finfo` and `iinfo` functions now accept array arguments, in accordance with the + Array API spec; +- `torch.asarray` function propagates the device of the input array. This works around + the [pytorch issue #150199](https://github.com/pytorch/pytorch/issues/150199); +- `torch.repeat` function is now available; +- `torch.count_nonzero` function now correctly handles the case of a tuple `axis` + arguments and `keepdims=True`; +- `torch.meshgrid` wrapper defaults to `indexing="xy"`, in accordance with the + array API specification; +- `cupy.asarray` function now implements the `copy=False` argument, albeit + at the cost of risking to make a temporary copy. +- In `numpy.take_along_axis` and `cupy.take_along_axis` the `axis` parameter now + defaults to -1, in accordance to the Array API spec. + + +The following users contributed to this release: + +Evgeni Burovski, +Lucas Colley, +Neil Girdhar, +Joren Hammudoglu, +Guido Imperiale + + +## 1.11.2 (2025-03-20) + +This is a bugfix release with no new features compared to version 1.11. + +- fix the `result_type` wrapper for pytorch. Previously, `result_type` had multiple + issues with scalar arguments. +- fix several issues with `clip` wrappers. Previously, `clip` was failing to allow + behaviors which are unspecified by the 2024.12 standard but allowed by the array + libraries. + +The following users contributed to this release: + +Evgeni Burovski +Guido Imperiale +Magnus Dalen Kvalevåg + + +## 1.11.1 (2025-03-04) + +This is a bugfix release with no new features compared to version 1.11. + +### Major Changes + +- fix `count_nonzero` wrappers: work around the lack of the `keepdims` argument in + several array libraries (torch, dask, cupy); work around numpy returning python + ints in for some input combinations. + +### Minor Changes + +- runnings self-tests does not require all array libraries. Missing libraries are + skipped. + +The following users contributed to this release: + +Evgeni Burovski +Guido Imperiale + + +## 1.11.0 (2025-02-27) + +### Major Changes + +This release targets the 2024.12 Array API revision. This includes + + - `__array_api_version__` for the wrapped APIs is now set to `2024.12`; + - Wrappers for `count_nonzero`; + - Wrappers for `cumulative_prod`; + - Wrappers for `take_along_axis` (with the exception of Dask); + - Wrappers for `diff`; + - `__capabilities__` dict contains a `max_dimensions` key; + - Python scalars are accepted as arguments to `result_type`; + - `fft.fftfreq` and `fft.rfftfreq` functions now accept an optional `dtype` + argument to control the output data type. + +Note that these wrappers, as well as other 2024.12 features, are relatively undertested +in this release, and may have rough edges. Please report any issues you encounter +in [the issue tracker](https://github.com/data-apis/array-api-compat/issues). + +New functions to test properties of arrays: + - `is_writeable_array` (benefits NumPy, JAX, Sparse) + - `is_lazy_array` (benefits JAX, Dask, ndonnx) + +Improved support for JAX: + - Work arounds for `.device` attribute and `to_device` function + not working correctly within `jax.jit` + +### Minor Changes + +- Several improvements to `dask.array` wrappers: + + - `size` returns None for arrays of unknown shapes. + - `astype(..., copy=True)` always copies, independently of the Dask version. + - implementations of `sort` and `argsort` are now available. Note that these + implementations are relatively crude, and might be memory intensive. + - `asarray` no longer accidentally materializes the Dask graph + - `torch` wrappers contain unsigned integer dtypes of widths >8 bits, `uint16`, + `uint32` and `uint64` if PyTorch version is at least 2.3. Note that the + unsigned integer support is incomplete in PyTorch itself, see + [gh-253](https://github.com/data-apis/array-api-compat/pull/253). + +### Authors + +The following users contributed to this release: + +Athan Reines +Guido Imperiale +Evgeni Burovski +Guido Imperiale +Lucas Colley +Ralf Gommers +Thomas Li + + +## 1.10.0 (2024-12-25) + +### Major Changes + +- New function `is_writeable_array` adds transparent support for readonly + arrays, such as JAX arrays or numpy arrays with `.flags.writeable=False`. + +- `asarray(..., copy=None)` with `dask` backend always copies, so that + `copy=None` and `copy=True` are equivalent for the `dask` backend. + This change is made to be forward compatible with the `dask==2024.12` + release. + + +### Minor Changes + +- `array_namespace` accepts (and ignores) `None` and python scalars (int, float, + complex, bool). This change is to simplify downstream adoption, for + functions where arguments can be either arrays or scalars. + +- `vecdot` conjugates its first argument, as stipulated by the Array API spec. + Previously, conjation if the first argument was missing. + + +## 1.9.1 (2024-10-29) + +### Major Changes + +- `__array_api_version__` for the wrapped APIs is now set to `2023.12`. + +### Minor Changes + +- Wrap `sign` so that it always uses the standard definition for complex + numbers, and always propagates nans. + +- Wrap dask.array.fft. + +- Readd `python_requires` to the package metadata. + ## 1.9 (2024-10-??) ### Major Changes @@ -30,6 +201,10 @@ ### Minor Changes +- NumPy 2.0 is now wrapped again. Previously it was unwrapped because it has + full 2022.12 array API support but it now requires wrapping again for + 2023.12 support. + - Support for JAX 0.4.32 and newer which implements the array API directly in `jax.numpy`. diff --git a/docs/conf.py b/docs/conf.py index d8d5c2da..ac9e6dd7 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -38,7 +38,8 @@ templates_path = ['_templates'] exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] -myst_enable_extensions = ["dollarmath", "linkify"] +myst_enable_extensions = ["dollarmath", "linkify", "tasklist"] +myst_enable_checkboxes = True napoleon_use_rtype = False napoleon_use_param = False diff --git a/docs/dev/implementation-notes.md b/docs/dev/implementation-notes.md index 48fb01f1..7f9c4634 100644 --- a/docs/dev/implementation-notes.md +++ b/docs/dev/implementation-notes.md @@ -47,3 +47,17 @@ identical PyTorch uses a similar layout in `array_api_compat/torch/`, but it differs enough from NumPy/CuPy that very few common wrappers for those libraries are reused. Dask is close to NumPy in behavior and so most Dask functions also reuse the NumPy/CuPy common wrappers. + +Occasionally, a wrapper implementation will need to reference another wrapper +implementation, rather than the base `xp` version. The easiest way to do this +is to call `array_namespace`, like + +```py +wrapped_xp = array_namespace(x) +wrapped_xp.wrapped_func(...) +``` + +Also, if there is a very minor difference required for wrapping, say, CuPy and +NumPy, they can still use a common implementation in `common/_aliases.py` and +use the `is_*_namespace()` or `is_*_function()` [helper +functions](../helper-functions.rst) to branch as necessary. diff --git a/docs/dev/releasing.md b/docs/dev/releasing.md index b2236597..1ee17709 100644 --- a/docs/dev/releasing.md +++ b/docs/dev/releasing.md @@ -1,59 +1,108 @@ # Releasing -To release, first make sure that all CI tests are passing on `main`. +- [ ] **Create a PR with a release branch** -Note that CuPy must be tested manually (it isn't tested on CI). Use the script + This makes it easy to verify that CI is passing, and also gives you a place + to push up updates to the changelog and any last minute fixes for the + release. -``` -./test_cupy.sh -``` +- [ ] **Double check the release branch is fully merged with `main`.** -on a machine with a CUDA GPU. + (e.g., if the release branch is called `release`) -Once you are ready to release, create a PR with a release branch, so that you -can verify that CI is passing. You must edit + ``` + git checkout main + git pull + git checkout release + git merge main + ``` -``` -array_api_compat/__init__.py -``` +- [ ] **Make sure that all CI tests are passing.** -and update the version (the version is not computed from the tag because that -would break vendorability). You should also edit + Note that the GitHub action that publishes to PyPI does not check if CI is + passing before publishing. So you need to check this manually. -``` -docs/changelog.md -``` + This does mean you can ignore CI failures, but ideally you should fix any + failures or update the `*-xfails.txt` files before tagging, so that CI and + the CuPy tests fully pass. Otherwise it will be hard to tell what things are + breaking in the future. It's also a good idea to remove any xpasses from + those files (but be aware that some xfails are from flaky failures, so + unless you know the underlying issue has been fixed, an xpass test is + probably still xfail). -with the changes for the release. +- [ ] **Test CuPy.** -Once everything is ready, create a tag + CuPy must be tested manually (it isn't tested on CI, see + https://github.com/data-apis/array-api-compat/issues/197). Use the script -``` -git tag -a -``` + ``` + ./test_cupy.sh + ``` -(note the tag names are not prefixed, for instance, the tag for version 1.5 is -just `1.5`) + on a machine with a CUDA GPU. -and push it to GitHub -``` -git push origin -``` +- [ ] **Update the version.** -Check that the `publish distributions` action on the tag build works. Note -that this action will run even if the other CI fails, so you must make sure -that CI is passing *before* tagging. + You must edit -This does mean you can ignore CI failures, but ideally you should fix any -failures or update the `*-xfails.txt` files before tagging, so that CI and the -cupy tests pass. Otherwise it will be hard to tell what things are breaking in -the future. It's also a good idea to remove any xpasses from those files (but -be aware that some xfails are from flaky failures, so unless you know the -underlying issue has been fixed, an xpass test is probably still xfail). + ``` + array_api_compat/__init__.py + ``` -If the publish action fails for some reason and didn't upload the release to -PyPI, you will need to delete the tag and try again. + and update the version (the version is not computed from the tag because + that would break vendorability). -After the PyPI package is published, the conda-forge bot should update the -feedstock automatically. +- [ ] **Update the [changelog](../changelog.md).** + + Edit + + ``` + docs/changelog.md + ``` + + with the changes for the release. + +- [ ] **Create the release tag.** + + Once everything is ready, create a tag + + ``` + git tag -a + ``` + + (note the tag names are not prefixed, for instance, the tag for version 1.5 is + just `1.5`) + +- [ ] **Push the tag to GitHub.** + + *This is the final step. Doing this will build and publish the release!* + + ``` + git push origin + ``` + + This will trigger the [`publish + distributions`](https://github.com/data-apis/array-api-compat/actions/workflows/publish-package.yml) + GitHub Action that will build the release and push it to PyPI. + +- [ ] **Check that the [`publish + distributions`](https://github.com/data-apis/array-api-compat/actions/workflows/publish-package.yml) + action build on the tag worked.** Note that this action will run even if the + other CI fails, so you must make sure that CI is passing *before* tagging. + + If it failed for some reason, you may need to delete the tag and try again. + +- [ ] **Merge the release branch.** + + This way any changes you made in the branch, such as updates to the + changelog or xfails files, are updated in `main`. This will also make the + docs update (the docs are published automatically from the sources on + `main`). + +- [ ] **Update conda-forge.** + + After the PyPI package is published, the conda-forge bot should update the + feedstock automatically after some time. The bot should automerge, so in + most cases you don't need to do anything here, unless some metadata on the + feedstock needs to be updated. diff --git a/docs/dev/tests.md b/docs/dev/tests.md index 6d9d1d7b..18fb7cf5 100644 --- a/docs/dev/tests.md +++ b/docs/dev/tests.md @@ -7,7 +7,7 @@ the array API standard. There are also array-api-compat specific tests in These tests should be limited to things that are not tested by the test suite, e.g., tests for [helper functions](../helper-functions.rst) or for behavior that is not strictly required by the standard. To run these tests, install the -dependencies from `requirements-dev.txt` (array-api-compat has [no hard +dependencies from the `dev` optional group (array-api-compat has [no hard runtime dependencies](no-dependencies)). array-api-tests is run against all supported libraries are tested on CI diff --git a/docs/helper-functions.rst b/docs/helper-functions.rst index f44dc070..155eda9a 100644 --- a/docs/helper-functions.rst +++ b/docs/helper-functions.rst @@ -51,6 +51,8 @@ yet. .. autofunction:: is_jax_array .. autofunction:: is_pydata_sparse_array .. autofunction:: is_ndonnx_array +.. autofunction:: is_writeable_array +.. autofunction:: is_lazy_array .. autofunction:: is_numpy_namespace .. autofunction:: is_cupy_namespace .. autofunction:: is_torch_namespace diff --git a/docs/index.md b/docs/index.md index b268e61a..c5c15174 100644 --- a/docs/index.md +++ b/docs/index.md @@ -12,8 +12,8 @@ each array library itself fully compatible with the array API, but this requires making backwards incompatible changes in many cases, so this will take some time. -Currently all libraries here are implemented against the [2022.12 -version](https://data-apis.org/array-api/2022.12/) of the standard. +Currently all libraries here are implemented against the [2024.12 +version](https://data-apis.org/array-api/2024.12/) of the standard. ## Installation diff --git a/docs/requirements.txt b/docs/requirements.txt deleted file mode 100644 index dbec7740..00000000 --- a/docs/requirements.txt +++ /dev/null @@ -1,6 +0,0 @@ -furo -linkify-it-py -myst-parser -sphinx -sphinx-copybutton -sphinx-autobuild diff --git a/docs/supported-array-libraries.md b/docs/supported-array-libraries.md index a016a636..46fcdc27 100644 --- a/docs/supported-array-libraries.md +++ b/docs/supported-array-libraries.md @@ -36,23 +36,16 @@ deviations from the standard should be noted: 50](https://numpy.org/neps/nep-0050-scalar-promotion.html) and https://github.com/numpy/numpy/issues/22341) -- `asarray()` does not support `copy=False`. - - Functions which are not wrapped may not have the same type annotations as the spec. - Functions which are not wrapped may not use positional-only arguments. -The minimum supported NumPy version is 1.21. However, this older version of +The minimum supported NumPy version is 1.22. However, this older version of NumPy has a few issues: - `unique_*` will not compare nans as unequal. -- `finfo()` has no `smallest_normal`. - No `from_dlpack` or `__dlpack__`. -- `argmax()` and `argmin()` do not have `keepdims`. -- `qr()` doesn't support matrix stacks. -- `asarray()` doesn't support `copy=True` (as noted above, `copy=False` is not - supported even in the latest NumPy). - Type promotion behavior will be value based for 0-D arrays (and there is no `NPY_PROMOTION_STATE=weak` to disable this). @@ -72,8 +65,8 @@ version. attribute in the spec. Use the {func}`~.size()` helper function as a portable workaround. -- PyTorch does not have unsigned integer types other than `uint8`, and no - attempt is made to implement them here. +- PyTorch has incomplete support for unsigned integer types other + than `uint8`, and no attempt is made to implement them here. - PyTorch has type promotion semantics that differ from the array API specification for 0-D tensor objects. The array functions in this wrapper @@ -100,8 +93,6 @@ version. - As with NumPy, type annotations and positional-only arguments may not exactly match the spec for functions that are not wrapped at all. -The minimum supported PyTorch version is 1.13. - (jax-support)= ## [JAX](https://jax.readthedocs.io/en/latest/) @@ -131,9 +122,17 @@ For `linalg`, several methods are missing, for example: - `matrix_rank` Other methods may only be partially implemented or return incorrect results at times. -The minimum supported Dask version is 2023.12.0. - (sparse-support)= ## [Sparse](https://sparse.pydata.org/en/stable/) Similar to JAX, `sparse` Array API support is contained directly in `sparse`. + +(ndonnx-support)= +## [ndonnx](https://github.com/quantco/ndonnx) + +Similar to JAX, `ndonnx` Array API support is contained directly in `ndonnx`. + +(array-api-strict-support)= +## [array-api-strict](https://data-apis.org/array-api-strict/) + +array-api-strict exists only to test support for the Array API, so it does not need any wrappers. diff --git a/numpy-1-21-xfails.txt b/numpy-1-21-xfails.txt deleted file mode 100644 index 459b33e3..00000000 --- a/numpy-1-21-xfails.txt +++ /dev/null @@ -1,260 +0,0 @@ -# asarray(copy=False) is not yet implemented -array_api_tests/test_creation_functions.py::test_asarray_arrays - -# https://github.com/data-apis/array-api-tests/issues/195 -array_api_tests/test_creation_functions.py::test_linspace - -# finfo(float32).eps returns float32 but should return float -array_api_tests/test_data_type_functions.py::test_finfo[float32] - -# Array methods and attributes not already on np.ndarray cannot be wrapped -array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__] -array_api_tests/test_has_names.py::test_has_names[array_method-to_device] -array_api_tests/test_has_names.py::test_has_names[array_attribute-device] -array_api_tests/test_has_names.py::test_has_names[array_attribute-mT] - -# linalg tests require https://github.com/data-apis/array-api-tests/pull/101 -# cleanups. Also some tests are using .mT -array_api_tests/test_linalg.py::test_eigvalsh -array_api_tests/test_linalg.py::test_solve -array_api_tests/test_linalg.py::test_trace - -# Array methods and attributes not already on np.ndarray cannot be wrapped -array_api_tests/test_signatures.py::test_array_method_signature[__array_namespace__] -array_api_tests/test_signatures.py::test_array_method_signature[to_device] - -# NumPy deviates in some special cases for floordiv -array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] -array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] -array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] -array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] -array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] -array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] -array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] -array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] -array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] -array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] -array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] -array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] -array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] -array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] -array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] -array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] -array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] -array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] - -# https://github.com/numpy/numpy/issues/21213 -array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity] -array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0] -array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity] -array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0] -array_api_tests/meta/test_hypothesis_helpers.py::test_symmetric_matrices - -# NumPy 1.21 specific XFAILS -############################ - -# finfo has no smallest_normal -array_api_tests/test_data_type_functions.py::test_finfo[float64] - -# dlpack stuff -array_api_tests/test_has_names.py::test_has_names[creation-from_dlpack] -array_api_tests/test_has_names.py::test_has_names[array_method-__dlpack__] -array_api_tests/test_has_names.py::test_has_names[array_method-__dlpack_device__] -array_api_tests/test_signatures.py::test_func_signature[from_dlpack] -array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] -array_api_tests/test_signatures.py::test_array_method_signature[__dlpack_device__] - -# qr() doesn't support matrix stacks -array_api_tests/test_linalg.py::test_qr - -# cross has some promotion bug that is fixed in newer numpy versions -array_api_tests/test_linalg.py::test_cross - -# vector_norm with ord=-1 which has since been fixed -# https://github.com/numpy/numpy/issues/21083 -array_api_tests/test_linalg.py::test_vector_norm - -# argmax and argmin do not support keepdims -array_api_tests/test_searching_functions.py::test_argmax -array_api_tests/test_searching_functions.py::test_argmin -array_api_tests/test_signatures.py::test_func_signature[argmax] -array_api_tests/test_signatures.py::test_func_signature[argmin] - -# unique doesn't support comparing nans as unequal -array_api_tests/test_set_functions.py::test_unique_all -array_api_tests/test_set_functions.py::test_unique_counts -array_api_tests/test_set_functions.py::test_unique_inverse -array_api_tests/test_set_functions.py::test_unique_values - -# The test suite is incorrectly checking sums that have loss of significance -# (https://github.com/data-apis/array-api-tests/issues/168) -array_api_tests/test_statistical_functions.py::test_sum -array_api_tests/test_statistical_functions.py::test_prod - -# NumPy 1.21 doesn't support NPY_PROMOTION_STATE=weak, so many tests fail with -# type promotion issues -array_api_tests/test_manipulation_functions.py::test_concat -array_api_tests/test_operators_and_elementwise_functions.py::test_add[__add__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_add[__add__(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_add[__iadd__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_add[add(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_atan2 -array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_and[__and__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_and[__and__(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_and[__iand__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_and[bitwise_and(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_left_shift[__ilshift__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_left_shift[__lshift__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_left_shift[__lshift__(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_left_shift[bitwise_left_shift(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_or[__ior__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_or[__or__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_or[__or__(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_or[bitwise_or(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_right_shift[__irshift__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_right_shift[__rshift__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_right_shift[__rshift__(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_right_shift[bitwise_right_shift(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor[__ixor__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor[__xor__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor[__xor__(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor[bitwise_xor(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_copysign -array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__truediv__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__truediv__(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_divide[divide(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_equal[__eq__(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_equal[equal(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__ifloordiv__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[floor_divide(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_greater[__gt__(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_greater[greater(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[greater_equal(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_hypot -array_api_tests/test_operators_and_elementwise_functions.py::test_less[__lt__(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_less[less(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_less_equal[__le__(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_less_equal[less_equal(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_logaddexp -array_api_tests/test_operators_and_elementwise_functions.py::test_maximum -array_api_tests/test_operators_and_elementwise_functions.py::test_minimum -array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__imul__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__mul__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__mul__(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[multiply(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_not_equal[__ne__(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_not_equal[not_equal(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_pow[__ipow__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_pow[__pow__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_pow[__pow__(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_pow[pow(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__imod__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__mod__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__mod__(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[remainder(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__isub__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__sub__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__sub__(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[subtract(x1, x2)] -array_api_tests/test_searching_functions.py::test_where -array_api_tests/test_special_cases.py::test_binary[__add__((x1_i is +0 or x1_i == -0) and isfinite(x2_i) and x2_i != 0) -> x2_i] -array_api_tests/test_special_cases.py::test_binary[__add__(isfinite(x1_i) and x1_i != 0 and (x2_i is +0 or x2_i == -0)) -> x1_i] -array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i < 0 and x2_i is +0) -> -infinity] -array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i < 0 and x2_i is -0) -> +infinity] -array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i > 0 and x2_i is +0) -> +infinity] -array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i > 0 and x2_i is -0) -> -infinity] -array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +0 and x2_i < 0) -> -0] -array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +0 and x2_i > 0) -> +0] -array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -0 and x2_i < 0) -> +0] -array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -0 and x2_i > 0) -> -0] -array_api_tests/test_special_cases.py::test_binary[__mod__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> x2_i] -array_api_tests/test_special_cases.py::test_binary[__mod__(isfinite(x1_i) and x1_i < 0 and x2_i is -infinity) -> x1_i] -array_api_tests/test_special_cases.py::test_binary[__mod__(isfinite(x1_i) and x1_i > 0 and x2_i is +infinity) -> x1_i] -array_api_tests/test_special_cases.py::test_binary[__mod__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> x2_i] -array_api_tests/test_special_cases.py::test_binary[__mod__(x1_i is +0 and x2_i < 0) -> -0] -array_api_tests/test_special_cases.py::test_binary[__mod__(x1_i is +0 and x2_i > 0) -> +0] -array_api_tests/test_special_cases.py::test_binary[__mod__(x1_i is -0 and x2_i < 0) -> -0] -array_api_tests/test_special_cases.py::test_binary[__mod__(x1_i is -0 and x2_i > 0) -> +0] -array_api_tests/test_special_cases.py::test_binary[__pow__(abs(x1_i) < 1 and x2_i is +infinity) -> +0] -array_api_tests/test_special_cases.py::test_binary[__pow__(abs(x1_i) < 1 and x2_i is -infinity) -> +infinity] -array_api_tests/test_special_cases.py::test_binary[__pow__(abs(x1_i) > 1 and x2_i is +infinity) -> +infinity] -array_api_tests/test_special_cases.py::test_binary[__pow__(abs(x1_i) > 1 and x2_i is -infinity) -> +0] -array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i < 0 and isfinite(x1_i) and isfinite(x2_i) and not x2_i.is_integer()) -> NaN] -array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is +0 and x2_i < 0) -> +infinity] -array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is +0 and x2_i > 0) -> +0] -array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is +infinity and x2_i < 0) -> +0] -array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is +infinity and x2_i > 0) -> +infinity] -array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -0 and x2_i < 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity] -array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -0 and x2_i < 0 and x2_i.is_integer() and x2_i % 2 == 1) -> -infinity] -array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -0 and x2_i > 0 and x2_i.is_integer() and x2_i % 2 == 1) -> -0] -array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -infinity and x2_i < 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0] -array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -infinity and x2_i < 0 and x2_i.is_integer() and x2_i % 2 == 1) -> -0] -array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -infinity and x2_i > 0 and x2_i.is_integer() and x2_i % 2 == 1) -> -infinity] -array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is NaN and not x2_i == 0) -> NaN] -array_api_tests/test_special_cases.py::test_binary[__truediv__(x1_i < 0 and x2_i is +0) -> -infinity] -array_api_tests/test_special_cases.py::test_binary[__truediv__(x1_i < 0 and x2_i is -0) -> +infinity] -array_api_tests/test_special_cases.py::test_binary[__truediv__(x1_i > 0 and x2_i is +0) -> +infinity] -array_api_tests/test_special_cases.py::test_binary[__truediv__(x1_i > 0 and x2_i is -0) -> -infinity] -array_api_tests/test_special_cases.py::test_binary[__truediv__(x1_i is +0 and x2_i < 0) -> -0] -array_api_tests/test_special_cases.py::test_binary[__truediv__(x1_i is +0 and x2_i > 0) -> +0] -array_api_tests/test_special_cases.py::test_binary[__truediv__(x1_i is -0 and x2_i < 0) -> +0] -array_api_tests/test_special_cases.py::test_binary[__truediv__(x1_i is -0 and x2_i > 0) -> -0] -array_api_tests/test_special_cases.py::test_binary[add((x1_i is +0 or x1_i == -0) and isfinite(x2_i) and x2_i != 0) -> x2_i] -array_api_tests/test_special_cases.py::test_binary[add(isfinite(x1_i) and x1_i != 0 and (x2_i is +0 or x2_i == -0)) -> x1_i] -array_api_tests/test_special_cases.py::test_binary[atan2(x1_i < 0 and x2_i is +0) -> roughly -pi/2] -array_api_tests/test_special_cases.py::test_binary[atan2(x1_i < 0 and x2_i is -0) -> roughly -pi/2] -array_api_tests/test_special_cases.py::test_binary[atan2(x1_i > 0 and x2_i is +0) -> roughly +pi/2] -array_api_tests/test_special_cases.py::test_binary[atan2(x1_i > 0 and x2_i is -0) -> roughly +pi/2] -array_api_tests/test_special_cases.py::test_binary[divide(x1_i < 0 and x2_i is +0) -> -infinity] -array_api_tests/test_special_cases.py::test_binary[divide(x1_i < 0 and x2_i is -0) -> +infinity] -array_api_tests/test_special_cases.py::test_binary[divide(x1_i > 0 and x2_i is +0) -> +infinity] -array_api_tests/test_special_cases.py::test_binary[divide(x1_i > 0 and x2_i is -0) -> -infinity] -array_api_tests/test_special_cases.py::test_binary[divide(x1_i is +0 and x2_i < 0) -> -0] -array_api_tests/test_special_cases.py::test_binary[divide(x1_i is +0 and x2_i > 0) -> +0] -array_api_tests/test_special_cases.py::test_binary[divide(x1_i is -0 and x2_i < 0) -> +0] -array_api_tests/test_special_cases.py::test_binary[divide(x1_i is -0 and x2_i > 0) -> -0] -array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i < 0 and x2_i is +0) -> -infinity] -array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i < 0 and x2_i is -0) -> +infinity] -array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i > 0 and x2_i is +0) -> +infinity] -array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i > 0 and x2_i is -0) -> -infinity] -array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +0 and x2_i < 0) -> -0] -array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +0 and x2_i > 0) -> +0] -array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -0 and x2_i < 0) -> +0] -array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -0 and x2_i > 0) -> -0] -array_api_tests/test_special_cases.py::test_binary[pow(abs(x1_i) < 1 and x2_i is +infinity) -> +0] -array_api_tests/test_special_cases.py::test_binary[pow(abs(x1_i) < 1 and x2_i is -infinity) -> +infinity] -array_api_tests/test_special_cases.py::test_binary[pow(abs(x1_i) > 1 and x2_i is +infinity) -> +infinity] -array_api_tests/test_special_cases.py::test_binary[pow(abs(x1_i) > 1 and x2_i is -infinity) -> +0] -array_api_tests/test_special_cases.py::test_binary[pow(x1_i < 0 and isfinite(x1_i) and isfinite(x2_i) and not x2_i.is_integer()) -> NaN] -array_api_tests/test_special_cases.py::test_binary[pow(x1_i is +0 and x2_i < 0) -> +infinity] -array_api_tests/test_special_cases.py::test_binary[pow(x1_i is +0 and x2_i > 0) -> +0] -array_api_tests/test_special_cases.py::test_binary[pow(x1_i is +infinity and x2_i < 0) -> +0] -array_api_tests/test_special_cases.py::test_binary[pow(x1_i is +infinity and x2_i > 0) -> +infinity] -array_api_tests/test_special_cases.py::test_binary[pow(x1_i is -0 and x2_i < 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity] -array_api_tests/test_special_cases.py::test_binary[pow(x1_i is -0 and x2_i < 0 and x2_i.is_integer() and x2_i % 2 == 1) -> -infinity] -array_api_tests/test_special_cases.py::test_binary[pow(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0] -array_api_tests/test_special_cases.py::test_binary[pow(x1_i is -0 and x2_i > 0 and x2_i.is_integer() and x2_i % 2 == 1) -> -0] -array_api_tests/test_special_cases.py::test_binary[pow(x1_i is -infinity and x2_i < 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0] -array_api_tests/test_special_cases.py::test_binary[pow(x1_i is -infinity and x2_i < 0 and x2_i.is_integer() and x2_i % 2 == 1) -> -0] -array_api_tests/test_special_cases.py::test_binary[pow(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity] -array_api_tests/test_special_cases.py::test_binary[pow(x1_i is -infinity and x2_i > 0 and x2_i.is_integer() and x2_i % 2 == 1) -> -infinity] -array_api_tests/test_special_cases.py::test_binary[pow(x1_i is NaN and not x2_i == 0) -> NaN] -array_api_tests/test_special_cases.py::test_binary[remainder(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> x2_i] -array_api_tests/test_special_cases.py::test_binary[remainder(isfinite(x1_i) and x1_i < 0 and x2_i is -infinity) -> x1_i] -array_api_tests/test_special_cases.py::test_binary[remainder(isfinite(x1_i) and x1_i > 0 and x2_i is +infinity) -> x1_i] -array_api_tests/test_special_cases.py::test_binary[remainder(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> x2_i] -array_api_tests/test_special_cases.py::test_binary[remainder(x1_i is +0 and x2_i < 0) -> -0] -array_api_tests/test_special_cases.py::test_binary[remainder(x1_i is +0 and x2_i > 0) -> +0] -array_api_tests/test_special_cases.py::test_binary[remainder(x1_i is -0 and x2_i < 0) -> -0] -array_api_tests/test_special_cases.py::test_binary[remainder(x1_i is -0 and x2_i > 0) -> +0] -array_api_tests/test_special_cases.py::test_iop[__iadd__(x1_i is -0 and x2_i is -0) -> -0] - -# 2023.12 support -array_api_tests/test_searching_functions.py::test_searchsorted -array_api_tests/test_signatures.py::test_func_signature[from_dlpack] -array_api_tests/test_signatures.py::test_func_signature[astype] -array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] -# uint64 repeats not supported -array_api_tests/test_manipulation_functions.py::test_repeat diff --git a/numpy-1-22-xfails.txt b/numpy-1-22-xfails.txt new file mode 100644 index 00000000..5df1b6d7 --- /dev/null +++ b/numpy-1-22-xfails.txt @@ -0,0 +1,175 @@ +# attributes are np.float32 instead of float +# (see also https://github.com/data-apis/array-api/issues/405) +array_api_tests/test_data_type_functions.py::test_finfo[float32] +array_api_tests/test_data_type_functions.py::test_finfo[complex64] + +# Array methods and attributes not already on np.ndarray cannot be wrapped +array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__] +array_api_tests/test_has_names.py::test_has_names[array_method-to_device] +array_api_tests/test_has_names.py::test_has_names[array_attribute-device] +array_api_tests/test_has_names.py::test_has_names[array_attribute-mT] + +# Array methods and attributes not already on np.ndarray cannot be wrapped +array_api_tests/test_signatures.py::test_array_method_signature[__array_namespace__] +array_api_tests/test_signatures.py::test_array_method_signature[to_device] + +# NumPy deviates in some special cases for floordiv +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] +array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] + +# https://github.com/numpy/numpy/issues/21213 +array_api_tests/meta/test_hypothesis_helpers.py::test_symmetric_matrices + +# NumPy 1.22 specific XFAILS +############################ + +# cross has some promotion bug that is fixed in newer numpy versions +array_api_tests/test_linalg.py::test_cross + +# linspace(-0.0, -1.0, num=1) returns +0.0 instead of -0.0. +# Fixed in newer numpy versions. +array_api_tests/test_creation_functions.py::test_linspace + +# vector_norm with ord=-1 which has since been fixed +# https://github.com/numpy/numpy/issues/21083 +array_api_tests/test_linalg.py::test_vector_norm + +# NumPy 1.22 doesn't support NPY_PROMOTION_STATE=weak, so many tests fail with +# type promotion issues +# NOTE: some of these may not fail until one runs array-api-tests with +# --max-examples 100000 +array_api_tests/test_manipulation_functions.py::test_concat +array_api_tests/test_operators_and_elementwise_functions.py::test_add[__add__(x, s)] +array_api_tests/test_operators_and_elementwise_functions.py::test_add[__add__(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_add[__iadd__(x, s)] +array_api_tests/test_operators_and_elementwise_functions.py::test_add[add(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_atan2 +array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_and[__and__(x, s)] +array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_and[__and__(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_and[__iand__(x, s)] +array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_and[bitwise_and(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_left_shift[__ilshift__(x, s)] +array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_left_shift[__lshift__(x, s)] +array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_left_shift[__lshift__(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_left_shift[bitwise_left_shift(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_or[__ior__(x, s)] +array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_or[__or__(x, s)] +array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_or[__or__(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_or[bitwise_or(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_right_shift[__irshift__(x, s)] +array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_right_shift[__rshift__(x, s)] +array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_right_shift[__rshift__(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_right_shift[bitwise_right_shift(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor[__ixor__(x, s)] +array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor[__xor__(x, s)] +array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor[__xor__(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor[bitwise_xor(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_copysign +array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__truediv__(x, s)] +array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__truediv__(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_divide[divide(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_equal[__eq__(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_equal[equal(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x, s)] +array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__ifloordiv__(x, s)] +array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[floor_divide(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_greater[__gt__(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_greater[greater(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[__ge__(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[greater_equal(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_hypot +array_api_tests/test_operators_and_elementwise_functions.py::test_less[__lt__(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_less[less(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_less_equal[__le__(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_less_equal[less_equal(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_logaddexp +array_api_tests/test_operators_and_elementwise_functions.py::test_maximum +array_api_tests/test_operators_and_elementwise_functions.py::test_minimum +array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__imul__(x, s)] +array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__mul__(x, s)] +array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__mul__(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[multiply(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_not_equal[__ne__(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_not_equal[not_equal(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_pow[__ipow__(x, s)] +array_api_tests/test_operators_and_elementwise_functions.py::test_pow[__pow__(x, s)] +array_api_tests/test_operators_and_elementwise_functions.py::test_pow[__pow__(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_pow[pow(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__imod__(x, s)] +array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__mod__(x, s)] +array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__mod__(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[remainder(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__isub__(x, s)] +array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__sub__(x, s)] +array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__sub__(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[subtract(x1, x2)] +array_api_tests/test_searching_functions.py::test_where +array_api_tests/test_special_cases.py::test_iop[__iadd__(x1_i is -0 and x2_i is -0) -> -0] + +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[add] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[divide] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[hypot] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[subtract] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[logaddexp] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[nextafter] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[multiply] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[maximum] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[minimum] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[copysign] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[pow] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[minimum] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[atan2] + +array_api_tests/test_operators_and_elementwise_functions.py::test_where_with_scalars + +# 2023.12 support +array_api_tests/test_has_names.py::test_has_names[creation-from_dlpack] +array_api_tests/test_signatures.py::test_func_signature[from_dlpack] +array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] +# uint64 repeats not supported +array_api_tests/test_manipulation_functions.py::test_repeat + +# 2024.12 support +array_api_tests/test_signatures.py::test_func_signature[bitwise_and] +array_api_tests/test_signatures.py::test_func_signature[bitwise_left_shift] +array_api_tests/test_signatures.py::test_func_signature[bitwise_or] +array_api_tests/test_signatures.py::test_func_signature[bitwise_right_shift] +array_api_tests/test_signatures.py::test_func_signature[bitwise_xor] +array_api_tests/test_data_type_functions.py::TestResultType::test_with_scalars + +# Stubs have a comment: (**note**: libraries may return ``NaN`` to match Python behavior.); Apparently,NumPy does just that +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] +array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] diff --git a/numpy-1-26-xfails.txt b/numpy-1-26-xfails.txt index 57b80e7e..98cb9f6c 100644 --- a/numpy-1-26-xfails.txt +++ b/numpy-1-26-xfails.txt @@ -1,8 +1,7 @@ -# asarray(copy=False) is not yet implemented -array_api_tests/test_creation_functions.py::test_asarray_arrays - -# finfo(float32).eps returns float32 but should return float +# attributes are np.float32 instead of float +# (see also https://github.com/data-apis/array-api/issues/405) array_api_tests/test_data_type_functions.py::test_finfo[float32] +array_api_tests/test_data_type_functions.py::test_finfo[complex64] # Array methods and attributes not already on np.ndarray cannot be wrapped array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__] @@ -35,21 +34,40 @@ array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] # https://github.com/numpy/numpy/issues/21213 -array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity] -array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0] -array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity] -array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0] array_api_tests/meta/test_hypothesis_helpers.py::test_symmetric_matrices -# The test suite is incorrectly checking sums that have loss of significance -# (https://github.com/data-apis/array-api-tests/issues/168) -array_api_tests/test_statistical_functions.py::test_sum -array_api_tests/test_statistical_functions.py::test_prod - # 2023.12 support -array_api_tests/test_searching_functions.py::test_searchsorted array_api_tests/test_signatures.py::test_func_signature[from_dlpack] -array_api_tests/test_signatures.py::test_func_signature[astype] array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] # uint64 repeats not supported array_api_tests/test_manipulation_functions.py::test_repeat + +# 2024.12 support +array_api_tests/test_signatures.py::test_func_signature[bitwise_and] +array_api_tests/test_signatures.py::test_func_signature[bitwise_left_shift] +array_api_tests/test_signatures.py::test_func_signature[bitwise_or] +array_api_tests/test_signatures.py::test_func_signature[bitwise_right_shift] +array_api_tests/test_signatures.py::test_func_signature[bitwise_xor] +array_api_tests/test_data_type_functions.py::TestResultType::test_with_scalars + +array_api_tests/test_operators_and_elementwise_functions.py::test_where_with_scalars + +# Stubs have a comment: (**note**: libraries may return ``NaN`` to match Python behavior.); Apparently, NumPy does just that +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] +array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] diff --git a/numpy-dev-xfails.txt b/numpy-dev-xfails.txt index 23a83e1e..972d2346 100644 --- a/numpy-dev-xfails.txt +++ b/numpy-dev-xfails.txt @@ -1,17 +1,7 @@ -# finfo(float32).eps returns float32 but should return float +# attributes are np.float32 instead of float +# (see also https://github.com/data-apis/array-api/issues/405) array_api_tests/test_data_type_functions.py::test_finfo[float32] - -# https://github.com/numpy/numpy/issues/21213 -array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity] -array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0] -array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity] -array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0] - -# The test suite is incorrectly checking sums that have loss of significance -# (https://github.com/data-apis/array-api-tests/issues/168) -array_api_tests/test_statistical_functions.py::test_sum -array_api_tests/test_statistical_functions.py::test_prod -array_api_tests/test_statistical_functions.py::test_cumulative_sum +array_api_tests/test_data_type_functions.py::test_finfo[complex64] # The test suite cannot properly get the signature for vecdot # https://github.com/numpy/numpy/pull/26237 @@ -19,9 +9,32 @@ array_api_tests/test_signatures.py::test_func_signature[vecdot] array_api_tests/test_signatures.py::test_extension_func_signature[linalg.vecdot] # 2023.12 support -# Argument 'device' missing from signature -array_api_tests/test_signatures.py::test_func_signature[astype] -array_api_tests/test_signatures.py::test_func_signature[from_dlpack] -array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] # uint64 repeats not supported array_api_tests/test_manipulation_functions.py::test_repeat + +# 2024.12 support +array_api_tests/test_signatures.py::test_func_signature[bitwise_and] +array_api_tests/test_signatures.py::test_func_signature[bitwise_left_shift] +array_api_tests/test_signatures.py::test_func_signature[bitwise_or] +array_api_tests/test_signatures.py::test_func_signature[bitwise_right_shift] +array_api_tests/test_signatures.py::test_func_signature[bitwise_xor] + +# Stubs have a comment: (**note**: libraries may return ``NaN`` to match Python behavior.); Apparently, NumPy does just that +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] +array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] diff --git a/numpy-skips.txt b/numpy-skips.txt index cbf7235b..e69de29b 100644 --- a/numpy-skips.txt +++ b/numpy-skips.txt @@ -1,11 +0,0 @@ -# These tests cause a core dump on CI, so we have to skip them entirely -array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__ifloordiv__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__ifloordiv__(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[floor_divide(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__imod__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__imod__(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__mod__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__mod__(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[remainder(x1, x2)] diff --git a/numpy-xfails.txt b/numpy-xfails.txt index 1c9d98f6..632b4ec3 100644 --- a/numpy-xfails.txt +++ b/numpy-xfails.txt @@ -1,7 +1,25 @@ -# finfo(float32).eps returns float32 but should return float +# attributes are np.float32 instead of float +# (see also https://github.com/data-apis/array-api/issues/405) array_api_tests/test_data_type_functions.py::test_finfo[float32] +array_api_tests/test_data_type_functions.py::test_finfo[complex64] -# NumPy deviates in some special cases for floordiv +# The test suite cannot properly get the signature for vecdot +# https://github.com/numpy/numpy/pull/26237 +array_api_tests/test_signatures.py::test_func_signature[vecdot] +array_api_tests/test_signatures.py::test_extension_func_signature[linalg.vecdot] + +# 2023.12 support +# uint64 repeats not supported +array_api_tests/test_manipulation_functions.py::test_repeat + +# 2024.12 support +array_api_tests/test_signatures.py::test_func_signature[bitwise_and] +array_api_tests/test_signatures.py::test_func_signature[bitwise_left_shift] +array_api_tests/test_signatures.py::test_func_signature[bitwise_or] +array_api_tests/test_signatures.py::test_func_signature[bitwise_right_shift] +array_api_tests/test_signatures.py::test_func_signature[bitwise_xor] + +# Stubs have a comment: (**note**: libraries may return ``NaN`` to match Python behavior.); Apparently, NumPy does just that array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] @@ -21,27 +39,3 @@ array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] -# https://github.com/numpy/numpy/issues/21213 -array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity] -array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0] -array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity] -array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0] -array_api_tests/meta/test_hypothesis_helpers.py::test_symmetric_matrices - -# The test suite is incorrectly checking sums that have loss of significance -# (https://github.com/data-apis/array-api-tests/issues/168) -array_api_tests/test_statistical_functions.py::test_sum -array_api_tests/test_statistical_functions.py::test_prod - -# The test suite cannot properly get the signature for vecdot -# https://github.com/numpy/numpy/pull/26237 -array_api_tests/test_signatures.py::test_func_signature[vecdot] -array_api_tests/test_signatures.py::test_extension_func_signature[linalg.vecdot] - -# 2023.12 support -array_api_tests/test_searching_functions.py::test_searchsorted -array_api_tests/test_signatures.py::test_func_signature[from_dlpack] -array_api_tests/test_signatures.py::test_func_signature[astype] -array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] -# uint64 repeats not supported -array_api_tests/test_manipulation_functions.py::test_repeat diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..ec054417 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,120 @@ +[build-system] +requires = ["setuptools", "setuptools-scm"] +build-backend = "setuptools.build_meta" + +[project] +name = "array-api-compat" +dynamic = ["version"] +description = "A wrapper around NumPy and other array libraries to make them compatible with the Array API standard" +readme = "README.md" +requires-python = ">=3.10" +license = "MIT" +authors = [{name = "Consortium for Python Data API Standards"}] +classifiers = [ + "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Topic :: Software Development :: Libraries :: Python Modules", + "Typing :: Typed", +] + +[project.optional-dependencies] +cupy = ["cupy"] +dask = ["dask>=2024.9.0"] +jax = ["jax"] +# Note: array-api-compat follows scikit-learn minimum dependencies, which support +# much older versions of NumPy than what SPEC0 recommends. +numpy = ["numpy>=1.22"] +pytorch = ["torch"] +sparse = ["sparse>=0.15.1"] +ndonnx = ["ndonnx"] +docs = [ + "furo", + "linkify-it-py", + "myst-parser", + "sphinx", + "sphinx-copybutton", + "sphinx-autobuild", +] +dev = [ + "array-api-strict", + "dask[array]>=2024.9.0", + "jax[cpu]", + "ndonnx", + "numpy>=1.22", + "pytest", + "torch", + "sparse>=0.15.1", +] + +[project.urls] +homepage = "https://data-apis.org/array-api-compat/" +repository = "https://github.com/data-apis/array-api-compat/" + +[tool.setuptools.dynamic] +version = {attr = "array_api_compat.__version__"} + +[tool.setuptools.packages.find] +include = ["array_api_compat*"] +namespaces = false + +[tool.ruff.lint] +preview = true +select = [ +# Defaults +"E4", "E7", "E9", "F", +# Undefined export +"F822", +# Useless import alias +"PLC0414" +] + +ignore = [ + # Module import not at top of file + "E402", + # Do not use bare `except` + "E722" +] + + +[tool.mypy] +files = ["array_api_compat"] +disallow_incomplete_defs = true +disallow_untyped_decorators = true +disallow_untyped_defs = false # TODO +ignore_missing_imports = false +no_implicit_optional = true +show_error_codes = true +warn_redundant_casts = true +warn_unused_ignores = true +warn_unreachable = true + +[[tool.mypy.overrides]] +module = ["cupy.*", "cupy_backends.*", "dask.*", "jax.*", "ndonnx.*", "sparse.*", "torch.*"] +ignore_missing_imports = true + + +[tool.pyright] +include = ["src", "tests"] +pythonPlatform = "All" + +reportAny = false +reportExplicitAny = false +# missing type stubs +reportAttributeAccessIssue = false +reportUnknownMemberType = false +reportUnknownVariableType = false +# Redundant with mypy checks +reportMissingImports = false +reportMissingTypeStubs = false +# false positives for input validation +reportUnreachable = false +# ruff handles this +reportUnusedParameter = false + +executionEnvironments = [ + { root = "array_api_compat" }, +] diff --git a/requirements-dev.txt b/requirements-dev.txt deleted file mode 100644 index c9d10f71..00000000 --- a/requirements-dev.txt +++ /dev/null @@ -1,8 +0,0 @@ -array-api-strict -dask[array] -jax[cpu] -numpy -pytest -torch -sparse >=0.15.1 -ndonnx diff --git a/ruff.toml b/ruff.toml deleted file mode 100644 index 72e111b5..00000000 --- a/ruff.toml +++ /dev/null @@ -1,17 +0,0 @@ -[lint] -preview = true -select = [ -# Defaults -"E4", "E7", "E9", "F", -# Undefined export -"F822", -# Useless import alias -"PLC0414" -] - -ignore = [ - # Module import not at top of file - "E402", - # Do not use bare `except` - "E722" -] diff --git a/setup.py b/setup.py deleted file mode 100644 index 6e7a9949..00000000 --- a/setup.py +++ /dev/null @@ -1,34 +0,0 @@ -from setuptools import setup, find_packages - -with open("README.md", "r") as fh: - long_description = fh.read() - -import array_api_compat - -setup( - name='array_api_compat', - version=array_api_compat.__version__, - packages=find_packages(include=["array_api_compat*"]), - author="Consortium for Python Data API Standards", - description="A wrapper around NumPy and other array libraries to make them compatible with the Array API standard", - long_description=long_description, - long_description_content_type="text/markdown", - url="https://data-apis.org/array-api-compat/", - license="MIT", - extras_require={ - "numpy": "numpy", - "cupy": "cupy", - "jax": "jax", - "pytorch": "pytorch", - "dask": "dask", - "sparse": "sparse >=0.15.1", - }, - classifiers=[ - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", - "License :: OSI Approved :: MIT License", - "Operating System :: OS Independent", - ], -) diff --git a/test_cupy.sh b/test_cupy.sh index 2e176aa1..a6974333 100755 --- a/test_cupy.sh +++ b/test_cupy.sh @@ -26,5 +26,5 @@ mkdir -p $SCRIPT_DIR/.hypothesis ln -s $SCRIPT_DIR/.hypothesis .hypothesis export ARRAY_API_TESTS_MODULE=array_api_compat.cupy -export ARRAY_API_TESTS_VERSION=2023.12 +export ARRAY_API_TESTS_VERSION=2024.12 pytest array_api_tests/ ${PYTEST_ARGS} --xfails-file $SCRIPT_DIR/cupy-xfails.txt "$@" diff --git a/tests/_helpers.py b/tests/_helpers.py index e2a7e1d1..17865aa0 100644 --- a/tests/_helpers.py +++ b/tests/_helpers.py @@ -1,18 +1,14 @@ from importlib import import_module -import sys import pytest wrapped_libraries = ["numpy", "cupy", "torch", "dask.array"] -all_libraries = wrapped_libraries + ["jax.numpy"] - -# `sparse` added array API support as of Python 3.10. -if sys.version_info >= (3, 10): - all_libraries.append('sparse') +all_libraries = wrapped_libraries + [ + "array_api_strict", "jax.numpy", "ndonnx", "sparse" +] def import_(library, wrapper=False): - if library == 'cupy': - pytest.importorskip(library) + pytest.importorskip(library) if wrapper: if 'jax' in library: # JAX v0.4.32 implements the array API directly in jax.numpy @@ -20,9 +16,18 @@ def import_(library, wrapper=False): jax_numpy = import_module("jax.numpy") if not hasattr(jax_numpy, "__array_api_version__"): library = 'jax.experimental.array_api' - elif library.startswith('sparse'): - library = 'sparse' - else: + elif library in wrapped_libraries: library = 'array_api_compat.' + library return import_module(library) + + +def xfail(request: pytest.FixtureRequest, reason: str) -> None: + """ + XFAIL the currently running test. + + Unlike ``pytest.xfail``, allow rest of test to execute instead of immediately + halting it, so that it may result in a XPASS. + xref https://github.com/pandas-dev/pandas/issues/38902 + """ + request.node.add_marker(pytest.mark.xfail(reason=reason)) diff --git a/tests/test_all.py b/tests/test_all.py index 969d5cfb..c36aef67 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -1,44 +1,311 @@ -""" -Test that files that define __all__ aren't missing any exports. +"""Test exported names""" -You can add names that shouldn't be exported to _all_ignore, like +import builtins -_all_ignore = ['sys'] +import numpy as np +import pytest -This is preferable to del-ing the names as this will break any name that is -used inside of a function. Note that names starting with an underscore are automatically ignored. -""" +from array_api_compat._internal import clone_module +from ._helpers import wrapped_libraries -import sys +NAMES = { + "": [ + # Inspection + "__array_api_version__", + "__array_namespace_info__", + # Submodules + "fft", + "linalg", + # Constants + "e", + "inf", + "nan", + "newaxis", + "pi", + # Creation Functions + "arange", + "asarray", + "empty", + "empty_like", + "eye", + "from_dlpack", + "full", + "full_like", + "linspace", + "meshgrid", + "ones", + "ones_like", + "tril", + "triu", + "zeros", + "zeros_like", + # Data Type Functions + "astype", + "can_cast", + "finfo", + "iinfo", + "isdtype", + "result_type", + # Data Types + "bool", + "int8", + "int16", + "int32", + "int64", + "uint8", + "uint16", + "uint32", + "uint64", + "float32", + "float64", + "complex64", + "complex128", + # Elementwise Functions + "abs", + "acos", + "acosh", + "add", + "asin", + "asinh", + "atan", + "atan2", + "atanh", + "bitwise_and", + "bitwise_left_shift", + "bitwise_invert", + "bitwise_or", + "bitwise_right_shift", + "bitwise_xor", + "ceil", + "clip", + "conj", + "copysign", + "cos", + "cosh", + "divide", + "equal", + "exp", + "expm1", + "floor", + "floor_divide", + "greater", + "greater_equal", + "hypot", + "imag", + "isfinite", + "isinf", + "isnan", + "less", + "less_equal", + "log", + "log1p", + "log2", + "log10", + "logaddexp", + "logical_and", + "logical_not", + "logical_or", + "logical_xor", + "maximum", + "minimum", + "multiply", + "negative", + "nextafter", + "not_equal", + "positive", + "pow", + "real", + "reciprocal", + "remainder", + "round", + "sign", + "signbit", + "sin", + "sinh", + "square", + "sqrt", + "subtract", + "tan", + "tanh", + "trunc", + # Indexing Functions + "take", + "take_along_axis", + # Linear Algebra Functions + "matmul", + "matrix_transpose", + "tensordot", + "vecdot", + # Manipulation Functions + "broadcast_arrays", + "broadcast_to", + "concat", + "expand_dims", + "flip", + "moveaxis", + "permute_dims", + "repeat", + "reshape", + "roll", + "squeeze", + "stack", + "tile", + "unstack", + # Searching Functions + "argmax", + "argmin", + "count_nonzero", + "nonzero", + "searchsorted", + "where", + # Set Functions + "unique_all", + "unique_counts", + "unique_inverse", + "unique_values", + # Sorting Functions + "argsort", + "sort", + # Statistical Functions + "cumulative_prod", + "cumulative_sum", + "max", + "mean", + "min", + "prod", + "std", + "sum", + "var", + # Utility Functions + "all", + "any", + "diff", + ], + "fft": [ + "fft", + "ifft", + "fftn", + "ifftn", + "rfft", + "irfft", + "rfftn", + "irfftn", + "hfft", + "ihfft", + "fftfreq", + "rfftfreq", + "fftshift", + "ifftshift", + ], + "linalg": [ + "cholesky", + "cross", + "det", + "diagonal", + "eigh", + "eigvalsh", + "inv", + "matmul", + "matrix_norm", + "matrix_power", + "matrix_rank", + "matrix_transpose", + "outer", + "pinv", + "qr", + "slogdet", + "solve", + "svd", + "svdvals", + "tensordot", + "trace", + "vecdot", + "vector_norm", + ], +} -from ._helpers import import_, wrapped_libraries +XFAILS = { + ("numpy", ""): ["from_dlpack"] if np.__version__ < "1.23" else [], + ("dask.array", ""): ["from_dlpack", "take_along_axis"], + ("dask.array", "linalg"): [ + "cross", + "det", + "eigh", + "eigvalsh", + "matrix_power", + "pinv", + "slogdet", + ], +} -import pytest -@pytest.mark.parametrize("library", ["common"] + wrapped_libraries) -def test_all(library): - import_(library, wrapper=True) +def all_names(mod): + """Return all names available in a module.""" + objs = {} + clone_module(mod.__name__, objs) + return set(objs) + + +def get_mod(library, module, *, compat): + if compat: + library = f"array_api_compat.{library}" + xp = pytest.importorskip(library) + return getattr(xp, module) if module else xp + + +@pytest.mark.parametrize("module", list(NAMES)) +@pytest.mark.parametrize("library", wrapped_libraries) +def test_array_api_names(library, module): + """Test that __all__ isn't missing any exports + dictated by the Standard. + """ + mod = get_mod(library, module, compat=True) + missing = set(NAMES[module]) - all_names(mod) + xfail = set(XFAILS.get((library, module), [])) + xpass = xfail - missing + fails = missing - xfail + assert not xpass, f"Names in XFAILS are defined: {xpass}" + assert not fails, f"Missing exports: {fails}" + + +@pytest.mark.parametrize("module", list(NAMES)) +@pytest.mark.parametrize("library", wrapped_libraries) +def test_compat_doesnt_hide_names(library, module): + """The base namespace can have more names than the ones explicitly exported + by array-api-compat. Test that we're not suppressing them. + """ + bare_mod = get_mod(library, module, compat=False) + compat_mod = get_mod(library, module, compat=True) + + missing = all_names(bare_mod) - all_names(compat_mod) + missing = {name for name in missing if not name.startswith("_")} + assert not missing, f"Non-Array API names have been hidden: {missing}" - for mod_name in sys.modules: - if not mod_name.startswith('array_api_compat.' + library): - continue - module = sys.modules[mod_name] +@pytest.mark.parametrize("module", list(NAMES)) +@pytest.mark.parametrize("library", wrapped_libraries) +def test_compat_doesnt_add_names(library, module): + """Test that array-api-compat isn't adding names to the namespace + besides those defined by the Array API Standard. + """ + bare_mod = get_mod(library, module, compat=False) + compat_mod = get_mod(library, module, compat=True) - # TODO: We should define __all__ in the __init__.py files and test it - # there too. - if not hasattr(module, '__all__'): - continue + aapi_names = set(NAMES[module]) + spurious = all_names(compat_mod) - all_names(bare_mod) - aapi_names + # Quietly ignore *Result dataclasses + spurious = {name for name in spurious if not name.endswith("Result")} + assert not spurious, ( + f"array-api-compat is adding non-Array API names: {spurious}" + ) - dir_names = [n for n in dir(module) if not n.startswith('_')] - if '__array_namespace_info__' in dir(module): - dir_names.append('__array_namespace_info__') - ignore_all_names = getattr(module, '_all_ignore', []) - ignore_all_names += ['annotations', 'TYPE_CHECKING'] - dir_names = set(dir_names) - set(ignore_all_names) - all_names = module.__all__ - if set(dir_names) != set(all_names): - assert set(dir_names) - set(all_names) == set(), f"Some dir() names not included in __all__ for {mod_name}" - assert set(all_names) - set(dir_names) == set(), f"Some __all__ names not in dir() for {mod_name}" +@pytest.mark.parametrize( + "name", [name for name in NAMES[""] if hasattr(builtins, name)] +) +@pytest.mark.parametrize("library", wrapped_libraries) +def test_builtins_collision(library, name): + """Test that xp.bool is not accidentally builtins.bool, etc.""" + xp = pytest.importorskip(f"array_api_compat.{library}") + assert getattr(xp, name) is not getattr(builtins, name) diff --git a/tests/test_array_namespace.py b/tests/test_array_namespace.py index e35e31e1..311efc37 100644 --- a/tests/test_array_namespace.py +++ b/tests/test_array_namespace.py @@ -2,67 +2,72 @@ import sys import warnings -import jax import numpy as np import pytest -import torch import array_api_compat from array_api_compat import array_namespace -from ._helpers import import_, all_libraries, wrapped_libraries +from ._helpers import all_libraries, wrapped_libraries, xfail + @pytest.mark.parametrize("use_compat", [True, False, None]) -@pytest.mark.parametrize("api_version", [None, "2021.12", "2022.12"]) -@pytest.mark.parametrize("library", all_libraries + ['array_api_strict']) -def test_array_namespace(library, api_version, use_compat): - xp = import_(library) +@pytest.mark.parametrize( + "api_version", [None, "2021.12", "2022.12", "2023.12", "2024.12"] +) +@pytest.mark.parametrize("library", all_libraries) +def test_array_namespace(request, library, api_version, use_compat): + xp = pytest.importorskip(library) array = xp.asarray([1.0, 2.0, 3.0]) - if use_compat is True and library in {'array_api_strict', 'jax.numpy', 'sparse'}: + if use_compat and library not in wrapped_libraries: pytest.raises(ValueError, lambda: array_namespace(array, use_compat=use_compat)) return - namespace = array_api_compat.array_namespace(array, api_version=api_version, use_compat=use_compat) + if (library == "sparse" and api_version in ("2023.12", "2024.12")) or ( + library == "jax.numpy" and api_version in ("2021.12", "2022.12", "2023.12") + ): + xfail(request, "Unsupported API version") + + with warnings.catch_warnings(): + warnings.simplefilter('ignore', UserWarning) + namespace = array_namespace(array, api_version=api_version, use_compat=use_compat) if use_compat is False or use_compat is None and library not in wrapped_libraries: - if library == "jax.numpy" and use_compat is None: - import jax.numpy - if hasattr(jax.numpy, "__array_api_version__"): - # JAX v0.4.32 or later uses jax.numpy directly - assert namespace == jax.numpy - else: - # JAX v0.4.31 or earlier uses jax.experimental.array_api - import jax.experimental.array_api - assert namespace == jax.experimental.array_api + if library == "jax.numpy" and not hasattr(xp, "__array_api_version__"): + # Backwards compatibility for JAX <0.4.32 + import jax.experimental.array_api + assert namespace == jax.experimental.array_api else: assert namespace == xp + elif library == "dask.array": + assert namespace == array_api_compat.dask.array else: - if library == "dask.array": - assert namespace == array_api_compat.dask.array - else: - assert namespace == getattr(array_api_compat, library) + assert namespace == getattr(array_api_compat, library) if library == "numpy": # check that the same namespace is returned for NumPy scalars - scalar_namespace = array_api_compat.array_namespace( - xp.float64(0.0), api_version=api_version, use_compat=use_compat - ) - assert scalar_namespace == namespace - - # Check that array_namespace works even if jax.experimental.array_api - # hasn't been imported yet (it monkeypatches __array_namespace__ - # onto JAX arrays, but we should support them regardless). The only way to - # do this is to use a subprocess, since we cannot un-import it and another - # test probably already imported it. - if library == "jax.numpy" and sys.version_info >= (3, 9): - code = f"""\ + with warnings.catch_warnings(): + warnings.simplefilter('ignore', UserWarning) + + scalar_namespace = array_namespace( + xp.float64(0.0), api_version=api_version, use_compat=use_compat + ) + assert scalar_namespace == namespace + + +def test_jax_backwards_compat(): + """On JAX <0.4.32, test that array_namespace works even if + jax.experimental.array_api has not been imported yet. + """ + pytest.importorskip("jax") + code = """\ import sys import jax.numpy import array_api_compat -array = jax.numpy.asarray([1.0, 2.0, 3.0]) +array = jax.numpy.asarray([1.0, 2.0, 3.0]) assert 'jax.experimental.array_api' not in sys.modules -namespace = array_api_compat.array_namespace(array, api_version={api_version!r}) +namespace = array_api_compat.array_namespace(array) if hasattr(jax.numpy, '__array_api_version__'): assert namespace == jax.numpy @@ -70,13 +75,15 @@ def test_array_namespace(library, api_version, use_compat): import jax.experimental.array_api assert namespace == jax.experimental.array_api """ - subprocess.run([sys.executable, "-c", code], check=True) + subprocess.check_call([sys.executable, "-c", code]) + def test_jax_zero_gradient(): + jax = pytest.importorskip("jax") jx = jax.numpy.arange(4) jax_zero = jax.vmap(jax.grad(jax.numpy.float32, allow_int=True))(jx) - assert (array_api_compat.get_namespace(jax_zero) is - array_api_compat.get_namespace(jx)) + assert array_namespace(jax_zero) is array_namespace(jx) + def test_array_namespace_errors(): pytest.raises(TypeError, lambda: array_namespace([1])) @@ -86,25 +93,40 @@ def test_array_namespace_errors(): pytest.raises(TypeError, lambda: array_namespace((x, x))) pytest.raises(TypeError, lambda: array_namespace(x, (x, x))) -def test_array_namespace_errors_torch(): - y = torch.asarray([1, 2]) - x = np.asarray([1, 2]) - pytest.raises(TypeError, lambda: array_namespace(x, y)) - -def test_api_version(): - x = torch.asarray([1, 2]) - torch_ = import_("torch", wrapper=True) - assert array_namespace(x, api_version="2022.12") == torch_ - assert array_namespace(x, api_version=None) == torch_ - assert array_namespace(x) == torch_ - # Should issue a warning - with warnings.catch_warnings(record=True) as w: - assert array_namespace(x, api_version="2021.12") == torch_ - assert len(w) == 1 - assert "2021.12" in str(w[0].message) - - pytest.raises(ValueError, lambda: array_namespace(x, api_version="2020.12")) + +@pytest.mark.parametrize("library", all_libraries) +def test_array_namespace_many_args(library): + xp = pytest.importorskip(library) + a = xp.asarray(1) + b = xp.asarray(2) + assert array_namespace(a, b) is array_namespace(a) + + +def test_array_namespace_mismatch(): + xp = pytest.importorskip("array_api_strict") + with pytest.raises(TypeError, match="Multiple namespaces"): + array_namespace(np.asarray(1), xp.asarray(1)) + def test_get_namespace(): # Backwards compatible wrapper - assert array_api_compat.get_namespace is array_api_compat.array_namespace + assert array_api_compat.get_namespace is array_namespace + + +@pytest.mark.parametrize("library", all_libraries) +def test_python_scalars(library): + xp = pytest.importorskip(library) + a = xp.asarray([1, 2]) + xp = array_namespace(a) + + pytest.raises(TypeError, lambda: array_namespace(1)) + pytest.raises(TypeError, lambda: array_namespace(1.0)) + pytest.raises(TypeError, lambda: array_namespace(1j)) + pytest.raises(TypeError, lambda: array_namespace(True)) + pytest.raises(TypeError, lambda: array_namespace(None)) + + assert array_namespace(a, 1) is xp + assert array_namespace(a, 1.0) is xp + assert array_namespace(a, 1j) is xp + assert array_namespace(a, True) is xp + assert array_namespace(a, None) is xp diff --git a/tests/test_common.py b/tests/test_common.py index e1cfa9eb..85ed032e 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -1,18 +1,25 @@ +import math + +import pytest +import numpy as np +import array +from numpy.testing import assert_equal + from array_api_compat import ( # noqa: F401 is_numpy_array, is_cupy_array, is_torch_array, is_dask_array, is_jax_array, is_pydata_sparse_array, + is_ndonnx_array, is_numpy_namespace, is_cupy_namespace, is_torch_namespace, is_dask_namespace, is_jax_namespace, is_pydata_sparse_namespace, + is_array_api_strict_namespace, is_ndonnx_namespace, ) -from array_api_compat import is_array_api_obj, device, to_device - -from ._helpers import import_, wrapped_libraries, all_libraries +from array_api_compat import ( + device, is_array_api_obj, is_lazy_array, is_writeable_array, size, to_device +) +from array_api_compat.common._helpers import _DASK_DEVICE +from ._helpers import all_libraries, import_, wrapped_libraries, xfail -import pytest -import numpy as np -import array -from numpy.testing import assert_allclose is_array_functions = { 'numpy': 'is_numpy_array', @@ -21,6 +28,7 @@ 'dask.array': 'is_dask_array', 'jax.numpy': 'is_jax_array', 'sparse': 'is_pydata_sparse_array', + 'ndonnx': 'is_ndonnx_array', } is_namespace_functions = { @@ -30,6 +38,8 @@ 'dask.array': 'is_dask_namespace', 'jax.numpy': 'is_jax_namespace', 'sparse': 'is_pydata_sparse_namespace', + 'array_api_strict': 'is_array_api_strict_namespace', + 'ndonnx': 'is_ndonnx_namespace', } @@ -55,18 +65,154 @@ def test_is_xp_namespace(library, func): assert is_func(lib) == (func == is_namespace_functions[library]) +@pytest.mark.parametrize('library', all_libraries) +def test_xp_is_array_generics(library): + """ + Test that scalar selection on a xp.ndarray always returns + an object that matches with exactly one among the is_*_array + function of the same library and is_numpy_array. + """ + lib = import_(library) + x = lib.asarray([1, 2, 3]) + x0 = x[0] + + matches = [] + for library2, func in is_array_functions.items(): + is_func = globals()[func] + if is_func(x0): + matches.append(library2) + + if library == "array_api_strict": + # There is no is_array_api_strict_array() function + assert matches == [] + else: + assert matches in ([library], ["numpy"]) + + @pytest.mark.parametrize("library", all_libraries) -def test_device(library): - xp = import_(library, wrapper=True) +def test_is_writeable_array(library): + lib = import_(library) + x = lib.asarray([1, 2, 3]) + if is_writeable_array(x): + x[1] = 4 + else: + with pytest.raises((TypeError, ValueError)): + x[1] = 4 + + +def test_is_writeable_array_numpy(): + x = np.asarray([1, 2, 3]) + assert is_writeable_array(x) + x.flags.writeable = False + assert not is_writeable_array(x) + + +@pytest.mark.parametrize("library", all_libraries) +def test_size(library): + xp = import_(library) + x = xp.asarray([1, 2, 3]) + assert size(x) == 3 + + +@pytest.mark.parametrize("library", all_libraries) +def test_size_none(library): + if library == "sparse": + pytest.skip("No arange(); no indexing by sparse arrays") + + xp = import_(library) + x = xp.arange(10) + x = x[x < 5] + + # dask.array now has shape=(nan, ) and size=nan + # ndonnx now has shape=(None, ) and size=None + # Eager libraries have shape=(5, ) and size=5 + assert size(x) in (None, 5) - # We can't test much for device() and to_device() other than that - # x.to_device(x.device) works. +@pytest.mark.parametrize("library", all_libraries) +def test_is_lazy_array(library): + lib = import_(library) + x = lib.asarray([1, 2, 3]) + assert isinstance(is_lazy_array(x), bool) + + +@pytest.mark.parametrize("shape", [(math.nan,), (1, math.nan), (None, ), (1, None)]) +def test_is_lazy_array_nan_size(shape, monkeypatch): + """Test is_lazy_array() on an unknown Array API compliant object + with NaN (like Dask) or None (like ndonnx) in its shape + """ + xp = import_("array_api_strict") + x = xp.asarray(1) + assert not is_lazy_array(x) + monkeypatch.setattr(type(x), "shape", shape) + assert is_lazy_array(x) + + +@pytest.mark.parametrize("exc", [TypeError, AssertionError]) +def test_is_lazy_array_bool_raises(exc, monkeypatch): + """Test is_lazy_array() on an unknown Array API compliant object + where calling bool() raises: + - TypeError: e.g. like jitted JAX. This is the proper exception which + lazy arrays should raise as per the Array API specification + - something else: e.g. like Dask, where bool() triggers compute() + which can result in any kind of exception to be raised + """ + xp = import_("array_api_strict") + x = xp.asarray(1) + assert not is_lazy_array(x) + + def __bool__(self): + raise exc("Hello world") + + monkeypatch.setattr(type(x), "__bool__", __bool__) + assert is_lazy_array(x) + + +@pytest.mark.parametrize( + 'func', + list(is_array_functions.values()) + + ["is_array_api_obj", "is_lazy_array", "is_writeable_array"] +) +def test_is_array_any_object(func): + """Test that is_*_array functions return False and don't raise on non-array objects + """ + func = globals()[func] + + # These objects are missing attributes such as __name__ + assert not func(object()) + assert not func(None) + assert not func(1) + + class C: + pass + + assert not func(C()) + + +@pytest.mark.parametrize("library", all_libraries) +def test_device_to_device(library, request): + if library == "ndonnx": + xfail(request, reason="Stub raises ValueError") + if library == "sparse": + xfail(request, reason="No __array_namespace_info__()") + if library == "array_api_strict": + if np.__version__ < "2": + xfail(request, reason="no copy argument of np.asarray") + + xp = import_(library, wrapper=True) + devices = xp.__array_namespace_info__().devices() + + # Default device x = xp.asarray([1, 2, 3]) dev = device(x) - x2 = to_device(x, dev) - assert device(x) == device(x2) + for dev in devices: + if dev is None: # JAX >=0.5.3 + continue + if dev is _DASK_DEVICE: # TODO this needs a better design + continue + y = to_device(x, dev) + assert device(y) == dev @pytest.mark.parametrize("library", wrapped_libraries) @@ -85,32 +231,45 @@ def test_to_device_host(library): # a `device(x)` query; however, what's really important # here is that we can test portably after calling # to_device(x, "cpu") to return to host - assert_allclose(x, expected) + assert_equal(x, expected) @pytest.mark.parametrize("target_library", is_array_functions.keys()) @pytest.mark.parametrize("source_library", is_array_functions.keys()) def test_asarray_cross_library(source_library, target_library, request): if source_library == "dask.array" and target_library == "torch": - # Allow rest of test to execute instead of immediately xfailing - # xref https://github.com/pandas-dev/pandas/issues/38902 - # TODO: remove xfail once # https://github.com/dask/dask/issues/8260 is resolved - request.node.add_marker(pytest.mark.xfail(reason="Bug in dask raising error on conversion")) - if source_library == "cupy" and target_library != "cupy": + xfail(request, reason="Bug in dask raising error on conversion") + + elif ( + source_library == "ndonnx" + and target_library not in ("array_api_strict", "ndonnx", "numpy") + ): + xfail(request, reason="The truth value of lazy Array Array(dtype=Boolean) is unknown") + elif source_library == "ndonnx" and target_library == "numpy": + xfail(request, reason="produces numpy array of ndonnx scalar arrays") + elif target_library == "ndonnx" and source_library in ("torch", "dask.array", "jax.numpy"): + xfail(request, reason="unable to infer dtype") + + elif source_library == "jax.numpy" and target_library == "torch": + xfail(request, reason="casts int to float") + elif source_library == "cupy" and target_library != "cupy": # cupy explicitly disallows implicit conversions to CPU pytest.skip(reason="cupy does not support implicit conversion to CPU") elif source_library == "sparse" and target_library != "sparse": pytest.skip(reason="`sparse` does not allow implicit densification") + src_lib = import_(source_library, wrapper=True) tgt_lib = import_(target_library, wrapper=True) is_tgt_type = globals()[is_array_functions[target_library]] - a = src_lib.asarray([1, 2, 3]) + a = src_lib.asarray([1, 2, 3], dtype=src_lib.int32) b = tgt_lib.asarray(a) assert is_tgt_type(b), f"Expected {b} to be a {tgt_lib.ndarray}, but was {type(b)}" + assert b.dtype == tgt_lib.int32 + @pytest.mark.parametrize("library", wrapped_libraries) def test_asarray_copy(library): @@ -121,91 +280,99 @@ def test_asarray_copy(library): xp = import_(library, wrapper=True) asarray = xp.asarray is_lib_func = globals()[is_array_functions[library]] - all = xp.all if library != 'dask.array' else lambda x: xp.all(x).compute() - - if library == 'numpy' and xp.__version__[0] < '2' and not hasattr(xp, '_CopyMode') : - supports_copy_false = False - elif library in ['cupy', 'dask.array']: - supports_copy_false = False - else: - supports_copy_false = True a = asarray([1]) b = asarray(a, copy=True) assert is_lib_func(b) a[0] = 0 - assert all(b[0] == 1) - assert all(a[0] == 0) + assert b[0] == 1 + assert a[0] == 0 a = asarray([1]) - if supports_copy_false: - b = asarray(a, copy=False) - assert is_lib_func(b) - a[0] = 0 - assert all(b[0] == 0) - else: - pytest.raises(NotImplementedError, lambda: asarray(a, copy=False)) - a = asarray([1]) - if supports_copy_false: - pytest.raises(ValueError, lambda: asarray(a, copy=False, - dtype=xp.float64)) - else: - pytest.raises(NotImplementedError, lambda: asarray(a, copy=False, dtype=xp.float64)) + # Test copy=False within the same namespace + b = asarray(a, copy=False) + assert is_lib_func(b) + a[0] = 0 + assert b[0] == 0 + with pytest.raises(ValueError): + asarray(a, copy=False, dtype=xp.float64) + # copy=None defaults to False when possible a = asarray([1]) b = asarray(a, copy=None) assert is_lib_func(b) a[0] = 0 - assert all(b[0] == 0) + assert b[0] == 0 + # copy=None defaults to True when impossible a = asarray([1.0], dtype=xp.float32) assert a.dtype == xp.float32 b = asarray(a, dtype=xp.float64, copy=None) assert is_lib_func(b) assert b.dtype == xp.float64 a[0] = 0.0 - assert all(b[0] == 1.0) + assert b[0] == 1.0 + # copy=None defaults to False when possible a = asarray([1.0], dtype=xp.float64) assert a.dtype == xp.float64 b = asarray(a, dtype=xp.float64, copy=None) assert is_lib_func(b) assert b.dtype == xp.float64 a[0] = 0.0 - assert all(b[0] == 0.0) + assert b[0] == 0.0 # Python built-in types for obj in [True, 0, 0.0, 0j, [0], [[0]]]: - asarray(obj, copy=True) # No error - asarray(obj, copy=None) # No error - if supports_copy_false: - pytest.raises(ValueError, lambda: asarray(obj, copy=False)) - else: - pytest.raises(NotImplementedError, lambda: asarray(obj, copy=False)) + asarray(obj, copy=True) # No error + asarray(obj, copy=None) # No error + + with pytest.raises(ValueError): + asarray(obj, copy=False) # Use the standard library array to test the buffer protocol - a = array.array('f', [1.0]) + a = array.array("f", [1.0]) b = asarray(a, copy=True) assert is_lib_func(b) a[0] = 0.0 - assert all(b[0] == 1.0) + assert b[0] == 1.0 - a = array.array('f', [1.0]) - if supports_copy_false: + a = array.array("f", [1.0]) + if library in ("cupy", "dask.array"): + with pytest.raises(ValueError): + asarray(a, copy=False) + else: b = asarray(a, copy=False) assert is_lib_func(b) a[0] = 0.0 - assert all(b[0] == 0.0) - else: - pytest.raises(NotImplementedError, lambda: asarray(a, copy=False)) + assert b[0] == 0.0 - a = array.array('f', [1.0]) + a = array.array("f", [1.0]) b = asarray(a, copy=None) assert is_lib_func(b) a[0] = 0.0 - if library == 'cupy': + if library in ("cupy", "dask.array"): # A copy is required for libraries where the default device is not CPU - assert all(b[0] == 1.0) + # dask changed behaviour of copy=None in 2024.12 to copy; + # this wrapper ensures the same behaviour in older versions too. + # https://github.com/dask/dask/pull/11524/ + assert b[0] == 1.0 else: - assert all(b[0] == 0.0) + # copy=None defaults to False when possible + assert b[0] == 0.0 + + +@pytest.mark.parametrize("library", ["numpy", "cupy", "torch"]) +def test_clip_out(library): + """Test non-standard out= parameter for clip() + + (see "Avoid Restricting Behavior that is Outside the Scope of the Standard" + in https://data-apis.org/array-api-compat/dev/special-considerations.html) + """ + xp = import_(library, wrapper=True) + x = xp.asarray([10, 20, 30]) + out = xp.zeros_like(x) + xp.clip(x, 15, 25, out=out) + expect = xp.asarray([15, 20, 25]) + assert xp.all(out == expect) diff --git a/tests/test_copies_or_views.py b/tests/test_copies_or_views.py new file mode 100644 index 00000000..ec8995f7 --- /dev/null +++ b/tests/test_copies_or_views.py @@ -0,0 +1,64 @@ +""" +A collection of tests to make sure that wrapped namespaces agree with the bare ones +on whether to return a view or a copy of inputs. +""" +import pytest +from ._helpers import import_, wrapped_libraries + + +FUNC_INPUTS = [ + # func_name, arr_input, dtype, scalar_value + ('abs', [1, 2], 'int8', 3), + ('abs', [1, 2], 'float32', 3.), + ('ceil', [1, 2], 'int8', 3), + ('clip', [1, 2], 'int8', 3), + ('conj', [1, 2], 'int8', 3), + ('floor', [1, 2], 'int8', 3), + ('imag', [1j, 2j], 'complex64', 3), + ('positive', [1, 2], 'int8', 3), + ('real', [1., 2.], 'float32', 3.), + ('round', [1, 2], 'int8', 3), + ('sign', [0, 0], 'float32', 3), + ('trunc', [1, 2], 'int8', 3), + ('trunc', [1, 2], 'float32', 3), +] + + +def ensure_unary(func, arr): + """Make a trivial unary function from func.""" + if func.__name__ == 'clip': + return lambda x: func(x, arr[0], arr[1]) + return func + + +def is_view(func, a, value): + """Apply `func`, mutate the output; does the input change?""" + b = func(a) + b[0] = value + return a[0] == value + + +@pytest.mark.parametrize('xp_name', wrapped_libraries + ['array_api_strict']) +@pytest.mark.parametrize('inputs', FUNC_INPUTS, ids=[inp[0] for inp in FUNC_INPUTS]) +def test_view_or_copy(inputs, xp_name): + bare_xp = import_(xp_name, wrapper=False) + wrapped_xp = import_(xp_name, wrapper=True) + + func_name, arr_input, dtype_str, value = inputs + dtype = getattr(bare_xp, dtype_str) + + bare_func = getattr(bare_xp, func_name) + bare_func = ensure_unary(bare_func, arr_input) + + wrapped_func = getattr(wrapped_xp, func_name) + wrapped_func = ensure_unary(wrapped_func, arr_input) + + # bare namespace: mutate the output, does the input change? + a = bare_xp.asarray(arr_input, dtype=dtype) + is_view_bare = is_view(bare_func, a, value) + + # wrapped namespace: mutate the output, does the input change? + a1 = wrapped_xp.asarray(arr_input, dtype=dtype) + is_view_wrapped = is_view(wrapped_func, a1, value) + + assert is_view_bare == is_view_wrapped diff --git a/tests/test_cupy.py b/tests/test_cupy.py new file mode 100644 index 00000000..4745b983 --- /dev/null +++ b/tests/test_cupy.py @@ -0,0 +1,45 @@ +import pytest +from array_api_compat import device, to_device + +xp = pytest.importorskip("array_api_compat.cupy") +from cupy.cuda import Stream + + +@pytest.mark.parametrize( + "make_stream", + [ + lambda: Stream(), + lambda: Stream(non_blocking=True), + lambda: Stream(null=True), + lambda: Stream(ptds=True), + ], +) +def test_to_device_with_stream(make_stream): + devices = xp.__array_namespace_info__().devices() + + a = xp.asarray([1, 2, 3]) + for dev in devices: + # Streams are device-specific and must be created within + # the context of the device... + with dev: + stream = make_stream() + # ... however, to_device() does not need to be inside the + # device context. + b = to_device(a, dev, stream=stream) + assert device(b) == dev + + +def test_to_device_with_dlpack_stream(): + devices = xp.__array_namespace_info__().devices() + + a = xp.asarray([1, 2, 3]) + for dev in devices: + # Streams are device-specific and must be created within + # the context of the device... + with dev: + s1 = Stream() + + # ... however, to_device() does not need to be inside the + # device context. + b = to_device(a, dev, stream=s1.ptr) + assert device(b) == dev diff --git a/tests/test_dask.py b/tests/test_dask.py new file mode 100644 index 00000000..fb0a84d4 --- /dev/null +++ b/tests/test_dask.py @@ -0,0 +1,183 @@ +from contextlib import contextmanager + +import numpy as np +import pytest + +try: + import dask + import dask.array as da +except ImportError: + pytestmark = pytest.skip(allow_module_level=True, reason="dask not found") + +from array_api_compat import array_namespace + + +@pytest.fixture +def xp(): + """Fixture returning the wrapped dask namespace""" + return array_namespace(da.empty(0)) + + +@contextmanager +def assert_no_compute(): + """ + Context manager that raises if at any point inside it anything calls compute() + or persist(), e.g. as it can be triggered implicitly by __bool__, __array__, etc. + """ + + def get(dsk, *args, **kwargs): + raise AssertionError("Called compute() or persist()") + + with dask.config.set(scheduler=get): + yield + + +def test_assert_no_compute(): + """Test the assert_no_compute context manager""" + a = da.asarray(True) + with pytest.raises(AssertionError, match="Called compute"): + with assert_no_compute(): + bool(a) + + # Exiting the context manager restores the original scheduler + assert bool(a) is True + + +# Test no_compute for functions that use generic _aliases with xp=np + + +def test_unary_ops_no_compute(xp): + with assert_no_compute(): + a = xp.asarray([1.5, -1.5]) + xp.ceil(a) + xp.floor(a) + xp.trunc(a) + xp.sign(a) + + +def test_matmul_tensordot_no_compute(xp): + A = da.ones((4, 4), chunks=2) + B = da.zeros((4, 4), chunks=2) + with assert_no_compute(): + xp.matmul(A, B) + xp.tensordot(A, B) + + +# Test no_compute for functions that are fully bespoke for dask + + +def test_asarray_no_compute(xp): + with assert_no_compute(): + a = xp.arange(10) + xp.asarray(a) + xp.asarray(a, dtype=np.int16) + xp.asarray(a, dtype=a.dtype) + xp.asarray(a, copy=True) + xp.asarray(a, copy=True, dtype=np.int16) + xp.asarray(a, copy=True, dtype=a.dtype) + xp.asarray(a, copy=False) + xp.asarray(a, copy=False, dtype=a.dtype) + + +@pytest.mark.parametrize("copy", [True, False]) +def test_astype_no_compute(xp, copy): + with assert_no_compute(): + a = xp.arange(10) + xp.astype(a, np.int16, copy=copy) + xp.astype(a, a.dtype, copy=copy) + + +def test_clip_no_compute(xp): + with assert_no_compute(): + a = xp.arange(10) + xp.clip(a) + xp.clip(a, 1) + xp.clip(a, 1, 8) + + +@pytest.mark.parametrize("chunks", (5, 10)) +def test_sort_argsort_nocompute(xp, chunks): + with assert_no_compute(): + a = xp.arange(10, chunks=chunks) + xp.sort(a) + xp.argsort(a) + + +def test_generators_are_lazy(xp): + """ + Test that generator functions are fully lazy, e.g. that + da.ones(n) is not implemented as da.asarray(np.ones(n)) + """ + size = 100_000_000_000 # 800 GB + chunks = size // 10 # 10x 80 GB chunks + + with assert_no_compute(): + xp.zeros(size, chunks=chunks) + xp.ones(size, chunks=chunks) + xp.empty(size, chunks=chunks) + xp.full(size, fill_value=123, chunks=chunks) + a = xp.arange(size, chunks=chunks) + xp.zeros_like(a) + xp.ones_like(a) + xp.empty_like(a) + xp.full_like(a, fill_value=123) + + +@pytest.mark.parametrize("axis", [0, 1]) +@pytest.mark.parametrize("func", ["sort", "argsort"]) +def test_sort_argsort_chunks(xp, func, axis): + """Test that sort and argsort are functionally correct when + the array is chunked along the sort axis, e.g. the sort is + not just local to each chunk. + """ + a = da.random.random((10, 10), chunks=(5, 5)) + actual = getattr(xp, func)(a, axis=axis) + expect = getattr(np, func)(a.compute(), axis=axis) + np.testing.assert_array_equal(actual, expect) + + +@pytest.mark.parametrize( + "shape,chunks", + [ + # 3 GiB; 128 MiB per chunk; must rechunk before sorting. + # Sort chunks can be 128 MiB each; no need for final rechunk. + ((20_000, 20_000), "auto"), + # 3 GiB; 128 MiB per chunk; must rechunk before sorting. + # Must sort on two 1.5 GiB chunks; benefits from final rechunk. + ((2, 2**30 * 3 // 16), "auto"), + # 3 GiB; 1.5 GiB per chunk; no need to rechunk before sorting. + # Surely the user must know what they're doing, so don't + # perform the final rechunk. + ((2, 2**30 * 3 // 16), (1, -1)), + ], +) +@pytest.mark.parametrize("func", ["sort", "argsort"]) +def test_sort_argsort_chunk_size(xp, func, shape, chunks): + """ + Test that sort and argsort produce reasonably-sized chunks + in the output array, even if they had to go through a singular + huge one to perform the operation. + """ + a = da.random.random(shape, chunks=chunks) + b = getattr(xp, func)(a) + max_chunk_size = max(b.chunks[0]) * max(b.chunks[1]) * b.dtype.itemsize + assert ( + max_chunk_size <= 128 * 1024 * 1024 # 128 MiB + or b.chunks == a.chunks + ) + + +@pytest.mark.parametrize("func", ["sort", "argsort"]) +def test_sort_argsort_meta(xp, func): + """Test meta-namespace other than numpy""" + mxp = pytest.importorskip("array_api_strict") + typ = type(mxp.asarray(0)) + a = da.random.random(10) + b = a.map_blocks(mxp.asarray) + assert isinstance(b._meta, typ) + c = getattr(xp, func)(b) + assert isinstance(c._meta, typ) + d = c.compute() + # Note: np.sort(array_api_strict.asarray(0)) would return a numpy array + assert isinstance(d, typ) + np.testing.assert_array_equal(d, getattr(np, func)(a.compute())) diff --git a/tests/test_jax.py b/tests/test_jax.py new file mode 100644 index 00000000..285958d4 --- /dev/null +++ b/tests/test_jax.py @@ -0,0 +1,38 @@ +from numpy.testing import assert_equal +import pytest + +from array_api_compat import device, to_device + +try: + import jax + import jax.numpy as jnp +except ImportError: + pytestmark = pytest.skip(allow_module_level=True, reason="jax not found") + +HAS_JAX_0_4_31 = jax.__version__ >= "0.4.31" + + +@pytest.mark.parametrize( + "func", + [ + lambda x: jnp.zeros(1, device=device(x)), + lambda x: jnp.zeros_like(jnp.ones(1, device=device(x))), + lambda x: jnp.zeros_like(jnp.empty(1, device=device(x))), + lambda x: jnp.full(1, fill_value=0, device=device(x)), + pytest.param( + lambda x: jnp.asarray([0], device=device(x)), + marks=pytest.mark.skipif( + not HAS_JAX_0_4_31, reason="asarray() has no device= parameter" + ), + ), + lambda x: to_device(jnp.zeros(1), device(x)), + ] +) +def test_device_jit(func): + # Test work around to https://github.com/jax-ml/jax/issues/26000 + # Also test missing to_device() method in JAX < 0.4.31 + # when inside jax.jit, even after importing jax.experimental.array_api + + x = jnp.ones(1) + assert_equal(func(x), jnp.asarray([0])) + assert_equal(jax.jit(func)(x), jnp.asarray([0])) diff --git a/tests/test_torch.py b/tests/test_torch.py new file mode 100644 index 00000000..7adb4ab3 --- /dev/null +++ b/tests/test_torch.py @@ -0,0 +1,119 @@ +"""Test "unspecified" behavior which we cannot easily test in the Array API test suite. +""" +import itertools + +import pytest + +try: + import torch +except ImportError: + pytestmark = pytest.skip(allow_module_level=True, reason="pytorch not found") + +from array_api_compat import torch as xp + + +class TestResultType: + def test_empty(self): + with pytest.raises(ValueError): + xp.result_type() + + def test_one_arg(self): + for x in [1, 1.0, 1j, '...', None]: + with pytest.raises((ValueError, AttributeError)): + xp.result_type(x) + + for x in [xp.float32, xp.int64, torch.complex64]: + assert xp.result_type(x) == x + + for x in [xp.asarray(True, dtype=xp.bool), xp.asarray(1, dtype=xp.complex64)]: + assert xp.result_type(x) == x.dtype + + def test_two_args(self): + # Only include here things "unspecified" in the spec + + # scalar, tensor or tensor,tensor + for x, y in [ + (1., 1j), + (1j, xp.arange(3)), + (True, xp.asarray(3.)), + (xp.ones(3) == 1, 1j*xp.ones(3)), + ]: + assert xp.result_type(x, y) == torch.result_type(x, y) + + # dtype, scalar + for x, y in [ + (1j, xp.int64), + (True, xp.float64), + ]: + assert xp.result_type(x, y) == torch.result_type(x, xp.empty([], dtype=y)) + + # dtype, dtype + for x, y in [ + (xp.bool, xp.complex64) + ]: + xt, yt = xp.empty([], dtype=x), xp.empty([], dtype=y) + assert xp.result_type(x, y) == torch.result_type(xt, yt) + + def test_multi_arg(self): + torch.set_default_dtype(torch.float32) + + args = [1., 5, 3, torch.asarray([3], dtype=torch.float16), 5, 6, 1.] + assert xp.result_type(*args) == torch.float16 + + args = [1, 2, 3j, xp.arange(3, dtype=xp.float32), 4, 5, 6] + assert xp.result_type(*args) == xp.complex64 + + args = [1, 2, 3j, xp.float64, 4, 5, 6] + assert xp.result_type(*args) == xp.complex128 + + args = [1, 2, 3j, xp.float64, 4, xp.asarray(3, dtype=xp.int16), 5, 6, False] + assert xp.result_type(*args) == xp.complex128 + + i64 = xp.ones(1, dtype=xp.int64) + f16 = xp.ones(1, dtype=xp.float16) + for i in itertools.permutations([i64, f16, 1.0, 1.0]): + assert xp.result_type(*i) == xp.float16, f"{i}" + + with pytest.raises(ValueError): + xp.result_type(1, 2, 3, 4) + + + @pytest.mark.parametrize("default_dt", ['float32', 'float64']) + @pytest.mark.parametrize("dtype_a", + (xp.int32, xp.int64, xp.float32, xp.float64, xp.complex64, xp.complex128) + ) + @pytest.mark.parametrize("dtype_b", + (xp.int32, xp.int64, xp.float32, xp.float64, xp.complex64, xp.complex128) + ) + def test_gh_273(self, default_dt, dtype_a, dtype_b): + # Regression test for https://github.com/data-apis/array-api-compat/issues/273 + + try: + prev_default = torch.get_default_dtype() + default_dtype = getattr(torch, default_dt) + torch.set_default_dtype(default_dtype) + + a = xp.asarray([2, 1], dtype=dtype_a) + b = xp.asarray([1, -1], dtype=dtype_b) + dtype_1 = xp.result_type(a, b, 1.0) + dtype_2 = xp.result_type(b, a, 1.0) + assert dtype_1 == dtype_2 + finally: + torch.set_default_dtype(prev_default) + + +def test_meshgrid(): + """Verify that array_api_compat.torch.meshgrid defaults to indexing='xy'.""" + + x, y = xp.asarray([1, 2]), xp.asarray([4]) + + X, Y = xp.meshgrid(x, y) + + # output of torch.meshgrid(x, y, indexing='xy') -- indexing='ij' is different + X_xy, Y_xy = xp.asarray([[1, 2]]), xp.asarray([[4, 4]]) + + assert X.shape == X_xy.shape + assert xp.all(X == X_xy) + + assert Y.shape == Y_xy.shape + assert xp.all(Y == Y_xy) diff --git a/tests/test_vendoring.py b/tests/test_vendoring.py index 70083b49..8b561551 100644 --- a/tests/test_vendoring.py +++ b/tests/test_vendoring.py @@ -16,11 +16,13 @@ def test_vendoring_cupy(): def test_vendoring_torch(): + pytest.importorskip("torch") from vendor_test import uses_torch uses_torch._test_torch() def test_vendoring_dask(): + pytest.importorskip("dask") from vendor_test import uses_dask uses_dask._test_dask() diff --git a/torch-skips.txt b/torch-skips.txt index cbf7235b..e69de29b 100644 --- a/torch-skips.txt +++ b/torch-skips.txt @@ -1,11 +0,0 @@ -# These tests cause a core dump on CI, so we have to skip them entirely -array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__ifloordiv__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__ifloordiv__(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[floor_divide(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__imod__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__imod__(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__mod__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__mod__(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[remainder(x1, x2)] diff --git a/torch-xfails.txt b/torch-xfails.txt index c7abe2e9..989df0c8 100644 --- a/torch-xfails.txt +++ b/torch-xfails.txt @@ -8,31 +8,15 @@ array_api_tests/test_array_object.py::test_getitem array_api_tests/test_array_object.py::test_setitem # Masking doesn't suport 0 dimensions in the mask array_api_tests/test_array_object.py::test_getitem_masking -# torch doesn't have uint dtypes other than uint8 -array_api_tests/test_array_object.py::test_scalar_casting[__int__(uint16)] -array_api_tests/test_array_object.py::test_scalar_casting[__int__(uint32)] -array_api_tests/test_array_object.py::test_scalar_casting[__int__(uint64)] -array_api_tests/test_array_object.py::test_scalar_casting[__index__(uint16)] -array_api_tests/test_array_object.py::test_scalar_casting[__index__(uint32)] -array_api_tests/test_array_object.py::test_scalar_casting[__index__(uint64)] # Overflow error from large inputs array_api_tests/test_creation_functions.py::test_arange # pytorch linspace bug (should be fixed in torch 2.0) -array_api_tests/test_creation_functions.py::test_linspace - -# torch doesn't have higher uint dtypes -array_api_tests/test_data_type_functions.py::test_iinfo[uint16] -array_api_tests/test_data_type_functions.py::test_iinfo[uint32] -array_api_tests/test_data_type_functions.py::test_iinfo[uint64] # We cannot wrap the tensor object array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__] array_api_tests/test_has_names.py::test_has_names[array_method-to_device] -# tensordot doesn't allow integer dtypes in some corner cases -array_api_tests/test_linalg.py::test_tensordot - # We cannot wrap the tensor object array_api_tests/test_operators_and_elementwise_functions.py::test_add[__add__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_and[__and__(x1, x2)] @@ -45,6 +29,10 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__trued array_api_tests/test_operators_and_elementwise_functions.py::test_equal[__eq__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_not_equal[__ne__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x, s)] +array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__ifloordiv__(x, s)] +array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__ifloordiv__(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[floor_divide(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_greater[__gt__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[__ge__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_less[__lt__(x1, x2)] @@ -56,13 +44,11 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_pow[__pow__(x1 array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__imod__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__sub__(x1, x2)] +# inverse trig functions are too inaccurate on CPU +array_api_tests/test_operators_and_elementwise_functions.py::test_acos +array_api_tests/test_operators_and_elementwise_functions.py::test_atan +array_api_tests/test_operators_and_elementwise_functions.py::test_asin -# overflow near float max -array_api_tests/test_operators_and_elementwise_functions.py::test_log1p - -# torch doesn't handle shifting by more than the bit size correctly -# https://github.com/pytorch/pytorch/issues/70904 -array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_right_shift[bitwise_right_shift(x1, x2)] # Torch bug for remainder in some cases with large values array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[remainder(x1, x2)] @@ -73,11 +59,6 @@ array_api_tests/test_set_functions.py::test_unique_all # (https://github.com/pytorch/pytorch/issues/94106) array_api_tests/test_set_functions.py::test_unique_inverse -# The test suite incorrectly divides by 0 here -# (https://github.com/data-apis/array-api-tests/issues/170) -array_api_tests/test_signatures.py::test_func_signature[floor_divide] -array_api_tests/test_signatures.py::test_func_signature[remainder] -array_api_tests/test_signatures.py::test_array_method_signature[__mod__] # We cannot add attributes to the tensor object array_api_tests/test_signatures.py::test_array_method_signature[__array_namespace__] @@ -86,13 +67,6 @@ array_api_tests/test_signatures.py::test_array_method_signature[to_device] # We do not attempt to work around special-case differences (most are on # tensor methods which we couldn't fix anyway). -array_api_tests/test_special_cases.py::test_binary[__add__((x1_i is +0 or x1_i == -0) and isfinite(x2_i) and x2_i != 0) -> x2_i] -array_api_tests/test_special_cases.py::test_binary[__add__(isfinite(x1_i) and x1_i != 0 and (x2_i is +0 or x2_i == -0)) -> x1_i] -array_api_tests/test_special_cases.py::test_binary[__add__(isfinite(x1_i) and x1_i != 0 and x2_i == -x1_i) -> +0] -array_api_tests/test_special_cases.py::test_binary[__add__(isfinite(x1_i) and x2_i is +infinity) -> +infinity] -array_api_tests/test_special_cases.py::test_binary[__add__(isfinite(x1_i) and x2_i is -infinity) -> -infinity] -array_api_tests/test_special_cases.py::test_binary[__add__(x1_i is +infinity and isfinite(x2_i)) -> +infinity] -array_api_tests/test_special_cases.py::test_binary[__add__(x1_i is -infinity and isfinite(x2_i)) -> -infinity] array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is -infinity) -> +0] array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is +infinity) -> +0] @@ -117,41 +91,6 @@ array_api_tests/test_special_cases.py::test_binary[__mod__(x1_i is +0 and x2_i < array_api_tests/test_special_cases.py::test_binary[__mod__(x1_i is +0 and x2_i > 0) -> +0] array_api_tests/test_special_cases.py::test_binary[__mod__(x1_i is -0 and x2_i < 0) -> -0] array_api_tests/test_special_cases.py::test_binary[__mod__(x1_i is -0 and x2_i > 0) -> +0] -array_api_tests/test_special_cases.py::test_binary[__pow__(abs(x1_i) < 1 and x2_i is +infinity) -> +0] -array_api_tests/test_special_cases.py::test_binary[__pow__(abs(x1_i) < 1 and x2_i is -infinity) -> +infinity] -array_api_tests/test_special_cases.py::test_binary[__pow__(abs(x1_i) > 1 and x2_i is +infinity) -> +infinity] -array_api_tests/test_special_cases.py::test_binary[__pow__(abs(x1_i) > 1 and x2_i is -infinity) -> +0] -array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i < 0 and isfinite(x1_i) and isfinite(x2_i) and not x2_i.is_integer()) -> NaN] -array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is +0 and x2_i < 0) -> +infinity] -array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is +0 and x2_i > 0) -> +0] -array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is +infinity and x2_i < 0) -> +0] -array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is +infinity and x2_i > 0) -> +infinity] -array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -0 and x2_i < 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity] -array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -0 and x2_i < 0 and x2_i.is_integer() and x2_i % 2 == 1) -> -infinity] -array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0] -array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -0 and x2_i > 0 and x2_i.is_integer() and x2_i % 2 == 1) -> -0] -array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -infinity and x2_i < 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0] -array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -infinity and x2_i < 0 and x2_i.is_integer() and x2_i % 2 == 1) -> -0] -array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity] -array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -infinity and x2_i > 0 and x2_i.is_integer() and x2_i % 2 == 1) -> -infinity] -array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is NaN and not x2_i == 0) -> NaN] -array_api_tests/test_special_cases.py::test_binary[__truediv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] -array_api_tests/test_special_cases.py::test_binary[__truediv__(isfinite(x1_i) and x1_i < 0 and x2_i is -infinity) -> +0] -array_api_tests/test_special_cases.py::test_binary[__truediv__(isfinite(x1_i) and x1_i > 0 and x2_i is +infinity) -> +0] -array_api_tests/test_special_cases.py::test_binary[__truediv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] -array_api_tests/test_special_cases.py::test_binary[__truediv__(x1_i < 0 and x2_i is +0) -> -infinity] -array_api_tests/test_special_cases.py::test_binary[__truediv__(x1_i < 0 and x2_i is -0) -> +infinity] -array_api_tests/test_special_cases.py::test_binary[__truediv__(x1_i > 0 and x2_i is +0) -> +infinity] -array_api_tests/test_special_cases.py::test_binary[__truediv__(x1_i > 0 and x2_i is -0) -> -infinity] -array_api_tests/test_special_cases.py::test_binary[__truediv__(x1_i is +0 and x2_i < 0) -> -0] -array_api_tests/test_special_cases.py::test_binary[__truediv__(x1_i is +0 and x2_i > 0) -> +0] -array_api_tests/test_special_cases.py::test_binary[__truediv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] -array_api_tests/test_special_cases.py::test_binary[__truediv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] -array_api_tests/test_special_cases.py::test_binary[__truediv__(x1_i is -0 and x2_i < 0) -> +0] -array_api_tests/test_special_cases.py::test_binary[__truediv__(x1_i is -0 and x2_i > 0) -> -0] -array_api_tests/test_special_cases.py::test_binary[__truediv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] -array_api_tests/test_special_cases.py::test_binary[__truediv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] -array_api_tests/test_special_cases.py::test_binary[add(isfinite(x1_i) and x1_i != 0 and x2_i == -x1_i) -> +0] array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] @@ -160,7 +99,6 @@ array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinit array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] array_api_tests/test_special_cases.py::test_binary[remainder(x1_i is +0 and x2_i < 0) -> -0] array_api_tests/test_special_cases.py::test_binary[remainder(x1_i is -0 and x2_i > 0) -> +0] -array_api_tests/test_special_cases.py::test_iop[__iadd__(isfinite(x1_i) and x1_i != 0 and x2_i == -x1_i) -> +0] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] @@ -172,31 +110,57 @@ array_api_tests/test_special_cases.py::test_iop[__imod__(x1_i is -0 and x2_i > 0 # Float correction is not supported by pytorch # (https://github.com/data-apis/array-api-tests/issues/168) -array_api_tests/test_special_cases.py::test_empty_arrays[std] -array_api_tests/test_special_cases.py::test_empty_arrays[var] -array_api_tests/test_special_cases.py::test_nan_propagation[std] -array_api_tests/test_special_cases.py::test_nan_propagation[var] array_api_tests/test_statistical_functions.py::test_std array_api_tests/test_statistical_functions.py::test_var -# The test suite is incorrectly checking sums that have loss of significance -# (https://github.com/data-apis/array-api-tests/issues/168) -array_api_tests/test_statistical_functions.py::test_sum -array_api_tests/test_statistical_functions.py::test_prod # These functions do not yet support complex numbers -array_api_tests/test_operators_and_elementwise_functions.py::test_expm1 array_api_tests/test_operators_and_elementwise_functions.py::test_round array_api_tests/test_set_functions.py::test_unique_counts array_api_tests/test_set_functions.py::test_unique_values +# finfo/iinfo.dtype is a string instead of a dtype +array_api_tests/test_data_type_functions.py::test_finfo_dtype +array_api_tests/test_data_type_functions.py::test_iinfo_dtype + # 2023.12 support -array_api_tests/test_has_names.py::test_has_names[manipulation-repeat] +# https://github.com/pytorch/pytorch/issues/151311: torch.repeat_interleave rejects short integers array_api_tests/test_manipulation_functions.py::test_repeat -array_api_tests/test_signatures.py::test_func_signature[repeat] # Argument 'device' missing from signature array_api_tests/test_signatures.py::test_func_signature[from_dlpack] # Argument 'max_version' missing from signature array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] -# Argument 'device' missing from signature -array_api_tests/test_signatures.py::test_func_signature[astype] + +# 2024.12 support +array_api_tests/test_signatures.py::test_func_signature[bitwise_and] +array_api_tests/test_signatures.py::test_func_signature[bitwise_left_shift] +array_api_tests/test_signatures.py::test_func_signature[bitwise_or] +array_api_tests/test_signatures.py::test_func_signature[bitwise_right_shift] +array_api_tests/test_signatures.py::test_func_signature[bitwise_xor] +array_api_tests/test_signatures.py::test_array_method_signature[__and__] +array_api_tests/test_signatures.py::test_array_method_signature[__lshift__] +array_api_tests/test_signatures.py::test_array_method_signature[__or__] +array_api_tests/test_signatures.py::test_array_method_signature[__rshift__] +array_api_tests/test_signatures.py::test_array_method_signature[__xor__] + +# 2024.12 support: binary functions reject python scalar arguments +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[atan2] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[copysign] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[divide] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[hypot] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[logaddexp] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[maximum] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[minimum] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[nextafter] + +# https://github.com/pytorch/pytorch/issues/149815 +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[equal] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[not_equal] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[less] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[less_equal] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[greater] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[greater_equal] + +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_bool[logical_and] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_bool[logical_or] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_bool[logical_xor] diff --git a/vendor_test/uses_torch.py b/vendor_test/uses_torch.py index 5804aaff..747ecd51 100644 --- a/vendor_test/uses_torch.py +++ b/vendor_test/uses_torch.py @@ -23,7 +23,7 @@ def _test_torch(): assert isinstance(b, torch.Tensor) assert isinstance(res, torch.Tensor) - torch.testing.assert_allclose(res, [[1., 2., 3.]]) + torch.testing.assert_close(res, torch.as_tensor([[1., 2., 3.]])) assert is_torch_array(res) assert is_torch_namespace(torch) and is_torch_namespace(torch_compat)