diff --git a/maint_tools/vendor_array_api_extra.sh b/maint_tools/vendor_array_api_extra.sh index ead6e2e62c43f..5cd51631cbdbb 100755 --- a/maint_tools/vendor_array_api_extra.sh +++ b/maint_tools/vendor_array_api_extra.sh @@ -6,7 +6,7 @@ set -o nounset set -o errexit URL="https://github.com/data-apis/array-api-extra.git" -VERSION="v0.7.1" +VERSION="v0.8.0" ROOT_DIR=sklearn/externals/array_api_extra diff --git a/sklearn/externals/array_api_extra/__init__.py b/sklearn/externals/array_api_extra/__init__.py index 924c23b9351a3..b5654902f0e66 100644 --- a/sklearn/externals/array_api_extra/__init__.py +++ b/sklearn/externals/array_api_extra/__init__.py @@ -1,6 +1,6 @@ """Extra array functions built on top of the array API standard.""" -from ._delegation import isclose, pad +from ._delegation import isclose, one_hot, pad from ._lib._at import at from ._lib._funcs import ( apply_where, @@ -8,6 +8,7 @@ broadcast_shapes, cov, create_diagonal, + default_dtype, expand_dims, kron, nunique, @@ -16,7 +17,7 @@ ) from ._lib._lazy import lazy_apply -__version__ = "0.7.1" +__version__ = "0.8.0" # pylint: disable=duplicate-code __all__ = [ @@ -27,11 +28,13 @@ "broadcast_shapes", "cov", "create_diagonal", + "default_dtype", "expand_dims", "isclose", "kron", "lazy_apply", "nunique", + "one_hot", "pad", "setdiff1d", "sinc", diff --git a/sklearn/externals/array_api_extra/_delegation.py b/sklearn/externals/array_api_extra/_delegation.py index bb11b7ee24773..756841c8e53fd 100644 --- a/sklearn/externals/array_api_extra/_delegation.py +++ b/sklearn/externals/array_api_extra/_delegation.py @@ -4,31 +4,21 @@ from types import ModuleType from typing import Literal -from ._lib import Backend, _funcs -from ._lib._utils._compat import array_namespace +from ._lib import _funcs +from ._lib._utils._compat import ( + array_namespace, + is_cupy_namespace, + is_dask_namespace, + is_jax_namespace, + is_numpy_namespace, + is_pydata_sparse_namespace, + is_torch_namespace, +) +from ._lib._utils._compat import device as get_device from ._lib._utils._helpers import asarrays -from ._lib._utils._typing import Array +from ._lib._utils._typing import Array, DType -__all__ = ["isclose", "pad"] - - -def _delegate(xp: ModuleType, *backends: Backend) -> bool: - """ - Check whether `xp` is one of the `backends` to delegate to. - - Parameters - ---------- - xp : array_namespace - Array namespace to check. - *backends : IsNamespace - Arbitrarily many backends (from the ``IsNamespace`` enum) to check. - - Returns - ------- - bool - ``True`` if `xp` matches one of the `backends`, ``False`` otherwise. - """ - return any(backend.is_namespace(xp) for backend in backends) +__all__ = ["isclose", "one_hot", "pad"] def isclose( @@ -108,16 +98,98 @@ def isclose( """ xp = array_namespace(a, b) if xp is None else xp - if _delegate(xp, Backend.NUMPY, Backend.CUPY, Backend.DASK, Backend.JAX): + if ( + is_numpy_namespace(xp) + or is_cupy_namespace(xp) + or is_dask_namespace(xp) + or is_jax_namespace(xp) + ): return xp.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) - if _delegate(xp, Backend.TORCH): + if is_torch_namespace(xp): a, b = asarrays(a, b, xp=xp) # Array API 2024.12 support return xp.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) return _funcs.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan, xp=xp) +def one_hot( + x: Array, + /, + num_classes: int, + *, + dtype: DType | None = None, + axis: int = -1, + xp: ModuleType | None = None, +) -> Array: + """ + One-hot encode the given indices. + + Each index in the input `x` is encoded as a vector of zeros of length `num_classes` + with the element at the given index set to one. + + Parameters + ---------- + x : array + An array with integral dtype whose values are between `0` and `num_classes - 1`. + num_classes : int + Number of classes in the one-hot dimension. + dtype : DType, optional + The dtype of the return value. Defaults to the default float dtype (usually + float64). + axis : int, optional + Position in the expanded axes where the new axis is placed. Default: -1. + xp : array_namespace, optional + The standard-compatible namespace for `x`. Default: infer. + + Returns + ------- + array + An array having the same shape as `x` except for a new axis at the position + given by `axis` having size `num_classes`. If `axis` is unspecified, it + defaults to -1, which appends a new axis. + + If ``x < 0`` or ``x >= num_classes``, then the result is undefined, may raise + an exception, or may even cause a bad state. `x` is not checked. + + Examples + -------- + >>> import array_api_extra as xpx + >>> import array_api_strict as xp + >>> xpx.one_hot(xp.asarray([1, 2, 0]), 3) + Array([[0., 1., 0.], + [0., 0., 1.], + [1., 0., 0.]], dtype=array_api_strict.float64) + """ + # Validate inputs. + if xp is None: + xp = array_namespace(x) + if not xp.isdtype(x.dtype, "integral"): + msg = "x must have an integral dtype." + raise TypeError(msg) + if dtype is None: + dtype = _funcs.default_dtype(xp, device=get_device(x)) + # Delegate where possible. + if is_jax_namespace(xp): + from jax.nn import one_hot as jax_one_hot + + return jax_one_hot(x, num_classes, dtype=dtype, axis=axis) + if is_torch_namespace(xp): + from torch.nn.functional import one_hot as torch_one_hot + + x = xp.astype(x, xp.int64) # PyTorch only supports int64 here. + try: + out = torch_one_hot(x, num_classes) + except RuntimeError as e: + raise IndexError from e + else: + out = _funcs.one_hot(x, num_classes, xp=xp) + out = xp.astype(out, dtype, copy=False) + if axis != -1: + out = xp.moveaxis(out, -1, axis) + return out + + def pad( x: Array, pad_width: int | tuple[int, int] | Sequence[tuple[int, int]], @@ -159,14 +231,19 @@ def pad( msg = "Only `'constant'` mode is currently supported" raise NotImplementedError(msg) + if ( + is_numpy_namespace(xp) + or is_cupy_namespace(xp) + or is_jax_namespace(xp) + or is_pydata_sparse_namespace(xp) + ): + return xp.pad(x, pad_width, mode, constant_values=constant_values) + # https://github.com/pytorch/pytorch/blob/cf76c05b4dc629ac989d1fb8e789d4fac04a095a/torch/_numpy/_funcs_impl.py#L2045-L2056 - if _delegate(xp, Backend.TORCH): + if is_torch_namespace(xp): pad_width = xp.asarray(pad_width) pad_width = xp.broadcast_to(pad_width, (x.ndim, 2)) pad_width = xp.flip(pad_width, axis=(0,)).flatten() return xp.nn.functional.pad(x, tuple(pad_width), value=constant_values) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] - if _delegate(xp, Backend.NUMPY, Backend.JAX, Backend.CUPY, Backend.SPARSE): - return xp.pad(x, pad_width, mode, constant_values=constant_values) - return _funcs.pad(x, pad_width, constant_values=constant_values, xp=xp) diff --git a/sklearn/externals/array_api_extra/_lib/__init__.py b/sklearn/externals/array_api_extra/_lib/__init__.py index b83d7e8c5c2b7..d7b3203346da0 100644 --- a/sklearn/externals/array_api_extra/_lib/__init__.py +++ b/sklearn/externals/array_api_extra/_lib/__init__.py @@ -1,5 +1 @@ """Internals of array-api-extra.""" - -from ._backends import Backend - -__all__ = ["Backend"] diff --git a/sklearn/externals/array_api_extra/_lib/_at.py b/sklearn/externals/array_api_extra/_lib/_at.py index 22e18d2c0c30c..870884b86ce9d 100644 --- a/sklearn/externals/array_api_extra/_lib/_at.py +++ b/sklearn/externals/array_api_extra/_lib/_at.py @@ -8,10 +8,12 @@ from types import ModuleType from typing import TYPE_CHECKING, ClassVar, cast +from ._utils import _compat from ._utils._compat import ( array_namespace, is_dask_array, is_jax_array, + is_torch_array, is_writeable_array, ) from ._utils._helpers import meta_namespace @@ -298,7 +300,7 @@ def _op( and idx.dtype == xp.bool and idx.shape == x.shape ): - y_xp = xp.asarray(y, dtype=x.dtype) + y_xp = xp.asarray(y, dtype=x.dtype, device=_compat.device(x)) if y_xp.ndim == 0: if out_of_place_op: # add(), subtract(), ... # suppress inf warnings on Dask @@ -344,6 +346,13 @@ def _op( msg = f"Can't update read-only array {x}" raise ValueError(msg) + # Work around bug in PyTorch where __setitem__ doesn't + # always support mismatched dtypes + # https://github.com/pytorch/pytorch/issues/150017 + if is_torch_array(y): + y = xp.astype(y, x.dtype, copy=False) + + # Backends without boolean indexing (other than JAX) crash here if in_place_op: # add(), subtract(), ... x[idx] = in_place_op(x[idx], y) else: # set() diff --git a/sklearn/externals/array_api_extra/_lib/_backends.py b/sklearn/externals/array_api_extra/_lib/_backends.py index f044281ac17c9..f64e14791f901 100644 --- a/sklearn/externals/array_api_extra/_lib/_backends.py +++ b/sklearn/externals/array_api_extra/_lib/_backends.py @@ -1,51 +1,46 @@ -"""Backends with which array-api-extra interacts in delegation and testing.""" +"""Backends against which array-api-extra runs its tests.""" -from collections.abc import Callable -from enum import Enum -from types import ModuleType -from typing import cast +from __future__ import annotations -from ._utils import _compat +from enum import Enum __all__ = ["Backend"] -class Backend(Enum): # numpydoc ignore=PR01,PR02 # type: ignore[no-subclass-any] +class Backend(Enum): # numpydoc ignore=PR02 """ All array library backends explicitly tested by array-api-extra. Parameters ---------- value : str - Name of the backend's module. - is_namespace : Callable[[ModuleType], bool] - Function to check whether an input module is the array namespace - corresponding to the backend. + Tag of the backend's module, in the format ``[:]``. """ - ARRAY_API_STRICT = "array_api_strict", _compat.is_array_api_strict_namespace - NUMPY = "numpy", _compat.is_numpy_namespace - NUMPY_READONLY = "numpy_readonly", _compat.is_numpy_namespace - CUPY = "cupy", _compat.is_cupy_namespace - TORCH = "torch", _compat.is_torch_namespace - DASK = "dask.array", _compat.is_dask_namespace - SPARSE = "sparse", _compat.is_pydata_sparse_namespace - JAX = "jax.numpy", _compat.is_jax_namespace - - def __new__( - cls, value: str, _is_namespace: Callable[[ModuleType], bool] - ): # numpydoc ignore=GL08 - obj = object.__new__(cls) - obj._value_ = value - return obj - - def __init__( - self, - value: str, # noqa: ARG002 # pylint: disable=unused-argument - is_namespace: Callable[[ModuleType], bool], - ): # numpydoc ignore=GL08 - self.is_namespace = is_namespace + # Use : to prevent Enum from deduplicating items with the same value + ARRAY_API_STRICT = "array_api_strict" + ARRAY_API_STRICTEST = "array_api_strict:strictest" + NUMPY = "numpy" + NUMPY_READONLY = "numpy:readonly" + CUPY = "cupy" + TORCH = "torch" + TORCH_GPU = "torch:gpu" + DASK = "dask.array" + SPARSE = "sparse" + JAX = "jax.numpy" + JAX_GPU = "jax.numpy:gpu" def __str__(self) -> str: # type: ignore[explicit-override] # pyright: ignore[reportImplicitOverride] # numpydoc ignore=RT01 """Pretty-print parameterized test names.""" - return cast(str, self.value) + return ( + self.name.lower().replace("_gpu", ":gpu").replace("_readonly", ":readonly") + ) + + @property + def modname(self) -> str: # numpydoc ignore=RT01 + """Module name to be imported.""" + return self.value.split(":")[0] + + def like(self, *others: Backend) -> bool: # numpydoc ignore=PR01,RT01 + """Check if this backend uses the same module as others.""" + return any(self.modname == other.modname for other in others) diff --git a/sklearn/externals/array_api_extra/_lib/_funcs.py b/sklearn/externals/array_api_extra/_lib/_funcs.py index efe2f377968ec..69dfe6a4297de 100644 --- a/sklearn/externals/array_api_extra/_lib/_funcs.py +++ b/sklearn/externals/array_api_extra/_lib/_funcs.py @@ -4,18 +4,19 @@ import warnings from collections.abc import Callable, Sequence from types import ModuleType, NoneType -from typing import cast, overload +from typing import Literal, cast, overload from ._at import at from ._utils import _compat, _helpers -from ._utils._compat import ( - array_namespace, - is_dask_namespace, - is_jax_array, - is_jax_namespace, +from ._utils._compat import array_namespace, is_dask_namespace, is_jax_array +from ._utils._helpers import ( + asarrays, + capabilities, + eager_shape, + meta_namespace, + ndindex, ) -from ._utils._helpers import asarrays, eager_shape, meta_namespace, ndindex -from ._utils._typing import Array +from ._utils._typing import Array, Device, DType __all__ = [ "apply_where", @@ -152,7 +153,7 @@ def _apply_where( # type: ignore[explicit-any] # numpydoc ignore=PR01,RT01 ) -> Array: """Helper of `apply_where`. On Dask, this runs on a single chunk.""" - if is_jax_namespace(xp): + if not capabilities(xp, device=_compat.device(cond))["boolean indexing"]: # jax.jit does not support assignment by boolean mask return xp.where(cond, f1(*args), f2(*args) if f2 is not None else fill_value) @@ -374,6 +375,23 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array: return xp.squeeze(c, axis=axes) +def one_hot( + x: Array, + /, + num_classes: int, + *, + xp: ModuleType, +) -> Array: # numpydoc ignore=PR01,RT01 + """See docstring in `array_api_extra._delegation.py`.""" + # TODO: Benchmark whether this is faster on the NumPy backend: + # if is_numpy_array(x): + # out = xp.zeros((x.size, num_classes), dtype=dtype) + # out[xp.arange(x.size), xp.reshape(x, (-1,))] = 1 + # return xp.reshape(out, (*x.shape, num_classes)) + range_num_classes = xp.arange(num_classes, dtype=x.dtype, device=_compat.device(x)) + return x[..., xp.newaxis] == range_num_classes + + def create_diagonal( x: Array, /, *, offset: int = 0, xp: ModuleType | None = None ) -> Array: @@ -437,6 +455,44 @@ def create_diagonal( return xp.reshape(diag, (*batch_dims, n, n)) +def default_dtype( + xp: ModuleType, + kind: Literal[ + "real floating", "complex floating", "integral", "indexing" + ] = "real floating", + *, + device: Device | None = None, +) -> DType: + """ + Return the default dtype for the given namespace and device. + + This is a convenience shorthand for + ``xp.__array_namespace_info__().default_dtypes(device=device)[kind]``. + + Parameters + ---------- + xp : array_namespace + The standard-compatible namespace for which to get the default dtype. + kind : {'real floating', 'complex floating', 'integral', 'indexing'}, optional + The kind of dtype to return. Default is 'real floating'. + device : Device, optional + The device for which to get the default dtype. Default: current device. + + Returns + ------- + dtype + The default dtype for the given namespace, kind, and device. + """ + dtypes = xp.__array_namespace_info__().default_dtypes(device=device) + try: + return dtypes[kind] + except KeyError as e: + domain = ("real floating", "complex floating", "integral", "indexing") + assert set(dtypes) == set(domain), f"Non-compliant namespace: {dtypes}" + msg = f"Unknown kind '{kind}'. Expected one of {domain}." + raise ValueError(msg) from e + + def expand_dims( a: Array, /, *, axis: int | tuple[int, ...] = (0,), xp: ModuleType | None = None ) -> Array: @@ -708,14 +764,33 @@ def nunique(x: Array, /, *, xp: ModuleType | None = None) -> Array: # size= is JAX-specific # https://github.com/data-apis/array-api/issues/883 _, counts = xp.unique_counts(x, size=_compat.size(x)) - return xp.astype(counts, xp.bool).sum() - - _, counts = xp.unique_counts(x) - n = _compat.size(counts) - # FIXME https://github.com/data-apis/array-api-compat/pull/231 - if n is None: # e.g. Dask, ndonnx - return xp.astype(counts, xp.bool).sum() - return xp.asarray(n, device=_compat.device(x)) + return (counts > 0).sum() + + # There are 3 general use cases: + # 1. backend has unique_counts and it returns an array with known shape + # 2. backend has unique_counts and it returns a None-sized array; + # e.g. Dask, ndonnx + # 3. backend does not have unique_counts; e.g. wrapped JAX + if capabilities(xp, device=_compat.device(x))["data-dependent shapes"]: + # xp has unique_counts; O(n) complexity + _, counts = xp.unique_counts(x) + n = _compat.size(counts) + if n is None: + return xp.sum(xp.ones_like(counts)) + return xp.asarray(n, device=_compat.device(x)) + + # xp does not have unique_counts; O(n*logn) complexity + x = xp.reshape(x, (-1,)) + x = xp.sort(x) + mask = x != xp.roll(x, -1) + default_int = default_dtype(xp, "integral", device=_compat.device(x)) + return xp.maximum( + # Special cases: + # - array is size 0 + # - array has all elements equal to each other + xp.astype(xp.any(~mask), default_int), + xp.sum(xp.astype(mask, default_int)), + ) def pad( diff --git a/sklearn/externals/array_api_extra/_lib/_lazy.py b/sklearn/externals/array_api_extra/_lib/_lazy.py index 7b45eff91cda4..d13d08f883753 100644 --- a/sklearn/externals/array_api_extra/_lib/_lazy.py +++ b/sklearn/externals/array_api_extra/_lib/_lazy.py @@ -144,7 +144,12 @@ def lazy_apply( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04 Dask This allows applying eager functions to Dask arrays. - The Dask graph won't be computed. + The Dask graph won't be computed until the user calls ``compute()`` or + ``persist()`` down the line. + + The function name will be prominently visible on the user-facing Dask + dashboard and on Prometheus metrics, so it is recommended for it to be + meaningful. `lazy_apply` doesn't know if `func` reduces along any axes; also, shape changes are non-trivial in chunked Dask arrays. For these reasons, all inputs diff --git a/sklearn/externals/array_api_extra/_lib/_testing.py b/sklearn/externals/array_api_extra/_lib/_testing.py index e5ec16a64c73e..16a9d10231a7d 100644 --- a/sklearn/externals/array_api_extra/_lib/_testing.py +++ b/sklearn/externals/array_api_extra/_lib/_testing.py @@ -5,10 +5,13 @@ See also ..testing for public testing utilities. """ +from __future__ import annotations + import math from types import ModuleType -from typing import cast +from typing import Any, cast +import numpy as np import pytest from ._utils._compat import ( @@ -16,16 +19,24 @@ is_array_api_strict_namespace, is_cupy_namespace, is_dask_namespace, + is_jax_namespace, + is_numpy_namespace, is_pydata_sparse_namespace, + is_torch_array, is_torch_namespace, + to_device, ) -from ._utils._typing import Array +from ._utils._typing import Array, Device -__all__ = ["xp_assert_close", "xp_assert_equal"] +__all__ = ["as_numpy_array", "xp_assert_close", "xp_assert_equal", "xp_assert_less"] def _check_ns_shape_dtype( - actual: Array, desired: Array + actual: Array, + desired: Array, + check_dtype: bool, + check_shape: bool, + check_scalar: bool, ) -> ModuleType: # numpydoc ignore=RT03 """ Assert that namespace, shape and dtype of the two arrays match. @@ -36,6 +47,11 @@ def _check_ns_shape_dtype( The array produced by the tested function. desired : Array The expected array (typically hardcoded). + check_dtype, check_shape : bool, default: True + Whether to check agreement between actual and desired dtypes and shapes + check_scalar : bool, default: False + NumPy only: whether to check agreement between actual and desired types - + 0d array vs scalar. Returns ------- @@ -47,25 +63,86 @@ def _check_ns_shape_dtype( msg = f"namespaces do not match: {actual_xp} != f{desired_xp}" assert actual_xp == desired_xp, msg - actual_shape = actual.shape - desired_shape = desired.shape + # Dask uses nan instead of None for unknown shapes + actual_shape = cast(tuple[float, ...], actual.shape) + desired_shape = cast(tuple[float, ...], desired.shape) + assert None not in actual_shape # Requires explicit support + assert None not in desired_shape if is_dask_namespace(desired_xp): - # Dask uses nan instead of None for unknown shapes - if any(math.isnan(i) for i in cast(tuple[float, ...], actual_shape)): + if any(math.isnan(i) for i in actual_shape): actual_shape = actual.compute().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] - if any(math.isnan(i) for i in cast(tuple[float, ...], desired_shape)): + if any(math.isnan(i) for i in desired_shape): desired_shape = desired.compute().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] - msg = f"shapes do not match: {actual_shape} != f{desired_shape}" - assert actual_shape == desired_shape, msg - - msg = f"dtypes do not match: {actual.dtype} != {desired.dtype}" - assert actual.dtype == desired.dtype, msg + if check_shape: + msg = f"shapes do not match: {actual_shape} != f{desired_shape}" + assert actual_shape == desired_shape, msg + else: + # Ignore shape, but check flattened size. This is normally done by + # np.testing.assert_array_equal etc even when strict=False, but not for + # non-materializable arrays. + actual_size = math.prod(actual_shape) # pyright: ignore[reportUnknownArgumentType] + desired_size = math.prod(desired_shape) # pyright: ignore[reportUnknownArgumentType] + msg = f"sizes do not match: {actual_size} != f{desired_size}" + assert actual_size == desired_size, msg + + if check_dtype: + msg = f"dtypes do not match: {actual.dtype} != {desired.dtype}" + assert actual.dtype == desired.dtype, msg + + if is_numpy_namespace(actual_xp) and check_scalar: + # only NumPy distinguishes between scalars and arrays; we do if check_scalar. + _msg = ( + "array-ness does not match:\n Actual: " + f"{type(actual)}\n Desired: {type(desired)}" + ) + assert np.isscalar(actual) == np.isscalar(desired), _msg return desired_xp -def xp_assert_equal(actual: Array, desired: Array, err_msg: str = "") -> None: +def _is_materializable(x: Array) -> bool: + """ + Return True if you can call `as_numpy_array(x)`; False otherwise. + """ + # Important: here we assume that we're not tracing - + # e.g. we're not inside `jax.jit`` nor `cupy.cuda.Stream.begin_capture`. + return not is_torch_array(x) or x.device.type != "meta" # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] + + +def as_numpy_array(array: Array, *, xp: ModuleType) -> np.typing.NDArray[Any]: # type: ignore[explicit-any] + """ + Convert array to NumPy, bypassing GPU-CPU transfer guards and densification guards. + """ + if is_cupy_namespace(xp): + return xp.asnumpy(array) + if is_pydata_sparse_namespace(xp): + return array.todense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] + + if is_torch_namespace(xp): + array = to_device(array, "cpu") + if is_array_api_strict_namespace(xp): + cpu: Device = xp.Device("CPU_DEVICE") + array = to_device(array, cpu) + if is_jax_namespace(xp): + import jax + + # Note: only needed if the transfer guard is enabled + cpu = cast(Device, jax.devices("cpu")[0]) + array = to_device(array, cpu) + + return np.asarray(array) + + +def xp_assert_equal( + actual: Array, + desired: Array, + *, + err_msg: str = "", + check_dtype: bool = True, + check_shape: bool = True, + check_scalar: bool = False, +) -> None: """ Array-API compatible version of `np.testing.assert_array_equal`. @@ -77,47 +154,60 @@ def xp_assert_equal(actual: Array, desired: Array, err_msg: str = "") -> None: The expected array (typically hardcoded). err_msg : str, optional Error message to display on failure. + check_dtype, check_shape : bool, default: True + Whether to check agreement between actual and desired dtypes and shapes + check_scalar : bool, default: False + NumPy only: whether to check agreement between actual and desired types - + 0d array vs scalar. See Also -------- xp_assert_close : Similar function for inexact equality checks. numpy.testing.assert_array_equal : Similar function for NumPy arrays. """ - xp = _check_ns_shape_dtype(actual, desired) + xp = _check_ns_shape_dtype(actual, desired, check_dtype, check_shape, check_scalar) + if not _is_materializable(actual): + return + actual_np = as_numpy_array(actual, xp=xp) + desired_np = as_numpy_array(desired, xp=xp) + np.testing.assert_array_equal(actual_np, desired_np, err_msg=err_msg) - if is_cupy_namespace(xp): - xp.testing.assert_array_equal(actual, desired, err_msg=err_msg) - elif is_torch_namespace(xp): - # PyTorch recommends using `rtol=0, atol=0` like this - # to test for exact equality - xp.testing.assert_close( - actual, - desired, - rtol=0, - atol=0, - equal_nan=True, - check_dtype=False, - msg=err_msg or None, - ) - else: - import numpy as np # pylint: disable=import-outside-toplevel - if is_pydata_sparse_namespace(xp): - actual = actual.todense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] - desired = desired.todense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] +def xp_assert_less( + x: Array, + y: Array, + *, + err_msg: str = "", + check_dtype: bool = True, + check_shape: bool = True, + check_scalar: bool = False, +) -> None: + """ + Array-API compatible version of `np.testing.assert_array_less`. - actual_np = None - desired_np = None - if is_array_api_strict_namespace(xp): - # __array__ doesn't work on array-api-strict device arrays - # We need to convert to the CPU device first - actual_np = np.asarray(xp.asarray(actual, device=xp.Device("CPU_DEVICE"))) - desired_np = np.asarray(xp.asarray(desired, device=xp.Device("CPU_DEVICE"))) + Parameters + ---------- + x, y : Array + The arrays to compare according to ``x < y`` (elementwise). + err_msg : str, optional + Error message to display on failure. + check_dtype, check_shape : bool, default: True + Whether to check agreement between actual and desired dtypes and shapes + check_scalar : bool, default: False + NumPy only: whether to check agreement between actual and desired types - + 0d array vs scalar. - # JAX/Dask arrays work with `np.testing` - actual_np = actual if actual_np is None else actual_np - desired_np = desired if desired_np is None else desired_np - np.testing.assert_array_equal(actual_np, desired_np, err_msg=err_msg) # pyright: ignore[reportUnknownArgumentType] + See Also + -------- + xp_assert_close : Similar function for inexact equality checks. + numpy.testing.assert_array_equal : Similar function for NumPy arrays. + """ + xp = _check_ns_shape_dtype(x, y, check_dtype, check_shape, check_scalar) + if not _is_materializable(x): + return + x_np = as_numpy_array(x, xp=xp) + y_np = as_numpy_array(y, xp=xp) + np.testing.assert_array_less(x_np, y_np, err_msg=err_msg) def xp_assert_close( @@ -127,6 +217,9 @@ def xp_assert_close( rtol: float | None = None, atol: float = 0, err_msg: str = "", + check_dtype: bool = True, + check_shape: bool = True, + check_scalar: bool = False, ) -> None: """ Array-API compatible version of `np.testing.assert_allclose`. @@ -143,6 +236,11 @@ def xp_assert_close( Absolute tolerance. Default: 0. err_msg : str, optional Error message to display on failure. + check_dtype, check_shape : bool, default: True + Whether to check agreement between actual and desired dtypes and shapes + check_scalar : bool, default: False + NumPy only: whether to check agreement between actual and desired types - + 0d array vs scalar. See Also -------- @@ -154,55 +252,33 @@ def xp_assert_close( ----- The default `atol` and `rtol` differ from `xp.all(xpx.isclose(a, b))`. """ - xp = _check_ns_shape_dtype(actual, desired) - - floating = xp.isdtype(actual.dtype, ("real floating", "complex floating")) - if rtol is None and floating: - # multiplier of 4 is used as for `np.float64` this puts the default `rtol` - # roughly half way between sqrt(eps) and the default for - # `numpy.testing.assert_allclose`, 1e-7 - rtol = xp.finfo(actual.dtype).eps ** 0.5 * 4 - elif rtol is None: - rtol = 1e-7 - - if is_cupy_namespace(xp): - xp.testing.assert_allclose( - actual, desired, rtol=rtol, atol=atol, err_msg=err_msg - ) - elif is_torch_namespace(xp): - xp.testing.assert_close( - actual, desired, rtol=rtol, atol=atol, equal_nan=True, msg=err_msg or None - ) - else: - import numpy as np # pylint: disable=import-outside-toplevel - - if is_pydata_sparse_namespace(xp): - actual = actual.todense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] - desired = desired.todense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] - - actual_np = None - desired_np = None - if is_array_api_strict_namespace(xp): - # __array__ doesn't work on array-api-strict device arrays - # We need to convert to the CPU device first - actual_np = np.asarray(xp.asarray(actual, device=xp.Device("CPU_DEVICE"))) - desired_np = np.asarray(xp.asarray(desired, device=xp.Device("CPU_DEVICE"))) - - # JAX/Dask arrays work with `np.testing` - actual_np = actual if actual_np is None else actual_np - desired_np = desired if desired_np is None else desired_np - - assert isinstance(rtol, float) - np.testing.assert_allclose( # pyright: ignore[reportCallIssue] - actual_np, # type: ignore[arg-type] # pyright: ignore[reportArgumentType] - desired_np, # type: ignore[arg-type] # pyright: ignore[reportArgumentType] - rtol=rtol, - atol=atol, - err_msg=err_msg, - ) - - -def xfail(request: pytest.FixtureRequest, reason: str) -> None: + xp = _check_ns_shape_dtype(actual, desired, check_dtype, check_shape, check_scalar) + if not _is_materializable(actual): + return + + if rtol is None: + if xp.isdtype(actual.dtype, ("real floating", "complex floating")): + # multiplier of 4 is used as for `np.float64` this puts the default `rtol` + # roughly half way between sqrt(eps) and the default for + # `numpy.testing.assert_allclose`, 1e-7 + rtol = xp.finfo(actual.dtype).eps ** 0.5 * 4 + else: + rtol = 1e-7 + + actual_np = as_numpy_array(actual, xp=xp) + desired_np = as_numpy_array(desired, xp=xp) + np.testing.assert_allclose( # pyright: ignore[reportCallIssue] + actual_np, + desired_np, + rtol=rtol, # pyright: ignore[reportArgumentType] + atol=atol, + err_msg=err_msg, + ) + + +def xfail( + request: pytest.FixtureRequest, *, reason: str, strict: bool | None = None +) -> None: """ XFAIL the currently running test. @@ -216,5 +292,13 @@ def xfail(request: pytest.FixtureRequest, reason: str) -> None: ``request`` argument of the test function. reason : str Reason for the expected failure. + strict: bool, optional + If True, the test will be marked as failed if it passes. + If False, the test will be marked as passed if it fails. + Default: ``xfail_strict`` value in ``pyproject.toml``, or False if absent. """ - request.node.add_marker(pytest.mark.xfail(reason=reason)) + if strict is not None: + marker = pytest.mark.xfail(reason=reason, strict=strict) + else: + marker = pytest.mark.xfail(reason=reason) + request.node.add_marker(marker) diff --git a/sklearn/externals/array_api_extra/_lib/_utils/_compat.py b/sklearn/externals/array_api_extra/_lib/_utils/_compat.py index b9997450d23b5..82ce76b8ecbcd 100644 --- a/sklearn/externals/array_api_extra/_lib/_utils/_compat.py +++ b/sklearn/externals/array_api_extra/_lib/_utils/_compat.py @@ -2,6 +2,7 @@ # Allow packages that vendor both `array-api-extra` and # `array-api-compat` to override the import location +# pylint: disable=duplicate-code try: from ...._array_api_compat_vendor import ( array_namespace, @@ -23,6 +24,7 @@ is_torch_namespace, is_writeable_array, size, + to_device, ) except ImportError: from array_api_compat import ( @@ -45,6 +47,7 @@ is_torch_namespace, is_writeable_array, size, + to_device, ) __all__ = [ @@ -67,4 +70,5 @@ "is_torch_namespace", "is_writeable_array", "size", + "to_device", ] diff --git a/sklearn/externals/array_api_extra/_lib/_utils/_compat.pyi b/sklearn/externals/array_api_extra/_lib/_utils/_compat.pyi index f40d7556dee87..48addda41c5bc 100644 --- a/sklearn/externals/array_api_extra/_lib/_utils/_compat.pyi +++ b/sklearn/externals/array_api_extra/_lib/_utils/_compat.pyi @@ -4,6 +4,7 @@ from __future__ import annotations from types import ModuleType +from typing import Any, TypeGuard # TODO import from typing (requires Python >=3.13) from typing_extensions import TypeIs @@ -12,29 +13,33 @@ from ._typing import Array, Device # pylint: disable=missing-class-docstring,unused-argument -class Namespace(ModuleType): - def device(self, x: Array, /) -> Device: ... - def array_namespace( *xs: Array | complex | None, api_version: str | None = None, use_compat: bool | None = None, -) -> Namespace: ... +) -> ModuleType: ... def device(x: Array, /) -> Device: ... def is_array_api_obj(x: object, /) -> TypeIs[Array]: ... -def is_array_api_strict_namespace(xp: ModuleType, /) -> TypeIs[Namespace]: ... -def is_cupy_namespace(xp: ModuleType, /) -> TypeIs[Namespace]: ... -def is_dask_namespace(xp: ModuleType, /) -> TypeIs[Namespace]: ... -def is_jax_namespace(xp: ModuleType, /) -> TypeIs[Namespace]: ... -def is_numpy_namespace(xp: ModuleType, /) -> TypeIs[Namespace]: ... -def is_pydata_sparse_namespace(xp: ModuleType, /) -> TypeIs[Namespace]: ... -def is_torch_namespace(xp: ModuleType, /) -> TypeIs[Namespace]: ... -def is_cupy_array(x: object, /) -> TypeIs[Array]: ... -def is_dask_array(x: object, /) -> TypeIs[Array]: ... -def is_jax_array(x: object, /) -> TypeIs[Array]: ... -def is_numpy_array(x: object, /) -> TypeIs[Array]: ... -def is_pydata_sparse_array(x: object, /) -> TypeIs[Array]: ... -def is_torch_array(x: object, /) -> TypeIs[Array]: ... -def is_lazy_array(x: object, /) -> TypeIs[Array]: ... -def is_writeable_array(x: object, /) -> TypeIs[Array]: ... +def is_array_api_strict_namespace(xp: ModuleType, /) -> bool: ... +def is_cupy_namespace(xp: ModuleType, /) -> bool: ... +def is_dask_namespace(xp: ModuleType, /) -> bool: ... +def is_jax_namespace(xp: ModuleType, /) -> bool: ... +def is_numpy_namespace(xp: ModuleType, /) -> bool: ... +def is_pydata_sparse_namespace(xp: ModuleType, /) -> bool: ... +def is_torch_namespace(xp: ModuleType, /) -> bool: ... +def is_cupy_array(x: object, /) -> TypeGuard[Array]: ... +def is_dask_array(x: object, /) -> TypeGuard[Array]: ... +def is_jax_array(x: object, /) -> TypeGuard[Array]: ... +def is_numpy_array(x: object, /) -> TypeGuard[Array]: ... +def is_pydata_sparse_array(x: object, /) -> TypeGuard[Array]: ... +def is_torch_array(x: object, /) -> TypeGuard[Array]: ... +def is_lazy_array(x: object, /) -> TypeGuard[Array]: ... +def is_writeable_array(x: object, /) -> TypeGuard[Array]: ... def size(x: Array, /) -> int | None: ... +def to_device( # type: ignore[explicit-any] + x: Array, + device: Device, # pylint: disable=redefined-outer-name + /, + *, + stream: int | Any | None = None, +) -> Array: ... diff --git a/sklearn/externals/array_api_extra/_lib/_utils/_helpers.py b/sklearn/externals/array_api_extra/_lib/_utils/_helpers.py index 9882d72e6c0ac..3e43fa91204d9 100644 --- a/sklearn/externals/array_api_extra/_lib/_utils/_helpers.py +++ b/sklearn/externals/array_api_extra/_lib/_utils/_helpers.py @@ -2,32 +2,61 @@ from __future__ import annotations +import io import math -from collections.abc import Generator, Iterable +import pickle +import types +from collections.abc import Callable, Generator, Iterable +from functools import wraps from types import ModuleType -from typing import TYPE_CHECKING, cast +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Generic, + Literal, + ParamSpec, + TypeAlias, + TypeVar, + cast, +) from . import _compat from ._compat import ( array_namespace, is_array_api_obj, is_dask_namespace, + is_jax_namespace, is_numpy_array, + is_pydata_sparse_namespace, + is_torch_namespace, ) -from ._typing import Array +from ._typing import Array, Device if TYPE_CHECKING: # pragma: no cover - # TODO import from typing (requires Python >=3.13) - from typing_extensions import TypeIs + # TODO import from typing (requires Python >=3.12 and >=3.13) + from typing_extensions import TypeIs, override +else: + + def override(func): + return func + + +P = ParamSpec("P") +T = TypeVar("T") __all__ = [ "asarrays", + "capabilities", "eager_shape", "in1d", "is_python_scalar", + "jax_autojit", "mean", "meta_namespace", + "pickle_flatten", + "pickle_unflatten", ] @@ -270,3 +299,298 @@ def meta_namespace( # Quietly skip scalars and None's metas = [cast(Array | None, getattr(a, "_meta", None)) for a in arrays] return array_namespace(*metas) + + +def capabilities( + xp: ModuleType, *, device: Device | None = None +) -> dict[str, int | None]: + """ + Return patched ``xp.__array_namespace_info__().capabilities()``. + + TODO this helper should be eventually removed once all the special cases + it handles are fixed in the respective backends. + + Parameters + ---------- + xp : array_namespace + The standard-compatible namespace. + device : Device, optional + The device to use. + + Returns + ------- + dict + Capabilities of the namespace. + """ + if is_pydata_sparse_namespace(xp): + # No __array_namespace_info__(); no indexing by sparse arrays + return { + "boolean indexing": False, + "data-dependent shapes": True, + "max dimensions": None, + } + out = xp.__array_namespace_info__().capabilities() + if is_jax_namespace(xp) and out["boolean indexing"]: + # FIXME https://github.com/jax-ml/jax/issues/27418 + # Fixed in jax >=0.6.0 + out = out.copy() + out["boolean indexing"] = False + if is_torch_namespace(xp): + # FIXME https://github.com/data-apis/array-api/issues/945 + device = xp.get_default_device() if device is None else xp.device(device) + if device.type == "meta": # type: ignore[union-attr] # pyright: ignore[reportAttributeAccessIssue,reportOptionalMemberAccess] + out = out.copy() + out["boolean indexing"] = False + out["data-dependent shapes"] = False + return out + + +_BASIC_PICKLED_TYPES = frozenset(( + bool, int, float, complex, str, bytes, bytearray, + list, tuple, dict, set, frozenset, range, slice, + types.NoneType, types.EllipsisType, +)) # fmt: skip +_BASIC_REST_TYPES = frozenset(( + type, types.BuiltinFunctionType, types.FunctionType, types.ModuleType +)) # fmt: skip + +FlattenRest: TypeAlias = tuple[object, ...] + + +def pickle_flatten( + obj: object, cls: type[T] | tuple[type[T], ...] +) -> tuple[list[T], FlattenRest]: + """ + Use the pickle machinery to extract objects out of an arbitrary container. + + Unlike regular ``pickle.dumps``, this function always succeeds. + + Parameters + ---------- + obj : object + The object to pickle. + cls : type | tuple[type, ...] + One or multiple classes to extract from the object. + The instances of these classes inside ``obj`` will not be pickled. + + Returns + ------- + instances : list[cls] + All instances of ``cls`` found inside ``obj`` (not pickled). + rest + Opaque object containing the pickled bytes plus all other objects where + ``__reduce__`` / ``__reduce_ex__`` is either not implemented or raised. + These are unpickleable objects, types, modules, and functions. + + This object is *typically* hashable save for fairly exotic objects + that are neither pickleable nor hashable. + + This object is pickleable if everything except ``instances`` was pickleable + in the input object. + + See Also + -------- + pickle_unflatten : Reverse function. + + Examples + -------- + >>> class A: + ... def __repr__(self): + ... return "" + >>> class NS: + ... def __repr__(self): + ... return "" + ... def __reduce__(self): + ... assert False, "not serializable" + >>> obj = {1: A(), 2: [A(), NS(), A()]} + >>> instances, rest = pickle_flatten(obj, A) + >>> instances + [, , ] + >>> pickle_unflatten(instances, rest) + {1: , 2: [, , ]} + + This can be also used to swap inner objects; the only constraint is that + the number of objects in and out must be the same: + + >>> pickle_unflatten(["foo", "bar", "baz"], rest) + {1: "foo", 2: ["bar", , "baz"]} + """ + instances: list[T] = [] + rest: list[object] = [] + + class Pickler(pickle.Pickler): # numpydoc ignore=GL08 + """ + Use the `pickle.Pickler.persistent_id` hook to extract objects. + """ + + @override + def persistent_id( + self, obj: object + ) -> Literal[0, 1, None]: # numpydoc ignore=GL08 + if isinstance(obj, cls): + instances.append(obj) # type: ignore[arg-type] + return 0 + + typ_ = type(obj) + if typ_ in _BASIC_PICKLED_TYPES: # No subclasses! + # If obj is a collection, recursively descend inside it + return None + if typ_ in _BASIC_REST_TYPES: + rest.append(obj) + return 1 + + try: + # Note: a class that defines __slots__ without defining __getstate__ + # cannot be pickled with __reduce__(), but can with __reduce_ex__(5) + _ = obj.__reduce_ex__(pickle.HIGHEST_PROTOCOL) + except Exception: # pylint: disable=broad-exception-caught + rest.append(obj) + return 1 + + # Object can be pickled. Let the Pickler recursively descend inside it. + return None + + f = io.BytesIO() + p = Pickler(f, protocol=pickle.HIGHEST_PROTOCOL) + p.dump(obj) + return instances, (f.getvalue(), *rest) + + +def pickle_unflatten(instances: Iterable[object], rest: FlattenRest) -> Any: # type: ignore[explicit-any] + """ + Reverse of ``pickle_flatten``. + + Parameters + ---------- + instances : Iterable + Inner objects to be reinserted into the flattened container. + rest : FlattenRest + Extra bits, as returned by ``pickle_flatten``. + + Returns + ------- + object + The outer object originally passed to ``pickle_flatten`` after a + pickle->unpickle round-trip. + + See Also + -------- + pickle_flatten : Serializing function. + pickle.loads : Standard unpickle function. + + Notes + ----- + The `instances` iterable must yield at least the same number of elements as the ones + returned by ``pickle_flatten``, but the elements do not need to be the same objects + or even the same types of objects. Excess elements, if any, will be left untouched. + """ + iters = iter(instances), iter(rest) + pik = cast(bytes, next(iters[1])) + + class Unpickler(pickle.Unpickler): # numpydoc ignore=GL08 + """Mirror of the overridden Pickler in pickle_flatten.""" + + @override + def persistent_load(self, pid: Literal[0, 1]) -> object: # numpydoc ignore=GL08 + try: + return next(iters[pid]) + except StopIteration as e: + msg = "Not enough objects to unpickle" + raise ValueError(msg) from e + + f = io.BytesIO(pik) + return Unpickler(f).load() + + +class _AutoJITWrapper(Generic[T]): # numpydoc ignore=PR01 + """ + Helper of :func:`jax_autojit`. + + Wrap arbitrary inputs and outputs of the jitted function and + convert them to/from PyTrees. + """ + + obj: T + _registered: ClassVar[bool] = False + __slots__: tuple[str, ...] = ("obj",) + + def __init__(self, obj: T) -> None: # numpydoc ignore=GL08 + self._register() + self.obj = obj + + @classmethod + def _register(cls): # numpydoc ignore=SS06 + """ + Register upon first use instead of at import time, to avoid + globally importing JAX. + """ + if not cls._registered: + import jax + + jax.tree_util.register_pytree_node( + cls, + lambda obj: pickle_flatten(obj, jax.Array), # pyright: ignore[reportUnknownArgumentType] + lambda aux_data, children: pickle_unflatten(children, aux_data), # pyright: ignore[reportUnknownArgumentType] + ) + cls._registered = True + + +def jax_autojit( + func: Callable[P, T], +) -> Callable[P, T]: # numpydoc ignore=PR01,RT01,SS03 + """ + Wrap `func` with ``jax.jit``, with the following differences: + + - Python scalar arguments and return values are not automatically converted to + ``jax.Array`` objects. + - All non-array arguments are automatically treated as static. + Unlike ``jax.jit``, static arguments must be either hashable or serializable with + ``pickle``. + - Unlike ``jax.jit``, non-array arguments and return values are not limited to + tuple/list/dict, but can be any object serializable with ``pickle``. + - Automatically descend into non-array arguments and find ``jax.Array`` objects + inside them, then rebuild the arguments when entering `func`, swapping the JAX + concrete arrays with tracer objects. + - Automatically descend into non-array return values and find ``jax.Array`` objects + inside them, then rebuild them downstream of exiting the JIT, swapping the JAX + tracer objects with concrete arrays. + + See Also + -------- + jax.jit : JAX JIT compilation function. + + Notes + ----- + These are useful choices *for testing purposes only*, which is how this function is + intended to be used. The output of ``jax.jit`` is a C++ level callable, that + directly dispatches to the compiled kernel after the initial call. In comparison, + ``jax_autojit`` incurs a much higher dispatch time. + + Additionally, consider:: + + def f(x: Array, y: float, plus: bool) -> Array: + return x + y if plus else x - y + + j1 = jax.jit(f, static_argnames="plus") + j2 = jax_autojit(f) + + In the above example, ``j2`` requires a lot less setup to be tested effectively than + ``j1``, but on the flip side it means that it will be re-traced for every different + value of ``y``, which likely makes it not fit for purpose in production. + """ + import jax + + @jax.jit # type: ignore[misc] # pyright: ignore[reportUntypedFunctionDecorator] + def inner( # type: ignore[decorated-any,explicit-any] # numpydoc ignore=GL08 + wargs: _AutoJITWrapper[Any], + ) -> _AutoJITWrapper[T]: + args, kwargs = wargs.obj + res = func(*args, **kwargs) # pyright: ignore[reportCallIssue] + return _AutoJITWrapper(res) + + @wraps(func) + def outer(*args: P.args, **kwargs: P.kwargs) -> T: # numpydoc ignore=GL08 + wargs = _AutoJITWrapper((args, kwargs)) + return inner(wargs).obj + + return outer diff --git a/sklearn/externals/array_api_extra/_lib/_utils/_typing.py b/sklearn/externals/array_api_extra/_lib/_utils/_typing.py index d32a3a07c1ee9..8204be4759610 100644 --- a/sklearn/externals/array_api_extra/_lib/_utils/_typing.py +++ b/sklearn/externals/array_api_extra/_lib/_utils/_typing.py @@ -1,5 +1,5 @@ # numpydoc ignore=GL08 -# pylint: disable=missing-module-docstring +# pylint: disable=missing-module-docstring,duplicate-code Array = object DType = object diff --git a/sklearn/externals/array_api_extra/testing.py b/sklearn/externals/array_api_extra/testing.py index 4f8288cf582ec..3979f9ddf65c1 100644 --- a/sklearn/externals/array_api_extra/testing.py +++ b/sklearn/externals/array_api_extra/testing.py @@ -7,12 +7,15 @@ from __future__ import annotations import contextlib -from collections.abc import Callable, Iterable, Iterator, Sequence +import enum +import warnings +from collections.abc import Callable, Iterator, Sequence from functools import wraps from types import ModuleType from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, cast from ._lib._utils._compat import is_dask_namespace, is_jax_namespace +from ._lib._utils._helpers import jax_autojit, pickle_flatten, pickle_unflatten __all__ = ["lazy_xp_function", "patch_lazy_xp_functions"] @@ -26,7 +29,7 @@ # Sphinx hacks SchedulerGetCallable = object - def override(func: object) -> object: + def override(func): return func @@ -36,13 +39,22 @@ def override(func: object) -> object: _ufuncs_tags: dict[object, dict[str, Any]] = {} # type: ignore[explicit-any] +class Deprecated(enum.Enum): + """Unique type for deprecated parameters.""" + + DEPRECATED = 1 + + +DEPRECATED = Deprecated.DEPRECATED + + def lazy_xp_function( # type: ignore[explicit-any] func: Callable[..., Any], *, - allow_dask_compute: int = 0, + allow_dask_compute: bool | int = False, jax_jit: bool = True, - static_argnums: int | Sequence[int] | None = None, - static_argnames: str | Iterable[str] | None = None, + static_argnums: Deprecated = DEPRECATED, + static_argnames: Deprecated = DEPRECATED, ) -> None: # numpydoc ignore=GL07 """ Tag a function to be tested on lazy backends. @@ -59,9 +71,10 @@ def lazy_xp_function( # type: ignore[explicit-any] ---------- func : callable Function to be tested. - allow_dask_compute : int, optional - Number of times `func` is allowed to internally materialize the Dask graph. This - is typically triggered by ``bool()``, ``float()``, or ``np.asarray()``. + allow_dask_compute : bool | int, optional + Whether `func` is allowed to internally materialize the Dask graph, or maximum + number of times it is allowed to do so. This is typically triggered by + ``bool()``, ``float()``, or ``np.asarray()``. Set to 1 if you are aware that `func` converts the input parameters to NumPy and want to let it do so at least for the time being, knowing that it is going to be @@ -75,19 +88,37 @@ def lazy_xp_function( # type: ignore[explicit-any] a test function that invokes `func` multiple times should still work with this parameter set to 1. - Default: 0, meaning that `func` must be fully lazy and never materialize the + Set to True to allow `func` to materialize the graph an unlimited number + of times. + + Default: False, meaning that `func` must be fully lazy and never materialize the graph. jax_jit : bool, optional - Set to True to replace `func` with ``jax.jit(func)`` after calling the - :func:`patch_lazy_xp_functions` test helper with ``xp=jax.numpy``. Set to False - if `func` is only compatible with eager (non-jitted) JAX. Default: True. - static_argnums : int | Sequence[int], optional - Passed to jax.jit. Positional arguments to treat as static (compile-time - constant). Default: infer from `static_argnames` using - `inspect.signature(func)`. - static_argnames : str | Iterable[str], optional - Passed to jax.jit. Named arguments to treat as static (compile-time constant). - Default: infer from `static_argnums` using `inspect.signature(func)`. + Set to True to replace `func` with a smart variant of ``jax.jit(func)`` after + calling the :func:`patch_lazy_xp_functions` test helper with ``xp=jax.numpy``. + This is the default behaviour. + Set to False if `func` is only compatible with eager (non-jitted) JAX. + + Unlike with vanilla ``jax.jit``, all arguments and return types that are not JAX + arrays are treated as static; the function can accept and return arbitrary + wrappers around JAX arrays. This difference is because, in real life, most users + won't wrap the function directly with ``jax.jit`` but rather they will use it + within their own code, which is itself then wrapped by ``jax.jit``, and + internally consume the function's outputs. + + In other words, the pattern that is being tested is:: + + >>> @jax.jit + ... def user_func(x): + ... y = user_prepares_inputs(x) + ... z = func(y, some_static_arg=True) + ... return user_consumes(z) + + Default: True. + static_argnums : + Deprecated; ignored + static_argnames : + Deprecated; ignored See Also -------- @@ -104,7 +135,7 @@ def lazy_xp_function( # type: ignore[explicit-any] def test_myfunc(xp): a = xp.asarray([1, 2]) - # When xp=jax.numpy, this is the same as `b = jax.jit(myfunc)(a)` + # When xp=jax.numpy, this is similar to `b = jax.jit(myfunc)(a)` # When xp=dask.array, crash on compute() or persist() b = myfunc(a) @@ -164,12 +195,20 @@ def test_myfunc(xp): b = mymodule.myfunc(a) # This is wrapped when xp=jax.numpy or xp=dask.array c = naked.myfunc(a) # This is not """ + if static_argnums is not DEPRECATED or static_argnames is not DEPRECATED: + warnings.warn( + ( + "The `static_argnums` and `static_argnames` parameters are deprecated " + "and ignored. They will be removed in a future version." + ), + DeprecationWarning, + stacklevel=2, + ) tags = { "allow_dask_compute": allow_dask_compute, "jax_jit": jax_jit, - "static_argnums": static_argnums, - "static_argnames": static_argnames, } + try: func._lazy_xp_function = tags # type: ignore[attr-defined] # pylint: disable=protected-access # pyright: ignore[reportFunctionMemberAccess] except AttributeError: # @cython.vectorize @@ -235,23 +274,17 @@ def iter_tagged() -> ( # type: ignore[explicit-any] if is_dask_namespace(xp): for mod, name, func, tags in iter_tagged(): n = tags["allow_dask_compute"] + if n is True: + n = 1_000_000 + elif n is False: + n = 0 wrapped = _dask_wrap(func, n) monkeypatch.setattr(mod, name, wrapped) elif is_jax_namespace(xp): - import jax - for mod, name, func, tags in iter_tagged(): if tags["jax_jit"]: - # suppress unused-ignore to run mypy in -e lint as well as -e dev - wrapped = cast( # type: ignore[explicit-any] - Callable[..., Any], - jax.jit( - func, - static_argnums=tags["static_argnums"], - static_argnames=tags["static_argnames"], - ), - ) + wrapped = jax_autojit(func) monkeypatch.setattr(mod, name, wrapped) @@ -300,6 +333,7 @@ def _dask_wrap( After the function returns, materialize the graph in order to re-raise exceptions. """ import dask + import dask.array as da func_name = getattr(func, "__name__", str(func)) n_str = f"only up to {n}" if n else "no" @@ -319,6 +353,8 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: # numpydoc ignore=GL08 # Block until the graph materializes and reraise exceptions. This allows # `pytest.raises` and `pytest.warns` to work as expected. Note that this would # not work on scheduler='distributed', as it would not block. - return dask.persist(out, scheduler="threads")[0] # type: ignore[attr-defined,no-untyped-call,func-returns-value,index] # pyright: ignore[reportPrivateImportUsage] + arrays, rest = pickle_flatten(out, da.Array) + arrays = dask.persist(arrays, scheduler="threads")[0] # type: ignore[attr-defined,no-untyped-call,func-returns-value,index] # pyright: ignore[reportPrivateImportUsage] + return pickle_unflatten(arrays, rest) # pyright: ignore[reportUnknownArgumentType] return wrapper