From 6a15d6d06ca43bee4d9a0f666b48afe1a4e91f28 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 29 Jan 2025 16:22:06 +0000 Subject: [PATCH 1/3] ENH: Array API 2024.12: binary ops vs. Python scalars --- docs/conf.py | 1 + src/array_api_extra/_delegation.py | 4 +- src/array_api_extra/_lib/_funcs.py | 11 ++- src/array_api_extra/_lib/_utils/_helpers.py | 84 ++++++++++++++++++ tests/test_funcs.py | 68 ++++++++++++--- tests/test_utils.py | 97 ++++++++++++++++++++- 6 files changed, 246 insertions(+), 19 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 79000c96..afa3bd5e 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -53,6 +53,7 @@ intersphinx_mapping = { "python": ("https://docs.python.org/3", None), + "array-api": ("https://data-apis.org/array-api/draft", None), "jax": ("https://jax.readthedocs.io/en/latest", None), } diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index b7bc9a84..d2aec2e4 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -52,8 +52,8 @@ def isclose( Parameters ---------- - a, b : Array - Input arrays to compare. + a, b : Array | int | float | complex | bool + Input objects to compare. At least one must be an Array API object. rtol : array_like, optional The relative tolerance parameter (see Notes). atol : array_like, optional diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index 2fdc084a..fd1c023f 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -12,6 +12,7 @@ from ._at import at from ._utils import _compat, _helpers from ._utils._compat import array_namespace, is_jax_array +from ._utils._helpers import asarrays from ._utils._typing import Array __all__ = [ @@ -315,6 +316,7 @@ def isclose( xp: ModuleType, ) -> Array: # numpydoc ignore=PR01,RT01 """See docstring in array_api_extra._delegation.""" + a, b = asarrays(a, b, xp=xp) a_inexact = xp.isdtype(a.dtype, ("real floating", "complex floating")) b_inexact = xp.isdtype(b.dtype, ("real floating", "complex floating")) @@ -356,8 +358,8 @@ def kron(a: Array, b: Array, /, *, xp: ModuleType | None = None) -> Array: Parameters ---------- - a, b : array - Input arrays. + a, b : Array | int | float | complex + Input arrays or scalars. At least one must be an Array API object. xp : array_namespace, optional The standard-compatible namespace for `a` and `b`. Default: infer. @@ -420,10 +422,10 @@ def kron(a: Array, b: Array, /, *, xp: ModuleType | None = None) -> Array: """ if xp is None: xp = array_namespace(a, b) + a, b = asarrays(a, b, xp=xp) - b = xp.asarray(b) singletons = (1,) * (b.ndim - a.ndim) - a = xp.broadcast_to(xp.asarray(a), singletons + a.shape) + a = xp.broadcast_to(a, singletons + a.shape) nd_b, nd_a = b.ndim, a.ndim nd_max = max(nd_b, nd_a) @@ -583,6 +585,7 @@ def setdiff1d( """ if xp is None: xp = array_namespace(x1, x2) + x1, x2 = asarrays(x1, x2, xp=xp) if assume_unique: x1 = xp.reshape(x1, (-1,)) diff --git a/src/array_api_extra/_lib/_utils/_helpers.py b/src/array_api_extra/_lib/_utils/_helpers.py index 84efcc35..35908bcd 100644 --- a/src/array_api_extra/_lib/_utils/_helpers.py +++ b/src/array_api_extra/_lib/_utils/_helpers.py @@ -4,8 +4,10 @@ from __future__ import annotations from types import ModuleType +from typing import cast from . import _compat +from ._compat import is_array_api_obj, is_numpy_array from ._typing import Array __all__ = ["in1d", "mean"] @@ -91,3 +93,85 @@ def mean( mean_imag = xp.mean(x_imag, axis=axis, keepdims=keepdims) return mean_real + (mean_imag * xp.asarray(1j)) return xp.mean(x, axis=axis, keepdims=keepdims) + + +def is_python_scalar(x: object) -> bool: # numpydoc ignore=PR01,RT01 + """Return True if `x` is a Python scalar, False otherwise.""" + # isinstance(x, float) returns True for np.float64 + # isinstance(x, complex) returns True for np.complex128 + return isinstance(x, int | float | complex | bool) and not is_numpy_array(x) + + +def asarrays( + a: Array | int | float | complex | bool, + b: Array | int | float | complex | bool, + xp: ModuleType, +) -> tuple[Array, Array]: + """ + Ensure both `a` and `b` are arrays. + + If `b` is a python scalar, it is converted to the same dtype as `a`, and vice versa. + + Behavior is not specified when mixing a Python ``float`` and an array with an + integer data type; this may give ``float32``, ``float64``, or raise an exception. + Behavior is implementation-specific. + + Similarly, behavior is not specified when mixing a Python ``complex`` and an array + with a real-valued data type; this may give ``complex64``, ``complex128``, or raise + an exception. Behavior is implementation-specific. + + Parameters + ---------- + a, b : Array | int | float | complex | bool + Input arrays or scalars. At least one must be an array. + xp : ModuleType + The array API namespace. + + Returns + ------- + Array, Array + The input arrays, possibly converted to arrays if they were scalars. + + See Also + -------- + mixing-arrays-with-python-scalars : Array API specification for the behavior. + """ + a_scalar = is_python_scalar(a) + b_scalar = is_python_scalar(b) + if not a_scalar and not b_scalar: + return a, b # This includes misc. malformed input e.g. str + + swap = False + if a_scalar: + swap = True + b, a = a, b + + if is_array_api_obj(a): + # a is an Array API object + # b is a int | float | complex | bool + + # pyright doesn't like it if you reuse the same variable name + xa = cast(Array, a) + + # https://data-apis.org/array-api/draft/API_specification/type_promotion.html#mixing-arrays-with-python-scalars + same_dtype = { + bool: "bool", + int: ("integral", "real floating", "complex floating"), + float: ("real floating", "complex floating"), + complex: "complex floating", + } + kind = same_dtype[type(b)] # type: ignore[index] + if xp.isdtype(xa.dtype, kind): + xb = xp.asarray(b, dtype=xa.dtype) + else: + # Undefined behaviour. Let the function deal with it, if it can. + xb = xp.asarray(b) + + else: + # Neither a nor b are Array API objects. + # Note: we can only reach this point when one explicitly passes + # xp=xp to the calling function; otherwise we fail earlier on + # array_namespace(a, b). + xa, xb = xp.asarray(a), xp.asarray(b) + + return (xb, xa) if swap else (xa, xb) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 33a8f36e..2c265b23 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -401,6 +401,24 @@ def test_none_shape_bool(self, xp: ModuleType): a = a[a] xp_assert_equal(isclose(a, b), xp.asarray([True, False])) + @pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY, reason="xp=xp") + @pytest.mark.skip_xp_backend(Backend.TORCH, reason="Array API 2024.12 support") + def test_python_scalar(self, xp: ModuleType): + a = xp.asarray([0.0, 0.1], dtype=xp.float32) + xp_assert_equal(isclose(a, 0.0), xp.asarray([True, False])) + xp_assert_equal(isclose(0.0, a), xp.asarray([True, False])) + + a = xp.asarray([0, 1], dtype=xp.int16) + xp_assert_equal(isclose(a, 0), xp.asarray([True, False])) + xp_assert_equal(isclose(0, a), xp.asarray([True, False])) + + xp_assert_equal(isclose(0, 0, xp=xp), xp.asarray(True)) + xp_assert_equal(isclose(0, 1, xp=xp), xp.asarray(False)) + + def test_all_python_scalars(self): + with pytest.raises(TypeError, match="Unrecognized"): + isclose(0, 0) + def test_xp(self, xp: ModuleType): a = xp.asarray([0.0, 0.0]) b = xp.asarray([1e-9, 1e-4]) @@ -413,30 +431,22 @@ def test_basic(self, xp: ModuleType): # Using 0-dimensional array a = xp.asarray(1) b = xp.asarray([[1, 2], [3, 4]]) - k = xp.asarray([[1, 2], [3, 4]]) - xp_assert_equal(kron(a, b), k) - a = xp.asarray([[1, 2], [3, 4]]) - b = xp.asarray(1) - xp_assert_equal(kron(a, b), k) + xp_assert_equal(kron(a, b), b) + xp_assert_equal(kron(b, a), b) # Using 1-dimensional array a = xp.asarray([3]) b = xp.asarray([[1, 2], [3, 4]]) k = xp.asarray([[3, 6], [9, 12]]) xp_assert_equal(kron(a, b), k) - a = xp.asarray([[1, 2], [3, 4]]) - b = xp.asarray([3]) - xp_assert_equal(kron(a, b), k) + xp_assert_equal(kron(b, a), k) # Using 3-dimensional array a = xp.asarray([[[1]], [[2]]]) b = xp.asarray([[1, 2], [3, 4]]) k = xp.asarray([[[1, 2], [3, 4]], [[2, 4], [6, 8]]]) xp_assert_equal(kron(a, b), k) - a = xp.asarray([[1, 2], [3, 4]]) - b = xp.asarray([[[1]], [[2]]]) - k = xp.asarray([[[1, 2], [3, 4]], [[2, 4], [6, 8]]]) - xp_assert_equal(kron(a, b), k) + xp_assert_equal(kron(b, a), k) def test_kron_smoke(self, xp: ModuleType): a = xp.ones((3, 3)) @@ -474,6 +484,18 @@ def test_kron_shape( k = kron(a, b) assert k.shape == expected_shape + def test_python_scalar(self, xp: ModuleType): + a = 1 + # Test no dtype promotion to xp.asarray(a); use b.dtype + b = xp.asarray([[1, 2], [3, 4]], dtype=xp.int16) + xp_assert_equal(kron(a, b), b) + xp_assert_equal(kron(b, a), b) + xp_assert_equal(kron(1, 1, xp=xp), xp.asarray(1)) + + def test_all_python_scalars(self): + with pytest.raises(TypeError, match="Unrecognized"): + kron(1, 1) + def test_device(self, xp: ModuleType, device: Device): x1 = xp.asarray([1, 2, 3], device=device) x2 = xp.asarray([4, 5], device=device) @@ -601,6 +623,28 @@ def test_shapes( actual = setdiff1d(x1, x2, assume_unique=assume_unique) xp_assert_equal(actual, xp.empty((0,))) + @pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY, reason="xp=xp") + @pytest.mark.parametrize("assume_unique", [True, False]) + def test_python_scalar(self, xp: ModuleType, assume_unique: bool): + # Test no dtype promotion to xp.asarray(x2); use x1.dtype + x1 = xp.asarray([3, 1, 2], dtype=xp.int16) + x2 = 3 + actual = setdiff1d(x1, x2, assume_unique=assume_unique) + xp_assert_equal(actual, xp.asarray([1, 2], dtype=xp.int16)) + + actual = setdiff1d(x2, x1, assume_unique=assume_unique) + xp_assert_equal(actual, xp.asarray([], dtype=xp.int16)) + + xp_assert_equal( + setdiff1d(0, 0, assume_unique=assume_unique, xp=xp), + xp.asarray([0])[:0], # Default int dtype for backend + ) + + @pytest.mark.parametrize("assume_unique", [True, False]) + def test_all_python_scalars(self, assume_unique: bool): + with pytest.raises(TypeError, match="Unrecognized"): + setdiff1d(0, 0, assume_unique=assume_unique) + def test_device(self, xp: ModuleType, device: Device): x1 = xp.asarray([3, 8, 20], device=device) x2 = xp.asarray([2, 3, 4], device=device) diff --git a/tests/test_utils.py b/tests/test_utils.py index f710056b..0c2a6504 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,11 +1,12 @@ from types import ModuleType +import numpy as np import pytest from array_api_extra._lib import Backend from array_api_extra._lib._testing import xp_assert_equal from array_api_extra._lib._utils._compat import device as get_device -from array_api_extra._lib._utils._helpers import in1d +from array_api_extra._lib._utils._helpers import asarrays, in1d from array_api_extra._lib._utils._typing import Device from array_api_extra.testing import lazy_xp_function @@ -45,3 +46,97 @@ def test_xp(self, xp: ModuleType): expected = xp.asarray([True, False]) actual = in1d(x1, x2, xp=xp) xp_assert_equal(actual, expected) + + +@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no isdtype") +@pytest.mark.parametrize( + ("dtype", "b", "defined"), + [ + # Well-defined cases of dtype promotion from Python scalar to Array + # bool vs. bool + ("bool", True, True), + # int vs. xp.*int*, xp.float*, xp.complex* + ("int16", 1, True), + ("uint8", 1, True), + ("float32", 1, True), + ("float64", 1, True), + ("complex64", 1, True), + ("complex128", 1, True), + # float vs. xp.float, xp.complex + ("float32", 1.0, True), + ("float64", 1.0, True), + ("complex64", 1.0, True), + ("complex128", 1.0, True), + # complex vs. xp.complex + ("complex64", 1.0j, True), + ("complex128", 1.0j, True), + # Undefined cases + ("bool", 1, False), + ("int64", 1.0, False), + ("float64", 1.0j, False), + ], +) +def test_asarrays_array_vs_scalar( + dtype: str, b: int | float | complex, defined: bool, xp: ModuleType +): + a = xp.asarray(1, dtype=getattr(xp, dtype)) + + xa, xb = asarrays(a, b, xp) + assert xa.dtype == a.dtype + if defined: + assert xb.dtype == a.dtype + else: + assert xb.dtype == xp.asarray(b).dtype + + xbr, xar = asarrays(b, a, xp) + assert xar.dtype == xa.dtype + assert xbr.dtype == xb.dtype + + +def test_asarrays_scalar_vs_scalar(xp: ModuleType): + a, b = asarrays(1, 2.2, xp=xp) + assert a.dtype == xp.asarray(1).dtype # Default dtype + assert b.dtype == xp.asarray(2.2).dtype # Default dtype; not broadcasted + + +ALL_TYPES = ( + "int8", + "int16", + "int32", + "int64", + "uint8", + "uint16", + "uint32", + "uint64", + "float32", + "float64", + "complex64", + "complex128", + "bool", +) + + +@pytest.mark.parametrize("a_type", ALL_TYPES) +@pytest.mark.parametrize("b_type", ALL_TYPES) +def test_asarrays_array_vs_array(a_type: str, b_type: str, xp: ModuleType): + """ + Test that when both inputs of asarray are already Array API objects, + they are returned unchanged. + """ + a = xp.asarray(1, dtype=getattr(xp, a_type)) + b = xp.asarray(1, dtype=getattr(xp, b_type)) + xa, xb = asarrays(a, b, xp) + assert xa.dtype == a.dtype + assert xb.dtype == b.dtype + + +@pytest.mark.parametrize("dtype", [np.float64, np.complex128]) +def test_asarrays_numpy_generics(dtype: type): + """ + Test special case of np.float64 and np.complex128, + which are subclasses of float and complex. + """ + a = dtype(0) + xa, xb = asarrays(a, 0, xp=np) + assert xa.dtype == dtype + assert xb.dtype == dtype From 9c2457d6e0493988da7c0e1a940de1e087d41da4 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 29 Jan 2025 16:23:29 +0000 Subject: [PATCH 2/3] lint --- tests/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 0c2a6504..d9f50362 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -69,7 +69,7 @@ def test_xp(self, xp: ModuleType): ("complex128", 1.0, True), # complex vs. xp.complex ("complex64", 1.0j, True), - ("complex128", 1.0j, True), + ("complex128", 1.0j, True), # Undefined cases ("bool", 1, False), ("int64", 1.0, False), From 9b71aeb35b6697def66a4e3fc1625c1c5eef91e5 Mon Sep 17 00:00:00 2001 From: Lucas Colley Date: Wed, 29 Jan 2025 19:30:04 +0000 Subject: [PATCH 3/3] Apply suggestions from code review --- src/array_api_extra/_delegation.py | 2 +- src/array_api_extra/_lib/_funcs.py | 2 +- src/array_api_extra/_lib/_utils/_helpers.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index d2aec2e4..f3295c45 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -53,7 +53,7 @@ def isclose( Parameters ---------- a, b : Array | int | float | complex | bool - Input objects to compare. At least one must be an Array API object. + Input objects to compare. At least one must be an array. rtol : array_like, optional The relative tolerance parameter (see Notes). atol : array_like, optional diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index fd1c023f..f7eb8c88 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -359,7 +359,7 @@ def kron(a: Array, b: Array, /, *, xp: ModuleType | None = None) -> Array: Parameters ---------- a, b : Array | int | float | complex - Input arrays or scalars. At least one must be an Array API object. + Input arrays or scalars. At least one must be an array. xp : array_namespace, optional The standard-compatible namespace for `a` and `b`. Default: infer. diff --git a/src/array_api_extra/_lib/_utils/_helpers.py b/src/array_api_extra/_lib/_utils/_helpers.py index 35908bcd..b32a1081 100644 --- a/src/array_api_extra/_lib/_utils/_helpers.py +++ b/src/array_api_extra/_lib/_utils/_helpers.py @@ -125,7 +125,7 @@ def asarrays( a, b : Array | int | float | complex | bool Input arrays or scalars. At least one must be an array. xp : ModuleType - The array API namespace. + The standard-compatible namespace for the returned arrays. Returns -------