From b69c945d0a5db5fa50f4259998acea664fd5d599 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Tue, 21 Jan 2025 16:36:47 +0000 Subject: [PATCH 1/4] bool mask support --- src/array_api_extra/_lib/_at.py | 92 ++++++++++++++++++--- src/array_api_extra/_lib/_utils/_compat.py | 3 + src/array_api_extra/_lib/_utils/_compat.pyi | 1 + tests/test_at.py | 81 +++++++++++++++--- vendor_tests/test_vendor.py | 2 + 5 files changed, 154 insertions(+), 25 deletions(-) diff --git a/src/array_api_extra/_lib/_at.py b/src/array_api_extra/_lib/_at.py index 927a7300..e0dab54b 100644 --- a/src/array_api_extra/_lib/_at.py +++ b/src/array_api_extra/_lib/_at.py @@ -9,7 +9,12 @@ from types import ModuleType from typing import ClassVar, cast -from ._utils._compat import array_namespace, is_jax_array, is_writeable_array +from ._utils._compat import ( + array_namespace, + is_dask_array, + is_jax_array, + is_writeable_array, +) from ._utils._typing import Array, Index @@ -141,6 +146,25 @@ class at: # pylint: disable=invalid-name # numpydoc ignore=PR02 not explicitly covered by ``array-api-compat``, are not supported by update methods. + Boolean masks are supported on Dask and jitted JAX arrays exclusively + when `idx` has the same shape as `x` and `y` is 0-dimensional. + Note that this is support is not available in JAX's native + ``x.at[mask].set(y)``. + + This pattern:: + + >>> mask = m(x) + >>> x[mask] = f(x[mask]) + + Can't be replaced by `at`, as it won't work on Dask and JAX inside jax.jit:: + + >>> mask = m(x) + >>> x = xpx.at(x, mask).set(f(x[mask]) # Crash on Dask and jax.jit + + You should instead use:: + + >>> x = xp.where(m(x), f(x), x) + Examples -------- Given either of these equivalent expressions:: @@ -189,6 +213,7 @@ def _op( self, at_op: _AtOp, in_place_op: Callable[[Array, Array | object], Array] | None, + out_of_place_op: Callable[[Array, Array], Array] | None, y: Array | object, /, copy: bool | None, @@ -210,6 +235,16 @@ def _op( x[idx] = y + out_of_place_op : Callable[[Array, Array], Array] | None + Out-of-place operation to apply when idx is a boolean mask and the backend + doesn't support in-place updates:: + + x = xp.where(idx, out_of_place_op(x, y), x) + + If None:: + + x = xp.where(idx, y, x) + y : array or object Right-hand side of the operation. copy : bool or None @@ -223,6 +258,7 @@ def _op( Updated `x`. """ x, idx = self._x, self._idx + xp = array_namespace(x, y) if xp is None else xp if idx is _undef: msg = ( @@ -247,15 +283,41 @@ def _op( else: writeable = is_writeable_array(x) + # JAX inside jax.jit and Dask don't support in-place updates with boolean + # mask. However we can handle the common special case of 0-dimensional y + # with where(idx, y, x) instead. + if ( + (is_dask_array(idx) or is_jax_array(idx)) + and idx.dtype == xp.bool + and idx.shape == x.shape + ): + y_xp = xp.asarray(y, dtype=x.dtype) + if y_xp.ndim == 0: + if out_of_place_op: + # FIXME: suppress inf warnings on dask with lazywhere + out = xp.where(idx, out_of_place_op(x, y_xp), x) + # Undo int->float promotion on JAX after _AtOp.DIVIDE + out = xp.astype(out, x.dtype, copy=False) + else: + out = xp.where(idx, y_xp, x) + + if copy: + return out + x[()] = out + return x + # else: this will work on eager JAX and crash on jax.jit and Dask + if copy: if is_jax_array(x): # Use JAX's at[] func = cast(Callable[[Array], Array], getattr(x.at[idx], at_op.value)) - return func(y) + out = func(y) + # Undo int->float promotion on JAX after _AtOp.DIVIDE + return xp.astype(out, x.dtype, copy=False) + # Emulate at[] behaviour for non-JAX arrays # with a copy followed by an update - if xp is None: - xp = array_namespace(x) + x = xp.asarray(x, copy=True) if writeable is False: # A copy of a read-only numpy array is writeable @@ -283,7 +345,7 @@ def set( xp: ModuleType | None = None, ) -> Array: # numpydoc ignore=PR01,RT01 """Apply ``x[idx] = y`` and return the update array.""" - return self._op(_AtOp.SET, None, y, copy=copy, xp=xp) + return self._op(_AtOp.SET, None, None, y, copy=copy, xp=xp) def add( self, @@ -297,7 +359,7 @@ def add( # Note for this and all other methods based on _iop: # operator.iadd and operator.add subtly differ in behaviour, as # only iadd will trigger exceptions when y has an incompatible dtype. - return self._op(_AtOp.ADD, operator.iadd, y, copy=copy, xp=xp) + return self._op(_AtOp.ADD, operator.iadd, operator.add, y, copy=copy, xp=xp) def subtract( self, @@ -307,7 +369,9 @@ def subtract( xp: ModuleType | None = None, ) -> Array: # numpydoc ignore=PR01,RT01 """Apply ``x[idx] -= y`` and return the updated array.""" - return self._op(_AtOp.SUBTRACT, operator.isub, y, copy=copy, xp=xp) + return self._op( + _AtOp.SUBTRACT, operator.isub, operator.sub, y, copy=copy, xp=xp + ) def multiply( self, @@ -317,7 +381,9 @@ def multiply( xp: ModuleType | None = None, ) -> Array: # numpydoc ignore=PR01,RT01 """Apply ``x[idx] *= y`` and return the updated array.""" - return self._op(_AtOp.MULTIPLY, operator.imul, y, copy=copy, xp=xp) + return self._op( + _AtOp.MULTIPLY, operator.imul, operator.mul, y, copy=copy, xp=xp + ) def divide( self, @@ -327,7 +393,9 @@ def divide( xp: ModuleType | None = None, ) -> Array: # numpydoc ignore=PR01,RT01 """Apply ``x[idx] /= y`` and return the updated array.""" - return self._op(_AtOp.DIVIDE, operator.itruediv, y, copy=copy, xp=xp) + return self._op( + _AtOp.DIVIDE, operator.itruediv, operator.truediv, y, copy=copy, xp=xp + ) def power( self, @@ -337,7 +405,7 @@ def power( xp: ModuleType | None = None, ) -> Array: # numpydoc ignore=PR01,RT01 """Apply ``x[idx] **= y`` and return the updated array.""" - return self._op(_AtOp.POWER, operator.ipow, y, copy=copy, xp=xp) + return self._op(_AtOp.POWER, operator.ipow, operator.pow, y, copy=copy, xp=xp) def min( self, @@ -349,7 +417,7 @@ def min( """Apply ``x[idx] = minimum(x[idx], y)`` and return the updated array.""" xp = array_namespace(self._x) if xp is None else xp y = xp.asarray(y) - return self._op(_AtOp.MIN, xp.minimum, y, copy=copy, xp=xp) + return self._op(_AtOp.MIN, xp.minimum, xp.minimum, y, copy=copy, xp=xp) def max( self, @@ -361,4 +429,4 @@ def max( """Apply ``x[idx] = maximum(x[idx], y)`` and return the updated array.""" xp = array_namespace(self._x) if xp is None else xp y = xp.asarray(y) - return self._op(_AtOp.MAX, xp.maximum, y, copy=copy, xp=xp) + return self._op(_AtOp.MAX, xp.maximum, xp.maximum, y, copy=copy, xp=xp) diff --git a/src/array_api_extra/_lib/_utils/_compat.py b/src/array_api_extra/_lib/_utils/_compat.py index 707e9553..f4def9f3 100644 --- a/src/array_api_extra/_lib/_utils/_compat.py +++ b/src/array_api_extra/_lib/_utils/_compat.py @@ -8,6 +8,7 @@ device, is_array_api_strict_namespace, is_cupy_namespace, + is_dask_array, is_dask_namespace, is_jax_array, is_jax_namespace, @@ -23,6 +24,7 @@ device, is_array_api_strict_namespace, is_cupy_namespace, + is_dask_array, is_dask_namespace, is_jax_array, is_jax_namespace, @@ -38,6 +40,7 @@ "device", "is_array_api_strict_namespace", "is_cupy_namespace", + "is_dask_array", "is_dask_namespace", "is_jax_array", "is_jax_namespace", diff --git a/src/array_api_extra/_lib/_utils/_compat.pyi b/src/array_api_extra/_lib/_utils/_compat.pyi index 1e81c984..e409091e 100644 --- a/src/array_api_extra/_lib/_utils/_compat.pyi +++ b/src/array_api_extra/_lib/_utils/_compat.pyi @@ -25,6 +25,7 @@ def is_jax_namespace(xp: ModuleType, /) -> bool: ... def is_numpy_namespace(xp: ModuleType, /) -> bool: ... def is_pydata_sparse_namespace(xp: ModuleType, /) -> bool: ... def is_torch_namespace(xp: ModuleType, /) -> bool: ... +def is_dask_array(x: object, /) -> bool: ... def is_jax_array(x: object, /) -> bool: ... def is_writeable_array(x: object, /) -> bool: ... def size(x: Array, /) -> int | None: ... diff --git a/tests/test_at.py b/tests/test_at.py index 744e3aaf..aff7ce2a 100644 --- a/tests/test_at.py +++ b/tests/test_at.py @@ -1,3 +1,4 @@ +import math import pickle from collections.abc import Callable, Generator from contextlib import contextmanager @@ -100,14 +101,7 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]: [ (False, False), (False, True), - pytest.param( - True, - False, - marks=( - pytest.mark.skip_xp_backend(Backend.JAX, reason="TODO special case"), - pytest.mark.skip_xp_backend(Backend.DASK, reason="TODO special case"), - ), - ), + (True, False), # Uses xp.where(idx, y, x) on JAX and Dask pytest.param( True, True, @@ -176,11 +170,16 @@ def test_alternate_index_syntax(): at(a, 0)[0].set(4) -@pytest.mark.parametrize("copy", [True, False]) +@pytest.mark.skip_xp_backend( + Backend.SPARSE, reason="read-only backend without .at support" +) +@pytest.mark.parametrize("copy", [True, None]) @pytest.mark.parametrize( "op", [_AtOp.ADD, _AtOp.SUBTRACT, _AtOp.MULTIPLY, _AtOp.DIVIDE, _AtOp.POWER] ) -def test_iops_incompatible_dtype(op: _AtOp, copy: bool): +def test_iops_incompatible_dtype( + xp: ModuleType, library: Backend, op: _AtOp, copy: bool | None +): """Test that at() replicates the backend's behaviour for in-place operations with incompatible dtypes. @@ -192,6 +191,62 @@ def test_iops_incompatible_dtype(op: _AtOp, copy: bool): UFuncTypeError: Cannot cast ufunc 'divide' output from dtype('float64') to dtype('int64') with casting rule 'same_kind' """ - x = np.asarray([2, 4]) - with pytest.raises(TypeError, match="Cannot cast ufunc"): - at_op(x, slice(None), op, 1.1, copy=copy) + x = xp.asarray([2, 4]) + + if library is Backend.DASK: + z = at_op(x, slice(None), op, 1.1, copy=copy) + assert z.dtype == x.dtype + + elif library is Backend.JAX: + with pytest.warns(FutureWarning, match="cannot safely cast"): + z = at_op(x, slice(None), op, 1.1, copy=copy) + assert z.dtype == x.dtype + + else: + with pytest.raises(Exception, match=r"cast|promote|dtype"): + at_op(x, slice(None), op, 1.1, copy=copy) + + +@pytest.mark.skip_xp_backend( + Backend.SPARSE, reason="read-only backend without .at support" +) +@pytest.mark.parametrize( + "op", [_AtOp.ADD, _AtOp.SUBTRACT, _AtOp.MULTIPLY, _AtOp.DIVIDE, _AtOp.POWER] +) +def test_bool_mask_incompatible_dtype(xp: ModuleType, library: Backend, op: _AtOp): + """ + When xp.where(idx, y, x) would promote the dtype of the output + to y.dtype, at(x, idx).set(y) must retain x.dtype instead + """ + x = xp.asarray([1, 2]) + idx = xp.asarray([True, False]) + if library in (Backend.DASK, Backend.JAX): + z = at_op(x, idx, op, 1.1) + assert z.dtype == x.dtype + + else: + with pytest.raises(Exception, match=r"cast|promote|dtype"): + at_op(x, idx, op, 1.1) + + +@pytest.mark.skip_xp_backend( + Backend.SPARSE, reason="read-only backend without .at support" +) +def test_bool_mask_nd(xp: ModuleType): + x = xp.asarray([[1, 2, 3], [4, 5, 6]]) + idx = xp.asarray([[True, False, False], [False, True, True]]) + z = at_op(x, idx, _AtOp.SET, 0) + xp_assert_equal(z, xp.asarray([[0, 2, 3], [4, 0, 0]])) + + +@pytest.mark.skip_xp_backend( + Backend.SPARSE, reason="read-only backend without .at support" +) +@pytest.mark.skip_xp_backend(Backend.DASK, reason="FIXME need scipy's lazywhere") +@pytest.mark.parametrize("bool_mask", [False, True]) +def test_no_inf_warnings(xp: ModuleType, bool_mask: bool): + x = xp.asarray([math.inf, 1.0, 2.0]) + idx = ~xp.isinf(x) if bool_mask else slice(1, None) + # inf - inf -> nan with a warning + z = at_op(x, idx, _AtOp.SUBTRACT, math.inf) + xp_assert_equal(z, xp.asarray([math.inf, -math.inf, -math.inf])) diff --git a/vendor_tests/test_vendor.py b/vendor_tests/test_vendor.py index 38249378..914a0a1d 100644 --- a/vendor_tests/test_vendor.py +++ b/vendor_tests/test_vendor.py @@ -7,6 +7,7 @@ def test_vendor_compat(): array_namespace, device, is_cupy_namespace, + is_dask_array, is_dask_namespace, is_jax_array, is_jax_namespace, @@ -20,6 +21,7 @@ def test_vendor_compat(): assert array_namespace(x) is xp device(x) assert not is_cupy_namespace(xp) + assert not is_dask_array(x) assert not is_dask_namespace(xp) assert not is_jax_array(x) assert not is_jax_namespace(xp) From eb6b721580efd9461528abb3d2c4966365c5b367 Mon Sep 17 00:00:00 2001 From: Guido Imperiale Date: Sat, 25 Jan 2025 13:40:28 +0000 Subject: [PATCH 2/4] Update src/array_api_extra/_lib/_at.py Co-authored-by: Lucas Colley --- src/array_api_extra/_lib/_at.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/array_api_extra/_lib/_at.py b/src/array_api_extra/_lib/_at.py index e0dab54b..e5cb3875 100644 --- a/src/array_api_extra/_lib/_at.py +++ b/src/array_api_extra/_lib/_at.py @@ -148,7 +148,7 @@ class at: # pylint: disable=invalid-name # numpydoc ignore=PR02 Boolean masks are supported on Dask and jitted JAX arrays exclusively when `idx` has the same shape as `x` and `y` is 0-dimensional. - Note that this is support is not available in JAX's native + Note that this support is not available in JAX's native ``x.at[mask].set(y)``. This pattern:: From a990846c750de47445454a40b9df65308321b81a Mon Sep 17 00:00:00 2001 From: crusaderky Date: Sat, 25 Jan 2025 22:26:51 +0000 Subject: [PATCH 3/4] rework incompatible_dtype test --- tests/test_at.py | 91 +++++++++++++++++++++++------------------------- 1 file changed, 43 insertions(+), 48 deletions(-) diff --git a/tests/test_at.py b/tests/test_at.py index aff7ce2a..42376526 100644 --- a/tests/test_at.py +++ b/tests/test_at.py @@ -16,6 +16,12 @@ from array_api_extra._lib._utils._typing import Array, Index from array_api_extra.testing import lazy_xp_function +pytestmark = [ + pytest.mark.skip_xp_backend( + Backend.SPARSE, reason="read-only backend without .at support" + ) +] + def at_op( # type: ignore[no-any-explicit] x: Array, @@ -71,9 +77,6 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]: xp_assert_equal(xp.all(array == array_orig), xp.asarray(copy)) -@pytest.mark.skip_xp_backend( - Backend.SPARSE, reason="read-only backend without .at support" -) @pytest.mark.parametrize( ("kwargs", "expect_copy"), [ @@ -170,68 +173,63 @@ def test_alternate_index_syntax(): at(a, 0)[0].set(4) -@pytest.mark.skip_xp_backend( - Backend.SPARSE, reason="read-only backend without .at support" -) @pytest.mark.parametrize("copy", [True, None]) -@pytest.mark.parametrize( - "op", [_AtOp.ADD, _AtOp.SUBTRACT, _AtOp.MULTIPLY, _AtOp.DIVIDE, _AtOp.POWER] -) -def test_iops_incompatible_dtype( - xp: ModuleType, library: Backend, op: _AtOp, copy: bool | None +@pytest.mark.parametrize("bool_mask", [False, True]) +@pytest.mark.parametrize("op", list(_AtOp)) +def test_incompatible_dtype( + xp: ModuleType, library: Backend, op: _AtOp, copy: bool | None, bool_mask: bool ): """Test that at() replicates the backend's behaviour for in-place operations with incompatible dtypes. - Note: + Behavior is backend-specific, but only two behaviors are allowed: + 1. raise an exception, or + 2. return the same dtype as x, disregarding y.dtype (no broadcasting). + + Note that __i__ and ____ behave differently, and we want to + replicate the behavior of __i__: + >>> a = np.asarray([1, 2, 3]) >>> a / 1.5 array([0. , 0.66666667, 1.33333333]) >>> a /= 1.5 UFuncTypeError: Cannot cast ufunc 'divide' output from dtype('float64') to dtype('int64') with casting rule 'same_kind' + + See Also + -------- """ x = xp.asarray([2, 4]) - - if library is Backend.DASK: - z = at_op(x, slice(None), op, 1.1, copy=copy) - assert z.dtype == x.dtype - - elif library is Backend.JAX: - with pytest.warns(FutureWarning, match="cannot safely cast"): - z = at_op(x, slice(None), op, 1.1, copy=copy) - assert z.dtype == x.dtype - - else: + idx = xp.asarray([True, False]) if bool_mask else slice(None) + z = None + + if library is Backend.JAX: + if bool_mask: + z = at_op(x, idx, op, 1.1, copy=copy) + else: + with pytest.warns(FutureWarning, match="cannot safely cast"): + z = at_op(x, idx, op, 1.1, copy=copy) + + elif library is Backend.DASK: + if op in (_AtOp.MIN, _AtOp.MAX): + pytest.xfail(reason="need array-api-compat 1.11") + z = at_op(x, idx, op, 1.1, copy=copy) + + elif library is Backend.ARRAY_API_STRICT and op is not _AtOp.SET: with pytest.raises(Exception, match=r"cast|promote|dtype"): - at_op(x, slice(None), op, 1.1, copy=copy) + at_op(x, idx, op, 1.1, copy=copy) - -@pytest.mark.skip_xp_backend( - Backend.SPARSE, reason="read-only backend without .at support" -) -@pytest.mark.parametrize( - "op", [_AtOp.ADD, _AtOp.SUBTRACT, _AtOp.MULTIPLY, _AtOp.DIVIDE, _AtOp.POWER] -) -def test_bool_mask_incompatible_dtype(xp: ModuleType, library: Backend, op: _AtOp): - """ - When xp.where(idx, y, x) would promote the dtype of the output - to y.dtype, at(x, idx).set(y) must retain x.dtype instead - """ - x = xp.asarray([1, 2]) - idx = xp.asarray([True, False]) - if library in (Backend.DASK, Backend.JAX): - z = at_op(x, idx, op, 1.1) - assert z.dtype == x.dtype + elif op in (_AtOp.SET, _AtOp.MIN, _AtOp.MAX): + # There is no __i__ version of these operations + z = at_op(x, idx, op, 1.1, copy=copy) else: with pytest.raises(Exception, match=r"cast|promote|dtype"): - at_op(x, idx, op, 1.1) + at_op(x, idx, op, 1.1, copy=copy) + + assert z is None or z.dtype == x.dtype -@pytest.mark.skip_xp_backend( - Backend.SPARSE, reason="read-only backend without .at support" -) def test_bool_mask_nd(xp: ModuleType): x = xp.asarray([[1, 2, 3], [4, 5, 6]]) idx = xp.asarray([[True, False, False], [False, True, True]]) @@ -239,9 +237,6 @@ def test_bool_mask_nd(xp: ModuleType): xp_assert_equal(z, xp.asarray([[0, 2, 3], [4, 0, 0]])) -@pytest.mark.skip_xp_backend( - Backend.SPARSE, reason="read-only backend without .at support" -) @pytest.mark.skip_xp_backend(Backend.DASK, reason="FIXME need scipy's lazywhere") @pytest.mark.parametrize("bool_mask", [False, True]) def test_no_inf_warnings(xp: ModuleType, bool_mask: bool): From 803c6703797418e702b7713369a962c439053e65 Mon Sep 17 00:00:00 2001 From: Lucas Colley Date: Sat, 25 Jan 2025 23:08:14 +0000 Subject: [PATCH 4/4] typo --- tests/test_at.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/test_at.py b/tests/test_at.py index 42376526..e5c1bbee 100644 --- a/tests/test_at.py +++ b/tests/test_at.py @@ -195,9 +195,6 @@ def test_incompatible_dtype( >>> a /= 1.5 UFuncTypeError: Cannot cast ufunc 'divide' output from dtype('float64') to dtype('int64') with casting rule 'same_kind' - - See Also - -------- """ x = xp.asarray([2, 4]) idx = xp.asarray([True, False]) if bool_mask else slice(None)