Skip to content

ENH: at support for bool mask in Dask and JAX #121

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 80 additions & 12 deletions src/array_api_extra/_lib/_at.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@
from types import ModuleType
from typing import ClassVar, cast

from ._utils._compat import array_namespace, is_jax_array, is_writeable_array
from ._utils._compat import (
array_namespace,
is_dask_array,
is_jax_array,
is_writeable_array,
)
from ._utils._typing import Array, Index


Expand Down Expand Up @@ -141,6 +146,25 @@
not explicitly covered by ``array-api-compat``, are not supported by update
methods.

Boolean masks are supported on Dask and jitted JAX arrays exclusively
when `idx` has the same shape as `x` and `y` is 0-dimensional.
Note that this support is not available in JAX's native
``x.at[mask].set(y)``.

This pattern::

>>> mask = m(x)
>>> x[mask] = f(x[mask])

Can't be replaced by `at`, as it won't work on Dask and JAX inside jax.jit::

>>> mask = m(x)
>>> x = xpx.at(x, mask).set(f(x[mask]) # Crash on Dask and jax.jit

You should instead use::

>>> x = xp.where(m(x), f(x), x)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO suggest to use lazywhere


Examples
--------
Given either of these equivalent expressions::
Expand Down Expand Up @@ -189,6 +213,7 @@
self,
at_op: _AtOp,
in_place_op: Callable[[Array, Array | object], Array] | None,
out_of_place_op: Callable[[Array, Array], Array] | None,
y: Array | object,
/,
copy: bool | None,
Expand All @@ -210,6 +235,16 @@

x[idx] = y

out_of_place_op : Callable[[Array, Array], Array] | None
Out-of-place operation to apply when idx is a boolean mask and the backend
doesn't support in-place updates::

x = xp.where(idx, out_of_place_op(x, y), x)

If None::

x = xp.where(idx, y, x)

y : array or object
Right-hand side of the operation.
copy : bool or None
Expand All @@ -223,6 +258,7 @@
Updated `x`.
"""
x, idx = self._x, self._idx
xp = array_namespace(x, y) if xp is None else xp

if idx is _undef:
msg = (
Expand All @@ -247,15 +283,41 @@
else:
writeable = is_writeable_array(x)

# JAX inside jax.jit and Dask don't support in-place updates with boolean
# mask. However we can handle the common special case of 0-dimensional y
# with where(idx, y, x) instead.
if (
(is_dask_array(idx) or is_jax_array(idx))
and idx.dtype == xp.bool
and idx.shape == x.shape
):
y_xp = xp.asarray(y, dtype=x.dtype)
if y_xp.ndim == 0:
if out_of_place_op:

Check warning on line 296 in src/array_api_extra/_lib/_at.py

View check run for this annotation

Codecov / codecov/patch

src/array_api_extra/_lib/_at.py#L294-L296

Added lines #L294 - L296 were not covered by tests
# FIXME: suppress inf warnings on dask with lazywhere
out = xp.where(idx, out_of_place_op(x, y_xp), x)

Check warning on line 298 in src/array_api_extra/_lib/_at.py

View check run for this annotation

Codecov / codecov/patch

src/array_api_extra/_lib/_at.py#L298

Added line #L298 was not covered by tests
# Undo int->float promotion on JAX after _AtOp.DIVIDE
out = xp.astype(out, x.dtype, copy=False)

Check warning on line 300 in src/array_api_extra/_lib/_at.py

View check run for this annotation

Codecov / codecov/patch

src/array_api_extra/_lib/_at.py#L300

Added line #L300 was not covered by tests
else:
out = xp.where(idx, y_xp, x)

Check warning on line 302 in src/array_api_extra/_lib/_at.py

View check run for this annotation

Codecov / codecov/patch

src/array_api_extra/_lib/_at.py#L302

Added line #L302 was not covered by tests

if copy:
return out
x[()] = out
return x

Check warning on line 307 in src/array_api_extra/_lib/_at.py

View check run for this annotation

Codecov / codecov/patch

src/array_api_extra/_lib/_at.py#L304-L307

Added lines #L304 - L307 were not covered by tests
# else: this will work on eager JAX and crash on jax.jit and Dask

if copy:
if is_jax_array(x):
# Use JAX's at[]
func = cast(Callable[[Array], Array], getattr(x.at[idx], at_op.value))
return func(y)
out = func(y)

Check warning on line 314 in src/array_api_extra/_lib/_at.py

View check run for this annotation

Codecov / codecov/patch

src/array_api_extra/_lib/_at.py#L314

Added line #L314 was not covered by tests
# Undo int->float promotion on JAX after _AtOp.DIVIDE
return xp.astype(out, x.dtype, copy=False)

Check warning on line 316 in src/array_api_extra/_lib/_at.py

View check run for this annotation

Codecov / codecov/patch

src/array_api_extra/_lib/_at.py#L316

Added line #L316 was not covered by tests

# Emulate at[] behaviour for non-JAX arrays
# with a copy followed by an update
if xp is None:
xp = array_namespace(x)

x = xp.asarray(x, copy=True)
if writeable is False:
# A copy of a read-only numpy array is writeable
Expand Down Expand Up @@ -283,7 +345,7 @@
xp: ModuleType | None = None,
) -> Array: # numpydoc ignore=PR01,RT01
"""Apply ``x[idx] = y`` and return the update array."""
return self._op(_AtOp.SET, None, y, copy=copy, xp=xp)
return self._op(_AtOp.SET, None, None, y, copy=copy, xp=xp)

def add(
self,
Expand All @@ -297,7 +359,7 @@
# Note for this and all other methods based on _iop:
# operator.iadd and operator.add subtly differ in behaviour, as
# only iadd will trigger exceptions when y has an incompatible dtype.
return self._op(_AtOp.ADD, operator.iadd, y, copy=copy, xp=xp)
return self._op(_AtOp.ADD, operator.iadd, operator.add, y, copy=copy, xp=xp)

def subtract(
self,
Expand All @@ -307,7 +369,9 @@
xp: ModuleType | None = None,
) -> Array: # numpydoc ignore=PR01,RT01
"""Apply ``x[idx] -= y`` and return the updated array."""
return self._op(_AtOp.SUBTRACT, operator.isub, y, copy=copy, xp=xp)
return self._op(
_AtOp.SUBTRACT, operator.isub, operator.sub, y, copy=copy, xp=xp
)

def multiply(
self,
Expand All @@ -317,7 +381,9 @@
xp: ModuleType | None = None,
) -> Array: # numpydoc ignore=PR01,RT01
"""Apply ``x[idx] *= y`` and return the updated array."""
return self._op(_AtOp.MULTIPLY, operator.imul, y, copy=copy, xp=xp)
return self._op(
_AtOp.MULTIPLY, operator.imul, operator.mul, y, copy=copy, xp=xp
)

def divide(
self,
Expand All @@ -327,7 +393,9 @@
xp: ModuleType | None = None,
) -> Array: # numpydoc ignore=PR01,RT01
"""Apply ``x[idx] /= y`` and return the updated array."""
return self._op(_AtOp.DIVIDE, operator.itruediv, y, copy=copy, xp=xp)
return self._op(
_AtOp.DIVIDE, operator.itruediv, operator.truediv, y, copy=copy, xp=xp
)

def power(
self,
Expand All @@ -337,7 +405,7 @@
xp: ModuleType | None = None,
) -> Array: # numpydoc ignore=PR01,RT01
"""Apply ``x[idx] **= y`` and return the updated array."""
return self._op(_AtOp.POWER, operator.ipow, y, copy=copy, xp=xp)
return self._op(_AtOp.POWER, operator.ipow, operator.pow, y, copy=copy, xp=xp)

def min(
self,
Expand All @@ -349,7 +417,7 @@
"""Apply ``x[idx] = minimum(x[idx], y)`` and return the updated array."""
xp = array_namespace(self._x) if xp is None else xp
y = xp.asarray(y)
return self._op(_AtOp.MIN, xp.minimum, y, copy=copy, xp=xp)
return self._op(_AtOp.MIN, xp.minimum, xp.minimum, y, copy=copy, xp=xp)

def max(
self,
Expand All @@ -361,4 +429,4 @@
"""Apply ``x[idx] = maximum(x[idx], y)`` and return the updated array."""
xp = array_namespace(self._x) if xp is None else xp
y = xp.asarray(y)
return self._op(_AtOp.MAX, xp.maximum, y, copy=copy, xp=xp)
return self._op(_AtOp.MAX, xp.maximum, xp.maximum, y, copy=copy, xp=xp)
3 changes: 3 additions & 0 deletions src/array_api_extra/_lib/_utils/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
device,
is_array_api_strict_namespace,
is_cupy_namespace,
is_dask_array,
is_dask_namespace,
is_jax_array,
is_jax_namespace,
Expand All @@ -23,6 +24,7 @@
device,
is_array_api_strict_namespace,
is_cupy_namespace,
is_dask_array,
is_dask_namespace,
is_jax_array,
is_jax_namespace,
Expand All @@ -38,6 +40,7 @@
"device",
"is_array_api_strict_namespace",
"is_cupy_namespace",
"is_dask_array",
"is_dask_namespace",
"is_jax_array",
"is_jax_namespace",
Expand Down
1 change: 1 addition & 0 deletions src/array_api_extra/_lib/_utils/_compat.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def is_jax_namespace(xp: ModuleType, /) -> bool: ...
def is_numpy_namespace(xp: ModuleType, /) -> bool: ...
def is_pydata_sparse_namespace(xp: ModuleType, /) -> bool: ...
def is_torch_namespace(xp: ModuleType, /) -> bool: ...
def is_dask_array(x: object, /) -> bool: ...
def is_jax_array(x: object, /) -> bool: ...
def is_writeable_array(x: object, /) -> bool: ...
def size(x: Array, /) -> int | None: ...
87 changes: 67 additions & 20 deletions tests/test_at.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
import pickle
from collections.abc import Callable, Generator
from contextlib import contextmanager
Expand All @@ -15,6 +16,12 @@
from array_api_extra._lib._utils._typing import Array, Index
from array_api_extra.testing import lazy_xp_function

pytestmark = [
pytest.mark.skip_xp_backend(
Backend.SPARSE, reason="read-only backend without .at support"
)
]


def at_op( # type: ignore[no-any-explicit]
x: Array,
Expand Down Expand Up @@ -70,9 +77,6 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]:
xp_assert_equal(xp.all(array == array_orig), xp.asarray(copy))


@pytest.mark.skip_xp_backend(
Backend.SPARSE, reason="read-only backend without .at support"
)
@pytest.mark.parametrize(
("kwargs", "expect_copy"),
[
Expand Down Expand Up @@ -100,14 +104,7 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]:
[
(False, False),
(False, True),
pytest.param(
True,
False,
marks=(
pytest.mark.skip_xp_backend(Backend.JAX, reason="TODO special case"),
pytest.mark.skip_xp_backend(Backend.DASK, reason="TODO special case"),
),
),
(True, False), # Uses xp.where(idx, y, x) on JAX and Dask
pytest.param(
True,
True,
Expand Down Expand Up @@ -176,22 +173,72 @@ def test_alternate_index_syntax():
at(a, 0)[0].set(4)


@pytest.mark.parametrize("copy", [True, False])
@pytest.mark.parametrize(
"op", [_AtOp.ADD, _AtOp.SUBTRACT, _AtOp.MULTIPLY, _AtOp.DIVIDE, _AtOp.POWER]
)
def test_iops_incompatible_dtype(op: _AtOp, copy: bool):
@pytest.mark.parametrize("copy", [True, None])
@pytest.mark.parametrize("bool_mask", [False, True])
@pytest.mark.parametrize("op", list(_AtOp))
def test_incompatible_dtype(
xp: ModuleType, library: Backend, op: _AtOp, copy: bool | None, bool_mask: bool
):
"""Test that at() replicates the backend's behaviour for
in-place operations with incompatible dtypes.

Note:
Behavior is backend-specific, but only two behaviors are allowed:
1. raise an exception, or
2. return the same dtype as x, disregarding y.dtype (no broadcasting).

Note that __i<op>__ and __<op>__ behave differently, and we want to
replicate the behavior of __i<op>__:

>>> a = np.asarray([1, 2, 3])
>>> a / 1.5
array([0. , 0.66666667, 1.33333333])
>>> a /= 1.5
UFuncTypeError: Cannot cast ufunc 'divide' output from dtype('float64')
to dtype('int64') with casting rule 'same_kind'
"""
x = np.asarray([2, 4])
with pytest.raises(TypeError, match="Cannot cast ufunc"):
at_op(x, slice(None), op, 1.1, copy=copy)
x = xp.asarray([2, 4])
idx = xp.asarray([True, False]) if bool_mask else slice(None)
z = None

if library is Backend.JAX:
if bool_mask:
z = at_op(x, idx, op, 1.1, copy=copy)
else:
with pytest.warns(FutureWarning, match="cannot safely cast"):
z = at_op(x, idx, op, 1.1, copy=copy)

elif library is Backend.DASK:
if op in (_AtOp.MIN, _AtOp.MAX):
pytest.xfail(reason="need array-api-compat 1.11")
z = at_op(x, idx, op, 1.1, copy=copy)

elif library is Backend.ARRAY_API_STRICT and op is not _AtOp.SET:
with pytest.raises(Exception, match=r"cast|promote|dtype"):
at_op(x, idx, op, 1.1, copy=copy)

elif op in (_AtOp.SET, _AtOp.MIN, _AtOp.MAX):
# There is no __i<op>__ version of these operations
z = at_op(x, idx, op, 1.1, copy=copy)

else:
with pytest.raises(Exception, match=r"cast|promote|dtype"):
at_op(x, idx, op, 1.1, copy=copy)

assert z is None or z.dtype == x.dtype


def test_bool_mask_nd(xp: ModuleType):
x = xp.asarray([[1, 2, 3], [4, 5, 6]])
idx = xp.asarray([[True, False, False], [False, True, True]])
z = at_op(x, idx, _AtOp.SET, 0)
xp_assert_equal(z, xp.asarray([[0, 2, 3], [4, 0, 0]]))


@pytest.mark.skip_xp_backend(Backend.DASK, reason="FIXME need scipy's lazywhere")
@pytest.mark.parametrize("bool_mask", [False, True])
def test_no_inf_warnings(xp: ModuleType, bool_mask: bool):
x = xp.asarray([math.inf, 1.0, 2.0])
idx = ~xp.isinf(x) if bool_mask else slice(1, None)
# inf - inf -> nan with a warning
z = at_op(x, idx, _AtOp.SUBTRACT, math.inf)
xp_assert_equal(z, xp.asarray([math.inf, -math.inf, -math.inf]))
2 changes: 2 additions & 0 deletions vendor_tests/test_vendor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ def test_vendor_compat():
array_namespace,
device,
is_cupy_namespace,
is_dask_array,
is_dask_namespace,
is_jax_array,
is_jax_namespace,
Expand All @@ -20,6 +21,7 @@ def test_vendor_compat():
assert array_namespace(x) is xp
device(x)
assert not is_cupy_namespace(xp)
assert not is_dask_array(x)
assert not is_dask_namespace(xp)
assert not is_jax_array(x)
assert not is_jax_namespace(xp)
Expand Down
Loading