Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion maint_tools/vendor_array_api_extra.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ set -o nounset
set -o errexit

URL="https://github.com/data-apis/array-api-extra.git"
VERSION="v0.7.1"
VERSION="v0.8.0"

ROOT_DIR=sklearn/externals/array_api_extra

Expand Down
7 changes: 5 additions & 2 deletions sklearn/externals/array_api_extra/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
"""Extra array functions built on top of the array API standard."""

from ._delegation import isclose, pad
from ._delegation import isclose, one_hot, pad
from ._lib._at import at
from ._lib._funcs import (
apply_where,
atleast_nd,
broadcast_shapes,
cov,
create_diagonal,
default_dtype,
expand_dims,
kron,
nunique,
Expand All @@ -16,7 +17,7 @@
)
from ._lib._lazy import lazy_apply

__version__ = "0.7.1"
__version__ = "0.8.0"

# pylint: disable=duplicate-code
__all__ = [
Expand All @@ -27,11 +28,13 @@
"broadcast_shapes",
"cov",
"create_diagonal",
"default_dtype",
"expand_dims",
"isclose",
"kron",
"lazy_apply",
"nunique",
"one_hot",
"pad",
"setdiff1d",
"sinc",
Expand Down
135 changes: 106 additions & 29 deletions sklearn/externals/array_api_extra/_delegation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,21 @@
from types import ModuleType
from typing import Literal

from ._lib import Backend, _funcs
from ._lib._utils._compat import array_namespace
from ._lib import _funcs
from ._lib._utils._compat import (
array_namespace,
is_cupy_namespace,
is_dask_namespace,
is_jax_namespace,
is_numpy_namespace,
is_pydata_sparse_namespace,
is_torch_namespace,
)
from ._lib._utils._compat import device as get_device
from ._lib._utils._helpers import asarrays
from ._lib._utils._typing import Array
from ._lib._utils._typing import Array, DType

__all__ = ["isclose", "pad"]


def _delegate(xp: ModuleType, *backends: Backend) -> bool:
"""
Check whether `xp` is one of the `backends` to delegate to.

Parameters
----------
xp : array_namespace
Array namespace to check.
*backends : IsNamespace
Arbitrarily many backends (from the ``IsNamespace`` enum) to check.

Returns
-------
bool
``True`` if `xp` matches one of the `backends`, ``False`` otherwise.
"""
return any(backend.is_namespace(xp) for backend in backends)
__all__ = ["isclose", "one_hot", "pad"]


def isclose(
Expand Down Expand Up @@ -108,16 +98,98 @@ def isclose(
"""
xp = array_namespace(a, b) if xp is None else xp

if _delegate(xp, Backend.NUMPY, Backend.CUPY, Backend.DASK, Backend.JAX):
if (
is_numpy_namespace(xp)
or is_cupy_namespace(xp)
or is_dask_namespace(xp)
or is_jax_namespace(xp)
):
return xp.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)

if _delegate(xp, Backend.TORCH):
if is_torch_namespace(xp):
a, b = asarrays(a, b, xp=xp) # Array API 2024.12 support
return xp.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)

return _funcs.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan, xp=xp)


def one_hot(
x: Array,
/,
num_classes: int,
*,
dtype: DType | None = None,
axis: int = -1,
xp: ModuleType | None = None,
) -> Array:
"""
One-hot encode the given indices.

Each index in the input `x` is encoded as a vector of zeros of length `num_classes`
with the element at the given index set to one.

Parameters
----------
x : array
An array with integral dtype whose values are between `0` and `num_classes - 1`.
num_classes : int
Number of classes in the one-hot dimension.
dtype : DType, optional
The dtype of the return value. Defaults to the default float dtype (usually
float64).
axis : int, optional
Position in the expanded axes where the new axis is placed. Default: -1.
xp : array_namespace, optional
The standard-compatible namespace for `x`. Default: infer.

Returns
-------
array
An array having the same shape as `x` except for a new axis at the position
given by `axis` having size `num_classes`. If `axis` is unspecified, it
defaults to -1, which appends a new axis.

If ``x < 0`` or ``x >= num_classes``, then the result is undefined, may raise
an exception, or may even cause a bad state. `x` is not checked.

Examples
--------
>>> import array_api_extra as xpx
>>> import array_api_strict as xp
>>> xpx.one_hot(xp.asarray([1, 2, 0]), 3)
Array([[0., 1., 0.],
[0., 0., 1.],
[1., 0., 0.]], dtype=array_api_strict.float64)
"""
# Validate inputs.
if xp is None:
xp = array_namespace(x)
if not xp.isdtype(x.dtype, "integral"):
msg = "x must have an integral dtype."
raise TypeError(msg)
if dtype is None:
dtype = _funcs.default_dtype(xp, device=get_device(x))
# Delegate where possible.
if is_jax_namespace(xp):
from jax.nn import one_hot as jax_one_hot

return jax_one_hot(x, num_classes, dtype=dtype, axis=axis)
if is_torch_namespace(xp):
from torch.nn.functional import one_hot as torch_one_hot

x = xp.astype(x, xp.int64) # PyTorch only supports int64 here.
try:
out = torch_one_hot(x, num_classes)
except RuntimeError as e:
raise IndexError from e
else:
out = _funcs.one_hot(x, num_classes, xp=xp)
out = xp.astype(out, dtype, copy=False)
if axis != -1:
out = xp.moveaxis(out, -1, axis)
return out


def pad(
x: Array,
pad_width: int | tuple[int, int] | Sequence[tuple[int, int]],
Expand Down Expand Up @@ -159,14 +231,19 @@ def pad(
msg = "Only `'constant'` mode is currently supported"
raise NotImplementedError(msg)

if (
is_numpy_namespace(xp)
or is_cupy_namespace(xp)
or is_jax_namespace(xp)
or is_pydata_sparse_namespace(xp)
):
return xp.pad(x, pad_width, mode, constant_values=constant_values)

# https://github.com/pytorch/pytorch/blob/cf76c05b4dc629ac989d1fb8e789d4fac04a095a/torch/_numpy/_funcs_impl.py#L2045-L2056
if _delegate(xp, Backend.TORCH):
if is_torch_namespace(xp):
pad_width = xp.asarray(pad_width)
pad_width = xp.broadcast_to(pad_width, (x.ndim, 2))
pad_width = xp.flip(pad_width, axis=(0,)).flatten()
return xp.nn.functional.pad(x, tuple(pad_width), value=constant_values) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]

if _delegate(xp, Backend.NUMPY, Backend.JAX, Backend.CUPY, Backend.SPARSE):
return xp.pad(x, pad_width, mode, constant_values=constant_values)

return _funcs.pad(x, pad_width, constant_values=constant_values, xp=xp)
4 changes: 0 additions & 4 deletions sklearn/externals/array_api_extra/_lib/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1 @@
"""Internals of array-api-extra."""

from ._backends import Backend

__all__ = ["Backend"]
11 changes: 10 additions & 1 deletion sklearn/externals/array_api_extra/_lib/_at.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
from types import ModuleType
from typing import TYPE_CHECKING, ClassVar, cast

from ._utils import _compat
from ._utils._compat import (
array_namespace,
is_dask_array,
is_jax_array,
is_torch_array,
is_writeable_array,
)
from ._utils._helpers import meta_namespace
Expand Down Expand Up @@ -298,7 +300,7 @@ def _op(
and idx.dtype == xp.bool
and idx.shape == x.shape
):
y_xp = xp.asarray(y, dtype=x.dtype)
y_xp = xp.asarray(y, dtype=x.dtype, device=_compat.device(x))
if y_xp.ndim == 0:
if out_of_place_op: # add(), subtract(), ...
# suppress inf warnings on Dask
Expand Down Expand Up @@ -344,6 +346,13 @@ def _op(
msg = f"Can't update read-only array {x}"
raise ValueError(msg)

# Work around bug in PyTorch where __setitem__ doesn't
# always support mismatched dtypes
# https://github.com/pytorch/pytorch/issues/150017
if is_torch_array(y):
y = xp.astype(y, x.dtype, copy=False)

# Backends without boolean indexing (other than JAX) crash here
if in_place_op: # add(), subtract(), ...
x[idx] = in_place_op(x[idx], y)
else: # set()
Expand Down
63 changes: 29 additions & 34 deletions sklearn/externals/array_api_extra/_lib/_backends.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,46 @@
"""Backends with which array-api-extra interacts in delegation and testing."""
"""Backends against which array-api-extra runs its tests."""

from collections.abc import Callable
from enum import Enum
from types import ModuleType
from typing import cast
from __future__ import annotations

from ._utils import _compat
from enum import Enum

__all__ = ["Backend"]


class Backend(Enum): # numpydoc ignore=PR01,PR02 # type: ignore[no-subclass-any]
class Backend(Enum): # numpydoc ignore=PR02
"""
All array library backends explicitly tested by array-api-extra.

Parameters
----------
value : str
Name of the backend's module.
is_namespace : Callable[[ModuleType], bool]
Function to check whether an input module is the array namespace
corresponding to the backend.
Tag of the backend's module, in the format ``<namespace>[:<extra tag>]``.
"""

ARRAY_API_STRICT = "array_api_strict", _compat.is_array_api_strict_namespace
NUMPY = "numpy", _compat.is_numpy_namespace
NUMPY_READONLY = "numpy_readonly", _compat.is_numpy_namespace
CUPY = "cupy", _compat.is_cupy_namespace
TORCH = "torch", _compat.is_torch_namespace
DASK = "dask.array", _compat.is_dask_namespace
SPARSE = "sparse", _compat.is_pydata_sparse_namespace
JAX = "jax.numpy", _compat.is_jax_namespace

def __new__(
cls, value: str, _is_namespace: Callable[[ModuleType], bool]
): # numpydoc ignore=GL08
obj = object.__new__(cls)
obj._value_ = value
return obj

def __init__(
self,
value: str, # noqa: ARG002 # pylint: disable=unused-argument
is_namespace: Callable[[ModuleType], bool],
): # numpydoc ignore=GL08
self.is_namespace = is_namespace
# Use :<tag> to prevent Enum from deduplicating items with the same value
ARRAY_API_STRICT = "array_api_strict"
ARRAY_API_STRICTEST = "array_api_strict:strictest"
NUMPY = "numpy"
NUMPY_READONLY = "numpy:readonly"
CUPY = "cupy"
TORCH = "torch"
TORCH_GPU = "torch:gpu"
DASK = "dask.array"
SPARSE = "sparse"
JAX = "jax.numpy"
JAX_GPU = "jax.numpy:gpu"

def __str__(self) -> str: # type: ignore[explicit-override] # pyright: ignore[reportImplicitOverride] # numpydoc ignore=RT01
"""Pretty-print parameterized test names."""
return cast(str, self.value)
return (
self.name.lower().replace("_gpu", ":gpu").replace("_readonly", ":readonly")
)

@property
def modname(self) -> str: # numpydoc ignore=RT01
"""Module name to be imported."""
return self.value.split(":")[0]

def like(self, *others: Backend) -> bool: # numpydoc ignore=PR01,RT01
"""Check if this backend uses the same module as others."""
return any(self.modname == other.modname for other in others)
Loading
Loading