From f661a237264eeec08ff1ba6d282266c7ae2fade0 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Tue, 25 Jan 2022 13:51:46 +0000 Subject: [PATCH 01/40] Updates to op/elwise tests * Use `xps` dtype strategies where possible for better repr and dtype filtering * Helper for asserting broadcasted shapes * Helper `sh.iter_indices()` to wrap `ndindex` equivalent * Update `test_equal` with `sh.iter_indices()` --- array_api_tests/shape_helpers.py | 14 +- ...est_operators_and_elementwise_functions.py | 178 ++++++++++-------- 2 files changed, 112 insertions(+), 80 deletions(-) diff --git a/array_api_tests/shape_helpers.py b/array_api_tests/shape_helpers.py index 17dd7f6e..833ea60f 100644 --- a/array_api_tests/shape_helpers.py +++ b/array_api_tests/shape_helpers.py @@ -2,6 +2,8 @@ from itertools import product from typing import Iterator, List, Optional, Tuple, Union +from ndindex import iter_indices as _iter_indices + from .typing import Scalar, Shape __all__ = ["normalise_axis", "ndindex", "axis_ndindex", "axes_ndindex", "reshape"] @@ -18,12 +20,14 @@ def normalise_axis( def ndindex(shape): - """Iterator of n-D indices to an array + # TODO: remove + return (indices[0] for indices in iter_indices(shape)) + - Yields tuples of integers to index every element of an array of shape - `shape`. Same as np.ndindex(). - """ - return product(*[range(i) for i in shape]) +def iter_indices(*shapes, skip_axes=()): + """Wrapper for ndindex.iter_indices()""" + gen = _iter_indices(*shapes, skip_axes=skip_axes) + return ([i.raw for i in indices] for indices in gen) def axis_ndindex( diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index 5da74014..7ca10831 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -11,7 +11,7 @@ import math from enum import Enum, auto -from typing import Callable, List, Optional, Sequence, Union +from typing import Callable, List, Optional, Union import pytest from hypothesis import assume, given @@ -26,10 +26,19 @@ from . import shape_helpers as sh from . import xps from .algos import broadcast_shapes -from .typing import Array, DataType, Param, Scalar +from .typing import Array, DataType, Param, Scalar, Shape pytestmark = pytest.mark.ci + +def all_integer_dtypes() -> st.SearchStrategy[DataType]: + return xps.unsigned_integer_dtypes() | xps.integer_dtypes() + + +def boolean_and_all_integer_dtypes() -> st.SearchStrategy[DataType]: + return xps.boolean_dtypes() | all_integer_dtypes() + + # When appropiate, this module tests operators alongside their respective # elementwise methods. We do this by parametrizing a generalised test method # with every relevant method and operator. @@ -53,11 +62,9 @@ def make_unary_params( - elwise_func_name: str, dtypes: Sequence[DataType] + elwise_func_name: str, dtypes_strat: st.SearchStrategy[DataType] ) -> List[UnaryParam]: - if hh.FILTER_UNDEFINED_DTYPES: - dtypes = [d for d in dtypes if not isinstance(d, xp._UndefinedStub)] - strat = xps.arrays(dtype=st.sampled_from(dtypes), shape=hh.shapes()) + strat = xps.arrays(dtype=dtypes_strat, shape=hh.shapes()) func = getattr(xp, elwise_func_name) op_name = func_to_op[elwise_func_name] op = lambda x: getattr(x, op_name)() @@ -94,13 +101,12 @@ class FuncType(Enum): IOP = auto() +shapes_kw = {"min_side": 1} + + def make_binary_params( - elwise_func_name: str, dtypes: Sequence[DataType] + elwise_func_name: str, dtypes_strat: st.SearchStrategy[DataType] ) -> List[BinaryParam]: - if hh.FILTER_UNDEFINED_DTYPES: - dtypes = [d for d in dtypes if not isinstance(d, xp._UndefinedStub)] - dtypes_strat = st.sampled_from(dtypes) - def make_param( func_name: str, func_type: FuncType, right_is_scalar: bool ) -> BinaryParam: @@ -113,17 +119,25 @@ def make_param( shared_dtypes = st.shared(dtypes_strat) if right_is_scalar: - left_strat = xps.arrays(dtype=shared_dtypes, shape=hh.shapes()) + left_strat = xps.arrays(dtype=shared_dtypes, shape=hh.shapes(**shapes_kw)) right_strat = shared_dtypes.flatmap( lambda d: xps.from_dtype(d, **finite_kw) ) else: if func_type is FuncType.IOP: - shared_shapes = st.shared(hh.shapes()) + shared_shapes = st.shared(hh.shapes(**shapes_kw)) left_strat = xps.arrays(dtype=shared_dtypes, shape=shared_shapes) right_strat = xps.arrays(dtype=shared_dtypes, shape=shared_shapes) else: - left_strat, right_strat = hh.two_mutual_arrays(dtypes) + mutual_shapes = st.shared( + hh.mutually_broadcastable_shapes(2, **shapes_kw) + ) + left_strat = xps.arrays( + dtype=shared_dtypes, shape=mutual_shapes.map(lambda pair: pair[0]) + ) + right_strat = xps.arrays( + dtype=shared_dtypes, shape=mutual_shapes.map(lambda pair: pair[1]) + ) if func_type is FuncType.FUNC: func = getattr(xp, func_name) @@ -142,9 +156,7 @@ def func(l: Array, r: Union[Scalar, Array]) -> Array: def func(l: Array, r: Union[Scalar, Array]) -> Array: locals_ = {} - locals_[left_sym] = ah.asarray( - l, copy=True - ) # prevents left mutating + locals_[left_sym] = ah.asarray(l, copy=True) # prevents mutating l locals_[right_sym] = r exec(expr, locals_) return locals_[left_sym] @@ -200,7 +212,25 @@ def assert_binary_param_dtype( ) -@pytest.mark.parametrize(unary_argnames, make_unary_params("abs", dh.numeric_dtypes)) +def assert_binary_param_shape( + func_name: str, + left: Array, + right: Union[Array, Scalar], + right_is_scalar: bool, + res: Array, + res_name: str, + expected: Optional[Shape] = None, +): + if right_is_scalar: + in_shapes = (left.shape,) + else: + in_shapes = (left.shape, right.shape) # type: ignore + ph.assert_result_shape( + func_name, in_shapes, res.shape, expected, repr_name=f"{res_name}.shape" + ) + + +@pytest.mark.parametrize(unary_argnames, make_unary_params("abs", xps.numeric_dtypes())) @given(data=st.data()) def test_abs(func_name, func, strat, data): x = data.draw(strat, label="x") @@ -258,7 +288,9 @@ def test_acosh(x): ah.assert_exactly_equal(domain, codomain) -@pytest.mark.parametrize(binary_argnames, make_binary_params("add", dh.numeric_dtypes)) +@pytest.mark.parametrize( + binary_argnames, make_binary_params("add", xps.numeric_dtypes()) +) @given(data=st.data()) def test_add( func_name, @@ -384,7 +416,7 @@ def test_atanh(x): @pytest.mark.parametrize( - binary_argnames, make_binary_params("bitwise_and", dh.bool_and_all_int_dtypes) + binary_argnames, make_binary_params("bitwise_and", boolean_and_all_integer_dtypes()) ) @given(data=st.data()) def test_bitwise_and( @@ -432,7 +464,7 @@ def test_bitwise_and( @pytest.mark.parametrize( - binary_argnames, make_binary_params("bitwise_left_shift", dh.all_int_dtypes) + binary_argnames, make_binary_params("bitwise_left_shift", all_integer_dtypes()) ) @given(data=st.data()) def test_bitwise_left_shift( @@ -478,7 +510,8 @@ def test_bitwise_left_shift( @pytest.mark.parametrize( - unary_argnames, make_unary_params("bitwise_invert", dh.bool_and_all_int_dtypes) + unary_argnames, + make_unary_params("bitwise_invert", boolean_and_all_integer_dtypes()), ) @given(data=st.data()) def test_bitwise_invert(func_name, func, strat, data): @@ -505,7 +538,7 @@ def test_bitwise_invert(func_name, func, strat, data): @pytest.mark.parametrize( - binary_argnames, make_binary_params("bitwise_or", dh.bool_and_all_int_dtypes) + binary_argnames, make_binary_params("bitwise_or", boolean_and_all_integer_dtypes()) ) @given(data=st.data()) def test_bitwise_or( @@ -553,7 +586,7 @@ def test_bitwise_or( @pytest.mark.parametrize( - binary_argnames, make_binary_params("bitwise_right_shift", dh.all_int_dtypes) + binary_argnames, make_binary_params("bitwise_right_shift", all_integer_dtypes()) ) @given(data=st.data()) def test_bitwise_right_shift( @@ -598,7 +631,7 @@ def test_bitwise_right_shift( @pytest.mark.parametrize( - binary_argnames, make_binary_params("bitwise_xor", dh.bool_and_all_int_dtypes) + binary_argnames, make_binary_params("bitwise_xor", boolean_and_all_integer_dtypes()) ) @given(data=st.data()) def test_bitwise_xor( @@ -688,7 +721,9 @@ def test_cosh(x): ah.assert_exactly_equal(domain, codomain) -@pytest.mark.parametrize(binary_argnames, make_binary_params("divide", dh.float_dtypes)) +@pytest.mark.parametrize( + binary_argnames, make_binary_params("divide", xps.floating_dtypes()) +) @given(data=st.data()) def test_divide( func_name, @@ -714,7 +749,9 @@ def test_divide( # have those sorts in general for this module. -@pytest.mark.parametrize(binary_argnames, make_binary_params("equal", dh.all_dtypes)) +@pytest.mark.parametrize( + binary_argnames, make_binary_params("equal", xps.scalar_dtypes()) +) @given(data=st.data()) def test_equal( func_name, @@ -735,45 +772,32 @@ def test_equal( assert_binary_param_dtype( func_name, left, right, right_is_scalar, out, res_name, xp.bool ) - # NOTE: ah.assert_exactly_equal() itself uses ah.equal(), so we must be careful - # not to use it here. Otherwise, the test would be circular and - # meaningless. Instead, we implement this by iterating every element of - # the arrays and comparing them. The logic here is also used for the tests - # for the other elementwise functions that accept any input dtype but - # always return bool (greater(), greater_equal(), less(), less_equal(), - # and not_equal()). + assert_binary_param_shape(func_name, left, right, right_is_scalar, out, res_name) if not right_is_scalar: - # First we broadcast the arrays so that they can be indexed uniformly. - # TODO: it should be possible to skip this step if we instead generate - # indices to x1 and x2 that correspond to the broadcasted shapes. This - # would avoid the dependence in this test on broadcast_to(). - shape = broadcast_shapes(left.shape, right.shape) - ph.assert_shape(func_name, out.shape, shape) - _left = xp.broadcast_to(left, shape) - _right = xp.broadcast_to(right, shape) - - # Second, manually promote the dtypes. This is important. If the internal - # type promotion in ah.equal() is wrong, it will not be directly visible in - # the output type, but it can lead to wrong answers. For example, - # ah.equal(array(1.0, dtype=xp.float32), array(1.00000001, dtype=xp.float64)) will - # be wrong if the float64 is downcast to float32. # be wrong if the - # xp.float64 is downcast to float32. See the comment on - # test_elementwise_function_two_arg_bool_type_promotion() in - # test_type_promotion.py. The type promotion for ah.equal() is not *really* - # tested in that file, because doing so requires doing the consistency - - # check we do here rather than just checking the res dtype. + # We manually promote the dtypes as incorrect internal type promotion + # could lead to erroneous behaviour that we don't catch. For example + # + # >>> xp.equal( + # ... xp.asarray(1.0, dtype=xp.float32), + # ... xp.asarray(1.00000001, dtype=xp.float64), + # ... ) + # + # would incorrectly be True if float64 downcasts to float32 internally. promoted_dtype = dh.promotion_table[left.dtype, right.dtype] - _left = ah.asarray(_left, dtype=promoted_dtype) - _right = ah.asarray(_right, dtype=promoted_dtype) - + _left = xp.astype(left, promoted_dtype) + _right = xp.astype(right, promoted_dtype) scalar_type = dh.get_scalar_type(promoted_dtype) - for idx in sh.ndindex(shape): - x1_idx = _left[idx] - x2_idx = _right[idx] - out_idx = out[idx] - assert out_idx.shape == x1_idx.shape == x2_idx.shape # sanity check - assert bool(out_idx) == (scalar_type(x1_idx) == scalar_type(x2_idx)) + for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, out.shape): + scalar_l = scalar_type(_left[l_idx]) + scalar_r = scalar_type(_right[r_idx]) + expected = scalar_l == scalar_r + scalar_o = bool(out[o_idx]) + assert scalar_o == expected, ( + f"out[{o_idx}]={scalar_o}, but should be " + f"{left_sym}[{l_idx}]=={right_sym}[{r_idx}]={expected} " + f"({left_sym}[{l_idx}]={scalar_l}, {right_sym}[{r_idx}]={scalar_r}) " + f"[{func_name}()]" + ) @given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) @@ -821,7 +845,7 @@ def test_floor(x): @pytest.mark.parametrize( - binary_argnames, make_binary_params("floor_divide", dh.numeric_dtypes) + binary_argnames, make_binary_params("floor_divide", xps.numeric_dtypes()) ) @given(data=st.data()) def test_floor_divide( @@ -865,7 +889,7 @@ def test_floor_divide( @pytest.mark.parametrize( - binary_argnames, make_binary_params("greater", dh.numeric_dtypes) + binary_argnames, make_binary_params("greater", xps.numeric_dtypes()) ) @given(data=st.data()) def test_greater( @@ -908,7 +932,7 @@ def test_greater( @pytest.mark.parametrize( - binary_argnames, make_binary_params("greater_equal", dh.numeric_dtypes) + binary_argnames, make_binary_params("greater_equal", xps.numeric_dtypes()) ) @given(data=st.data()) def test_greater_equal( @@ -1007,7 +1031,9 @@ def test_isnan(x): assert bool(out[idx]) == math.isnan(s) -@pytest.mark.parametrize(binary_argnames, make_binary_params("less", dh.numeric_dtypes)) +@pytest.mark.parametrize( + binary_argnames, make_binary_params("less", xps.numeric_dtypes()) +) @given(data=st.data()) def test_less( func_name, @@ -1050,7 +1076,7 @@ def test_less( @pytest.mark.parametrize( - binary_argnames, make_binary_params("less_equal", dh.numeric_dtypes) + binary_argnames, make_binary_params("less_equal", xps.numeric_dtypes()) ) @given(data=st.data()) def test_less_equal( @@ -1209,7 +1235,7 @@ def test_logical_xor(x1, x2): @pytest.mark.parametrize( - binary_argnames, make_binary_params("multiply", dh.numeric_dtypes) + binary_argnames, make_binary_params("multiply", xps.numeric_dtypes()) ) @given(data=st.data()) def test_multiply( @@ -1236,7 +1262,7 @@ def test_multiply( @pytest.mark.parametrize( - unary_argnames, make_unary_params("negative", dh.numeric_dtypes) + unary_argnames, make_unary_params("negative", xps.numeric_dtypes()) ) @given(data=st.data()) def test_negative(func_name, func, strat, data): @@ -1263,7 +1289,7 @@ def test_negative(func_name, func, strat, data): @pytest.mark.parametrize( - binary_argnames, make_binary_params("not_equal", dh.all_dtypes) + binary_argnames, make_binary_params("not_equal", xps.scalar_dtypes()) ) @given(data=st.data()) def test_not_equal( @@ -1307,7 +1333,7 @@ def test_not_equal( @pytest.mark.parametrize( - unary_argnames, make_unary_params("positive", dh.numeric_dtypes) + unary_argnames, make_unary_params("positive", xps.numeric_dtypes()) ) @given(data=st.data()) def test_positive(func_name, func, strat, data): @@ -1321,7 +1347,9 @@ def test_positive(func_name, func, strat, data): ah.assert_exactly_equal(out, x) -@pytest.mark.parametrize(binary_argnames, make_binary_params("pow", dh.numeric_dtypes)) +@pytest.mark.parametrize( + binary_argnames, make_binary_params("pow", xps.numeric_dtypes()) +) @given(data=st.data()) def test_pow( func_name, @@ -1357,7 +1385,7 @@ def test_pow( @pytest.mark.parametrize( - binary_argnames, make_binary_params("remainder", dh.numeric_dtypes) + binary_argnames, make_binary_params("remainder", xps.numeric_dtypes()) ) @given(data=st.data()) def test_remainder( @@ -1456,7 +1484,7 @@ def test_sqrt(x): @pytest.mark.parametrize( - binary_argnames, make_binary_params("subtract", dh.numeric_dtypes) + binary_argnames, make_binary_params("subtract", xps.numeric_dtypes()) ) @given(data=st.data()) def test_subtract( From a590f8dd36f01d61face407772c12f1ea3d74413 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Tue, 25 Jan 2022 14:27:31 +0000 Subject: [PATCH 02/40] `sh.fmt_idx()` helper --- array_api_tests/meta/test_utils.py | 19 +++++++++++++ array_api_tests/shape_helpers.py | 43 +++++++++++++++++++++++++++--- array_api_tests/typing.py | 6 ++++- 3 files changed, 64 insertions(+), 4 deletions(-) diff --git a/array_api_tests/meta/test_utils.py b/array_api_tests/meta/test_utils.py index 3b28b9a9..1b188df3 100644 --- a/array_api_tests/meta/test_utils.py +++ b/array_api_tests/meta/test_utils.py @@ -82,3 +82,22 @@ def test_axes_ndindex(shape, axes, expected): ) def test_roll_ndindex(shape, shifts, axes, expected): assert list(roll_ndindex(shape, shifts, axes)) == expected + + +@pytest.mark.parametrize( + "idx, expected", + [ + ((), "x"), + (42, "x[42]"), + ((42,), "x[42]"), + (slice(None, 2), "x[:2]"), + (slice(2, None), "x[2:]"), + (slice(0, 2), "x[0:2]"), + (slice(0, 2, -1), "x[0:2:-1]"), + (slice(None, None, -1), "x[::-1]"), + (slice(None, None), "x[:]"), + (..., "x[...]"), + ], +) +def test_fmt_idx(idx, expected): + assert sh.fmt_idx("x", idx) == expected diff --git a/array_api_tests/shape_helpers.py b/array_api_tests/shape_helpers.py index 833ea60f..8e260ff9 100644 --- a/array_api_tests/shape_helpers.py +++ b/array_api_tests/shape_helpers.py @@ -4,9 +4,16 @@ from ndindex import iter_indices as _iter_indices -from .typing import Scalar, Shape +from .typing import AtomicIndex, Index, Scalar, Shape -__all__ = ["normalise_axis", "ndindex", "axis_ndindex", "axes_ndindex", "reshape"] +__all__ = [ + "normalise_axis", + "ndindex", + "axis_ndindex", + "axes_ndindex", + "reshape", + "fmt_idx", +] def normalise_axis( @@ -64,7 +71,7 @@ def axes_ndindex(shape: Shape, axes: Tuple[int, ...]) -> Iterator[List[Shape]]: yield list(indices) -def reshape(flat_seq: List[Scalar], shape: Shape) -> Union[Scalar, List[Scalar]]: +def reshape(flat_seq: List[Scalar], shape: Shape) -> Union[Scalar, List]: """Reshape a flat sequence""" if any(s == 0 for s in shape): raise ValueError( @@ -79,3 +86,33 @@ def reshape(flat_seq: List[Scalar], shape: Shape) -> Union[Scalar, List[Scalar]] size = len(flat_seq) n = math.prod(shape[1:]) return [reshape(flat_seq[i * n : (i + 1) * n], shape[1:]) for i in range(size // n)] + + +def fmt_i(i: AtomicIndex) -> str: + if isinstance(i, int): + return str(i) + elif isinstance(i, slice): + res = "" + if i.start is not None: + res += str(i.start) + res += ":" + if i.stop is not None: + res += str(i.stop) + if i.step is not None: + res += f":{i.step}" + return res + else: + return "..." + + +def fmt_idx(sym: str, idx: Index) -> str: + if idx == (): + return sym + res = f"{sym}[" + _idx = idx if isinstance(idx, tuple) else (idx,) + if len(_idx) == 1: + res += fmt_i(_idx[0]) + else: + res += ", ".join(fmt_i(i) for i in _idx) + res += "]" + return res diff --git a/array_api_tests/typing.py b/array_api_tests/typing.py index 286ce21b..da8652ae 100644 --- a/array_api_tests/typing.py +++ b/array_api_tests/typing.py @@ -1,4 +1,4 @@ -from typing import Tuple, Type, Union, Any +from typing import Any, Tuple, Type, Union __all__ = [ "DataType", @@ -6,6 +6,8 @@ "ScalarType", "Array", "Shape", + "AtomicIndex", + "Index", "Param", ] @@ -14,4 +16,6 @@ ScalarType = Union[Type[bool], Type[int], Type[float]] Array = Any Shape = Tuple[int, ...] +AtomicIndex = Union[int, "ellipsis", slice] # noqa +Index = Union[AtomicIndex, Tuple[AtomicIndex, ...]] Param = Tuple From d7e5e639b5fcd60bd80cdf0ace1d5de9f99774bd Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Wed, 26 Jan 2022 09:32:15 +0000 Subject: [PATCH 03/40] Better values testing in `test_not_equal` --- array_api_tests/shape_helpers.py | 4 +- ...est_operators_and_elementwise_functions.py | 74 +++++++++++++------ 2 files changed, 53 insertions(+), 25 deletions(-) diff --git a/array_api_tests/shape_helpers.py b/array_api_tests/shape_helpers.py index 8e260ff9..98dbff85 100644 --- a/array_api_tests/shape_helpers.py +++ b/array_api_tests/shape_helpers.py @@ -1,4 +1,5 @@ import math +from functools import lru_cache from itertools import product from typing import Iterator, List, Optional, Tuple, Union @@ -27,7 +28,7 @@ def normalise_axis( def ndindex(shape): - # TODO: remove + """Yield every index of shape""" return (indices[0] for indices in iter_indices(shape)) @@ -105,6 +106,7 @@ def fmt_i(i: AtomicIndex) -> str: return "..." +@lru_cache def fmt_idx(sym: str, idx: Index) -> str: if idx == (): return sym diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index 7ca10831..4b2a1f51 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -312,6 +312,7 @@ def test_add( reject() assert_binary_param_dtype(func_name, left, right, right_is_scalar, res, res_name) + assert_binary_param_shape(func_name, left, right, right_is_scalar, res, res_name) if not right_is_scalar: # add is commutative expected = func(right, left) @@ -773,16 +774,28 @@ def test_equal( func_name, left, right, right_is_scalar, out, res_name, xp.bool ) assert_binary_param_shape(func_name, left, right, right_is_scalar, out, res_name) - if not right_is_scalar: + if right_is_scalar: + scalar_type = dh.get_scalar_type(left.dtype) + for idx in sh.ndindex(left.shape): + scalar_l = scalar_type(left[idx]) + expected = scalar_l == right + scalar_o = bool(out[idx]) + f_l = sh.fmt_idx(left_sym, idx) + f_o = sh.fmt_idx(res_name, idx) + assert scalar_o == expected, ( + f"{f_o}={scalar_o}, but should be ({f_l} == {right})={expected} " + f"[{func_name}()]\n{f_l}={scalar_l}" + ) + else: # We manually promote the dtypes as incorrect internal type promotion - # could lead to erroneous behaviour that we don't catch. For example + # could lead to false positives. For example # # >>> xp.equal( # ... xp.asarray(1.0, dtype=xp.float32), # ... xp.asarray(1.00000001, dtype=xp.float64), # ... ) # - # would incorrectly be True if float64 downcasts to float32 internally. + # would erroneously be True if float64 downcasted to float32. promoted_dtype = dh.promotion_table[left.dtype, right.dtype] _left = xp.astype(left, promoted_dtype) _right = xp.astype(right, promoted_dtype) @@ -792,11 +805,12 @@ def test_equal( scalar_r = scalar_type(_right[r_idx]) expected = scalar_l == scalar_r scalar_o = bool(out[o_idx]) + f_l = sh.fmt_idx(left_sym, l_idx) + f_r = sh.fmt_idx(right_sym, r_idx) + f_o = sh.fmt_idx(res_name, o_idx) assert scalar_o == expected, ( - f"out[{o_idx}]={scalar_o}, but should be " - f"{left_sym}[{l_idx}]=={right_sym}[{r_idx}]={expected} " - f"({left_sym}[{l_idx}]={scalar_l}, {right_sym}[{r_idx}]={scalar_r}) " - f"[{func_name}()]" + f"{f_o}={scalar_o}, but should be ({f_l} == {f_r})={expected} " + f"[{func_name}()]\n{f_l}={scalar_l}, {f_r}={scalar_r}" ) @@ -1311,25 +1325,37 @@ def test_not_equal( assert_binary_param_dtype( func_name, left, right, right_is_scalar, out, res_name, xp.bool ) - if not right_is_scalar: - # TODO: generate indices without broadcasting arrays (see test_equal comment) - - shape = broadcast_shapes(left.shape, right.shape) - ph.assert_shape(func_name, out.shape, shape) - _left = xp.broadcast_to(left, shape) - _right = xp.broadcast_to(right, shape) - + assert_binary_param_shape(func_name, left, right, right_is_scalar, out, res_name) + if right_is_scalar: + scalar_type = dh.get_scalar_type(left.dtype) + for idx in sh.ndindex(left.shape): + scalar_l = scalar_type(left[idx]) + expected = scalar_l != right + scalar_o = bool(out[idx]) + f_l = sh.fmt_idx(left_sym, idx) + f_o = sh.fmt_idx(res_name, idx) + assert scalar_o == expected, ( + f"{f_o}={scalar_o}, but should be ({f_l} != {right})={expected} " + f"[{func_name}()]\n{f_l}={scalar_l}" + ) + else: + # See test_equal note promoted_dtype = dh.promotion_table[left.dtype, right.dtype] - _left = ah.asarray(_left, dtype=promoted_dtype) - _right = ah.asarray(_right, dtype=promoted_dtype) - + _left = xp.astype(left, promoted_dtype) + _right = xp.astype(right, promoted_dtype) scalar_type = dh.get_scalar_type(promoted_dtype) - for idx in sh.ndindex(shape): - out_idx = out[idx] - x1_idx = _left[idx] - x2_idx = _right[idx] - assert out_idx.shape == x1_idx.shape == x2_idx.shape # sanity check - assert bool(out_idx) == (scalar_type(x1_idx) != scalar_type(x2_idx)) + for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, out.shape): + scalar_l = scalar_type(_left[l_idx]) + scalar_r = scalar_type(_right[r_idx]) + expected = scalar_l != scalar_r + scalar_o = bool(out[o_idx]) + f_l = sh.fmt_idx(left_sym, l_idx) + f_r = sh.fmt_idx(right_sym, r_idx) + f_o = sh.fmt_idx(res_name, o_idx) + assert scalar_o == expected, ( + f"{f_o}={scalar_o}, but should be ({f_l} != {f_r})={expected} " + f"[{func_name}()]\n{f_l}={scalar_l}, {f_r}={scalar_r}" + ) @pytest.mark.parametrize( From 1a54bd43d04d46fbcfe09da8832f586ea21febf2 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Wed, 26 Jan 2022 10:18:46 +0000 Subject: [PATCH 04/40] Better values testing for bitwise op/elwise tests --- ...est_operators_and_elementwise_functions.py | 311 ++++++++++++------ 1 file changed, 202 insertions(+), 109 deletions(-) diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index 4b2a1f51..06fe8101 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -437,31 +437,52 @@ def test_bitwise_and( res = func(left, right) assert_binary_param_dtype(func_name, left, right, right_is_scalar, res, res_name) - if not right_is_scalar: - # TODO: generate indices without broadcasting arrays (see test_equal comment) - shape = broadcast_shapes(left.shape, right.shape) - ph.assert_shape(func_name, res.shape, shape, repr_name=f"{res_name}.shape") - _left = xp.broadcast_to(left, shape) - _right = xp.broadcast_to(right, shape) - - # Compare against the Python & operator. - if res.dtype == xp.bool: - for idx in sh.ndindex(res.shape): - s_left = bool(_left[idx]) - s_right = bool(_right[idx]) - s_res = bool(res[idx]) - assert (s_left and s_right) == s_res - else: - for idx in sh.ndindex(res.shape): - s_left = int(_left[idx]) - s_right = int(_right[idx]) - s_res = int(res[idx]) - s_and = ah.int_to_dtype( - s_left & s_right, + assert_binary_param_shape(func_name, left, right, right_is_scalar, res, res_name) + scalar_type = dh.get_scalar_type(res.dtype) + if right_is_scalar: + for idx in sh.ndindex(res.shape): + scalar_l = scalar_type(left[idx]) + if res.dtype == xp.bool: + expected = scalar_l and right + else: + # for mypy + assert isinstance(scalar_l, int) + assert isinstance(right, int) + expected = ah.int_to_dtype( + scalar_l & right, + dh.dtype_nbits[res.dtype], + dh.dtype_signed[res.dtype], + ) + scalar_o = scalar_type(res[idx]) + f_l = sh.fmt_idx(left_sym, idx) + f_o = sh.fmt_idx(res_name, idx) + assert scalar_o == expected, ( + f"{f_o}={scalar_o}, but should be ({f_l} & {right})={expected} " + f"[{func_name}()]\n{f_l}={scalar_l}" + ) + else: + for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, res.shape): + scalar_l = scalar_type(left[l_idx]) + scalar_r = scalar_type(right[r_idx]) + if res.dtype == xp.bool: + expected = scalar_l and scalar_r + else: + # for mypy + assert isinstance(scalar_l, int) + assert isinstance(scalar_r, int) + expected = ah.int_to_dtype( + scalar_l & scalar_r, dh.dtype_nbits[res.dtype], dh.dtype_signed[res.dtype], ) - assert s_and == s_res + scalar_o = scalar_type(res[o_idx]) + f_l = sh.fmt_idx(left_sym, l_idx) + f_r = sh.fmt_idx(right_sym, r_idx) + f_o = sh.fmt_idx(res_name, o_idx) + assert scalar_o == expected, ( + f"{f_o}={scalar_o}, but should be ({f_l} & {f_r})={expected} " + f"[{func_name}()]\n{f_l}={scalar_l}, {f_r}={scalar_r}" + ) @pytest.mark.parametrize( @@ -489,25 +510,41 @@ def test_bitwise_left_shift( res = func(left, right) assert_binary_param_dtype(func_name, left, right, right_is_scalar, res, res_name) - if not right_is_scalar: - # TODO: generate indices without broadcasting arrays (see test_equal comment) - shape = broadcast_shapes(left.shape, right.shape) - ph.assert_shape(func_name, res.shape, shape, repr_name=f"{res_name}.shape") - _left = xp.broadcast_to(left, shape) - _right = xp.broadcast_to(right, shape) - - # Compare against the Python << operator. + assert_binary_param_shape(func_name, left, right, right_is_scalar, res, res_name) + if right_is_scalar: for idx in sh.ndindex(res.shape): - s_left = int(_left[idx]) - s_right = int(_right[idx]) - s_res = int(res[idx]) - s_shift = ah.int_to_dtype( + scalar_l = int(left[idx]) + expected = ah.int_to_dtype( # We avoid shifting very large ints - s_left << s_right if s_right < dh.dtype_nbits[res.dtype] else 0, + scalar_l << right if right < dh.dtype_nbits[res.dtype] else 0, dh.dtype_nbits[res.dtype], dh.dtype_signed[res.dtype], ) - assert s_shift == s_res + scalar_o = int(res[idx]) + f_l = sh.fmt_idx(left_sym, idx) + f_o = sh.fmt_idx(res_name, idx) + assert scalar_o == expected, ( + f"{f_o}={scalar_o}, but should be ({f_l} << {right})={expected} " + f"[{func_name}()]\n{f_l}={scalar_l}" + ) + else: + for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, res.shape): + scalar_l = int(left[l_idx]) + scalar_r = int(right[r_idx]) + expected = ah.int_to_dtype( + # We avoid shifting very large ints + scalar_l << scalar_r if scalar_r < dh.dtype_nbits[res.dtype] else 0, + dh.dtype_nbits[res.dtype], + dh.dtype_signed[res.dtype], + ) + scalar_o = int(res[o_idx]) + f_l = sh.fmt_idx(left_sym, l_idx) + f_r = sh.fmt_idx(right_sym, r_idx) + f_o = sh.fmt_idx(res_name, o_idx) + assert scalar_o == expected, ( + f"{f_o}={scalar_o}, but should be ({f_l} << {f_r})={expected} " + f"[{func_name}()]\n{f_l}={scalar_l}, {f_r}={scalar_r}" + ) @pytest.mark.parametrize( @@ -522,20 +559,23 @@ def test_bitwise_invert(func_name, func, strat, data): ph.assert_dtype(func_name, x.dtype, out.dtype) ph.assert_shape(func_name, out.shape, x.shape) - # Compare against the Python ~ operator. - if out.dtype == xp.bool: - for idx in sh.ndindex(out.shape): - s_x = bool(x[idx]) - s_out = bool(out[idx]) - assert (not s_x) == s_out - else: - for idx in sh.ndindex(out.shape): - s_x = int(x[idx]) - s_out = int(out[idx]) - s_invert = ah.int_to_dtype( - ~s_x, dh.dtype_nbits[out.dtype], dh.dtype_signed[out.dtype] + for idx in sh.ndindex(out.shape): + if out.dtype == xp.bool: + scalar_x = bool(x[idx]) + scalar_o = bool(out[idx]) + expected = not scalar_x + else: + scalar_x = int(x[idx]) + scalar_o = int(out[idx]) + expected = ah.int_to_dtype( + ~scalar_x, dh.dtype_nbits[out.dtype], dh.dtype_signed[out.dtype] ) - assert s_invert == s_out + f_x = sh.fmt_idx("x", idx) + f_o = sh.fmt_idx("out", idx) + assert scalar_o == expected, ( + f"{f_o}={scalar_o}, but should be ~{f_x}={scalar_x} " + f"[{func_name}()]\n{f_x}={scalar_x}" + ) @pytest.mark.parametrize( @@ -559,31 +599,50 @@ def test_bitwise_or( res = func(left, right) assert_binary_param_dtype(func_name, left, right, right_is_scalar, res, res_name) - if not right_is_scalar: - # TODO: generate indices without broadcasting arrays (see test_equal comment) - shape = broadcast_shapes(left.shape, right.shape) - ph.assert_shape(func_name, res.shape, shape, repr_name=f"{res_name}.shape") - _left = xp.broadcast_to(left, shape) - _right = xp.broadcast_to(right, shape) - - # Compare against the Python | operator. - if res.dtype == xp.bool: - for idx in sh.ndindex(res.shape): - s_left = bool(_left[idx]) - s_right = bool(_right[idx]) - s_res = bool(res[idx]) - assert (s_left or s_right) == s_res - else: - for idx in sh.ndindex(res.shape): - s_left = int(_left[idx]) - s_right = int(_right[idx]) - s_res = int(res[idx]) - s_or = ah.int_to_dtype( - s_left | s_right, + assert_binary_param_shape(func_name, left, right, right_is_scalar, res, res_name) + if right_is_scalar: + for idx in sh.ndindex(res.shape): + if res.dtype == xp.bool: + scalar_l = bool(left[idx]) + scalar_o = bool(res[idx]) + expected = scalar_l or right + else: + scalar_l = int(left[idx]) + scalar_o = int(res[idx]) + expected = ah.int_to_dtype( + scalar_l | right, + dh.dtype_nbits[res.dtype], + dh.dtype_signed[res.dtype], + ) + f_l = sh.fmt_idx(left_sym, idx) + f_o = sh.fmt_idx(res_name, idx) + assert scalar_o == expected, ( + f"{f_o}={scalar_o}, but should be ({f_l} | {right})={expected} " + f"[{func_name}()]\n{f_l}={scalar_l}" + ) + else: + for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, res.shape): + if res.dtype == xp.bool: + scalar_l = bool(left[l_idx]) + scalar_r = bool(right[r_idx]) + scalar_o = bool(res[o_idx]) + expected = scalar_l or scalar_r + else: + scalar_l = int(left[l_idx]) + scalar_r = int(right[r_idx]) + scalar_o = int(res[o_idx]) + expected = ah.int_to_dtype( + scalar_l | scalar_r, dh.dtype_nbits[res.dtype], dh.dtype_signed[res.dtype], ) - assert s_or == s_res + f_l = sh.fmt_idx(left_sym, l_idx) + f_r = sh.fmt_idx(right_sym, r_idx) + f_o = sh.fmt_idx(res_name, o_idx) + assert scalar_o == expected, ( + f"{f_o}={scalar_o}, but should be ({f_l} | {f_r})={expected} " + f"[{func_name}()]\n{f_l}={scalar_l}, {f_r}={scalar_r}" + ) @pytest.mark.parametrize( @@ -611,24 +670,39 @@ def test_bitwise_right_shift( res = func(left, right) assert_binary_param_dtype(func_name, left, right, right_is_scalar, res, res_name) - if not right_is_scalar: - # TODO: generate indices without broadcasting arrays (see test_equal comment) - shape = broadcast_shapes(left.shape, right.shape) - ph.assert_shape( - "bitwise_right_shift", res.shape, shape, repr_name=f"{res_name}.shape" - ) - _left = xp.broadcast_to(left, shape) - _right = xp.broadcast_to(right, shape) - - # Compare against the Python >> operator. + assert_binary_param_shape(func_name, left, right, right_is_scalar, res, res_name) + if right_is_scalar: for idx in sh.ndindex(res.shape): - s_left = int(_left[idx]) - s_right = int(_right[idx]) - s_res = int(res[idx]) - s_shift = ah.int_to_dtype( - s_left >> s_right, dh.dtype_nbits[res.dtype], dh.dtype_signed[res.dtype] + scalar_l = int(left[idx]) + expected = ah.int_to_dtype( + scalar_l >> right, + dh.dtype_nbits[res.dtype], + dh.dtype_signed[res.dtype], + ) + scalar_o = int(res[idx]) + f_l = sh.fmt_idx(left_sym, idx) + f_o = sh.fmt_idx(res_name, idx) + assert scalar_o == expected, ( + f"{f_o}={scalar_o}, but should be ({f_l} >> {right})={expected} " + f"[{func_name}()]\n{f_l}={scalar_l}" + ) + else: + for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, res.shape): + scalar_l = int(left[l_idx]) + scalar_r = int(right[r_idx]) + expected = ah.int_to_dtype( + scalar_l >> scalar_r, + dh.dtype_nbits[res.dtype], + dh.dtype_signed[res.dtype], + ) + scalar_o = int(res[o_idx]) + f_l = sh.fmt_idx(left_sym, l_idx) + f_r = sh.fmt_idx(right_sym, r_idx) + f_o = sh.fmt_idx(res_name, o_idx) + assert scalar_o == expected, ( + f"{f_o}={scalar_o}, but should be ({f_l} >> {f_r})={expected} " + f"[{func_name}()]\n{f_l}={scalar_l}, {f_r}={scalar_r}" ) - assert s_shift == s_res @pytest.mark.parametrize( @@ -652,31 +726,50 @@ def test_bitwise_xor( res = func(left, right) assert_binary_param_dtype(func_name, left, right, right_is_scalar, res, res_name) - if not right_is_scalar: - # TODO: generate indices without broadcasting arrays (see test_equal comment) - shape = broadcast_shapes(left.shape, right.shape) - ph.assert_shape(func_name, res.shape, shape, repr_name=f"{res_name}.shape") - _left = xp.broadcast_to(left, shape) - _right = xp.broadcast_to(right, shape) - - # Compare against the Python ^ operator. - if res.dtype == xp.bool: - for idx in sh.ndindex(res.shape): - s_left = bool(_left[idx]) - s_right = bool(_right[idx]) - s_res = bool(res[idx]) - assert (s_left ^ s_right) == s_res - else: - for idx in sh.ndindex(res.shape): - s_left = int(_left[idx]) - s_right = int(_right[idx]) - s_res = int(res[idx]) - s_xor = ah.int_to_dtype( - s_left ^ s_right, + assert_binary_param_shape(func_name, left, right, right_is_scalar, res, res_name) + if right_is_scalar: + for idx in sh.ndindex(res.shape): + if res.dtype == xp.bool: + scalar_l = bool(left[idx]) + scalar_o = bool(res[idx]) + expected = scalar_l ^ right + else: + scalar_l = int(left[idx]) + scalar_o = int(res[idx]) + expected = ah.int_to_dtype( + scalar_l ^ right, dh.dtype_nbits[res.dtype], dh.dtype_signed[res.dtype], ) - assert s_xor == s_res + f_l = sh.fmt_idx(left_sym, idx) + f_o = sh.fmt_idx(res_name, idx) + assert scalar_o == expected, ( + f"{f_o}={scalar_o}, but should be ({f_l} ^ {right})={expected} " + f"[{func_name}()]\n{f_l}={scalar_l}" + ) + else: + for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, res.shape): + if res.dtype == xp.bool: + scalar_l = bool(left[l_idx]) + scalar_r = bool(right[r_idx]) + scalar_o = bool(res[o_idx]) + expected = scalar_l ^ scalar_r + else: + scalar_l = int(left[l_idx]) + scalar_r = int(right[r_idx]) + scalar_o = int(res[o_idx]) + expected = ah.int_to_dtype( + scalar_l ^ scalar_r, + dh.dtype_nbits[res.dtype], + dh.dtype_signed[res.dtype], + ) + f_l = sh.fmt_idx(left_sym, l_idx) + f_r = sh.fmt_idx(right_sym, r_idx) + f_o = sh.fmt_idx(res_name, o_idx) + assert scalar_o == expected, ( + f"{f_o}={scalar_o}, but should be ({f_l} ^ {f_r})={expected} " + f"[{func_name}()]\n{f_l}={scalar_l}, {f_r}={scalar_r}" + ) @given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes())) From 2f8492b8d78d1a4b02de25bd2bc718faa186974b Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Wed, 26 Jan 2022 13:24:06 +0000 Subject: [PATCH 05/40] Context objects for unary/binary params --- ...est_operators_and_elementwise_functions.py | 750 +++++++----------- 1 file changed, 269 insertions(+), 481 deletions(-) diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index 06fe8101..d4cc8199 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -11,7 +11,7 @@ import math from enum import Enum, auto -from typing import Callable, List, Optional, Union +from typing import Callable, List, NamedTuple, Optional, Union import pytest from hypothesis import assume, given @@ -57,42 +57,32 @@ def boolean_and_all_integer_dtypes() -> st.SearchStrategy[DataType]: all_op_to_symbol = {**dh.binary_op_to_symbol, **dh.inplace_op_to_symbol} finite_kw = {"allow_nan": False, "allow_infinity": False} -unary_argnames = ("func_name", "func", "strat") -UnaryParam = Param[str, Callable[[Array], Array], st.SearchStrategy[Array]] + +class UnaryParamContext(NamedTuple): + func_name: str + func: Callable[[Array], Array] + strat: st.SearchStrategy[Array] + + @property + def id(self) -> str: + return f"{self.func_name}" + + def __repr__(self): + return f"UnaryParamContext(<{self.id}>)" def make_unary_params( elwise_func_name: str, dtypes_strat: st.SearchStrategy[DataType] -) -> List[UnaryParam]: +) -> List[Param[UnaryParamContext]]: strat = xps.arrays(dtype=dtypes_strat, shape=hh.shapes()) - func = getattr(xp, elwise_func_name) + func_ctx = UnaryParamContext( + func_name=elwise_func_name, func=getattr(xp, elwise_func_name), strat=strat + ) op_name = func_to_op[elwise_func_name] - op = lambda x: getattr(x, op_name)() - return [ - pytest.param(elwise_func_name, func, strat, id=elwise_func_name), - pytest.param(op_name, op, strat, id=op_name), - ] - - -binary_argnames = ( - "func_name", - "func", - "left_sym", - "left_strat", - "right_sym", - "right_strat", - "right_is_scalar", - "res_name", -) -BinaryParam = Param[ - str, - Callable[[Array, Union[Scalar, Array]], Array], - str, - st.SearchStrategy[Array], - str, - st.SearchStrategy[Union[Scalar, Array]], - bool, -] + op_ctx = UnaryParamContext( + func_name=op_name, func=lambda x: getattr(x, op_name)(), strat=strat + ) + return [pytest.param(func_ctx, id=func_ctx.id), pytest.param(op_ctx, id=op_ctx.id)] class FuncType(Enum): @@ -104,12 +94,30 @@ class FuncType(Enum): shapes_kw = {"min_side": 1} +class BinaryParamContext(NamedTuple): + func_name: str + func: Callable[[Array, Union[Scalar, Array]], Array] + left_sym: str + left_strat: st.SearchStrategy[Array] + right_sym: str + right_strat: st.SearchStrategy[Union[Scalar, Array]] + right_is_scalar: bool + res_name: str + + @property + def id(self) -> str: + return f"{self.func_name}({self.left_sym}, {self.right_sym})" + + def __repr__(self): + return f"BinaryParamContext(<{self.id}>)" + + def make_binary_params( elwise_func_name: str, dtypes_strat: st.SearchStrategy[DataType] -) -> List[BinaryParam]: +) -> List[Param[BinaryParamContext]]: def make_param( func_name: str, func_type: FuncType, right_is_scalar: bool - ) -> BinaryParam: + ) -> Param[BinaryParamContext]: if right_is_scalar: left_sym = "x" right_sym = "s" @@ -168,7 +176,7 @@ def func(l: Array, r: Union[Scalar, Array]) -> Array: else: res_name = "out" - return pytest.param( + ctx = BinaryParamContext( func_name, func, left_sym, @@ -177,8 +185,8 @@ def func(l: Array, r: Union[Scalar, Array]) -> Array: right_strat, right_is_scalar, res_name, - id=f"{func_name}({left_sym}, {right_sym})", ) + return pytest.param(ctx, id=ctx.id) op_name = func_to_op[elwise_func_name] params = [ @@ -195,57 +203,53 @@ def func(l: Array, r: Union[Scalar, Array]) -> Array: def assert_binary_param_dtype( - func_name: str, + ctx: BinaryParamContext, left: Array, right: Union[Array, Scalar], - right_is_scalar: bool, res: Array, - res_name: str, expected: Optional[DataType] = None, ): - if right_is_scalar: + if ctx.right_is_scalar: in_dtypes = left.dtype else: in_dtypes = (left.dtype, right.dtype) # type: ignore ph.assert_dtype( - func_name, in_dtypes, res.dtype, expected, repr_name=f"{res_name}.dtype" + ctx.func_name, in_dtypes, res.dtype, expected, repr_name=f"{ctx.res_name}.dtype" ) def assert_binary_param_shape( - func_name: str, + ctx: BinaryParamContext, left: Array, right: Union[Array, Scalar], - right_is_scalar: bool, res: Array, - res_name: str, expected: Optional[Shape] = None, ): - if right_is_scalar: + if ctx.right_is_scalar: in_shapes = (left.shape,) else: in_shapes = (left.shape, right.shape) # type: ignore ph.assert_result_shape( - func_name, in_shapes, res.shape, expected, repr_name=f"{res_name}.shape" + ctx.func_name, in_shapes, res.shape, expected, repr_name=f"{ctx.res_name}.shape" ) -@pytest.mark.parametrize(unary_argnames, make_unary_params("abs", xps.numeric_dtypes())) +@pytest.mark.parametrize("ctx", make_unary_params("abs", xps.numeric_dtypes())) @given(data=st.data()) -def test_abs(func_name, func, strat, data): - x = data.draw(strat, label="x") +def test_abs(ctx, data): + x = data.draw(ctx.strat, label="x") if x.dtype in dh.int_dtypes: # abs of the smallest representable negative integer is not defined mask = xp.not_equal( x, ah.full(x.shape, dh.dtype_ranges[x.dtype].min, dtype=x.dtype) ) x = x[mask] - out = func(x) - ph.assert_dtype(func_name, x.dtype, out.dtype) - ph.assert_shape(func_name, out.shape, x.shape) + out = ctx.func(x) + ph.assert_dtype(ctx.func_name, x.dtype, out.dtype) + ph.assert_shape(ctx.func_name, out.shape, x.shape) assert ah.all( ah.logical_not(ah.negative_mathematical_sign(out)) - ), f"out elements not all positively signed [{func_name}()]\n{out=}" + ), f"out elements not all positively signed [{ctx.func_name}()]\n{out=}" less_zero = ah.negative_mathematical_sign(x) negx = ah.negative(x) # abs(x) = -x for x < 0 @@ -288,34 +292,22 @@ def test_acosh(x): ah.assert_exactly_equal(domain, codomain) -@pytest.mark.parametrize( - binary_argnames, make_binary_params("add", xps.numeric_dtypes()) -) +@pytest.mark.parametrize("ctx,", make_binary_params("add", xps.numeric_dtypes())) @given(data=st.data()) -def test_add( - func_name, - func, - left_sym, - left_strat, - right_sym, - right_strat, - right_is_scalar, - res_name, - data, -): - left = data.draw(left_strat, label=left_sym) - right = data.draw(right_strat, label=right_sym) +def test_add(ctx, data): + left = data.draw(ctx.left_strat, label=ctx.left_sym) + right = data.draw(ctx.right_strat, label=ctx.right_sym) try: - res = func(left, right) + res = ctx.func(left, right) except OverflowError: reject() - assert_binary_param_dtype(func_name, left, right, right_is_scalar, res, res_name) - assert_binary_param_shape(func_name, left, right, right_is_scalar, res, res_name) - if not right_is_scalar: + assert_binary_param_dtype(ctx, left, right, res) + assert_binary_param_shape(ctx, left, right, res) + if not ctx.right_is_scalar: # add is commutative - expected = func(right, left) + expected = ctx.func(right, left) ah.assert_exactly_equal(res, expected) @@ -417,29 +409,19 @@ def test_atanh(x): @pytest.mark.parametrize( - binary_argnames, make_binary_params("bitwise_and", boolean_and_all_integer_dtypes()) + "ctx", make_binary_params("bitwise_and", boolean_and_all_integer_dtypes()) ) @given(data=st.data()) -def test_bitwise_and( - func_name, - func, - left_sym, - left_strat, - right_sym, - right_strat, - right_is_scalar, - res_name, - data, -): - left = data.draw(left_strat, label=left_sym) - right = data.draw(right_strat, label=right_sym) +def test_bitwise_and(ctx, data): + left = data.draw(ctx.left_strat, label=ctx.left_sym) + right = data.draw(ctx.right_strat, label=ctx.right_sym) - res = func(left, right) + res = ctx.func(left, right) - assert_binary_param_dtype(func_name, left, right, right_is_scalar, res, res_name) - assert_binary_param_shape(func_name, left, right, right_is_scalar, res, res_name) + assert_binary_param_dtype(ctx, left, right, res) + assert_binary_param_shape(ctx, left, right, res) scalar_type = dh.get_scalar_type(res.dtype) - if right_is_scalar: + if ctx.right_is_scalar: for idx in sh.ndindex(res.shape): scalar_l = scalar_type(left[idx]) if res.dtype == xp.bool: @@ -454,11 +436,11 @@ def test_bitwise_and( dh.dtype_signed[res.dtype], ) scalar_o = scalar_type(res[idx]) - f_l = sh.fmt_idx(left_sym, idx) - f_o = sh.fmt_idx(res_name, idx) + f_l = sh.fmt_idx(ctx.left_sym, idx) + f_o = sh.fmt_idx(ctx.res_name, idx) assert scalar_o == expected, ( f"{f_o}={scalar_o}, but should be ({f_l} & {right})={expected} " - f"[{func_name}()]\n{f_l}={scalar_l}" + f"[{ctx.func_name}()]\n{f_l}={scalar_l}" ) else: for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, res.shape): @@ -476,42 +458,32 @@ def test_bitwise_and( dh.dtype_signed[res.dtype], ) scalar_o = scalar_type(res[o_idx]) - f_l = sh.fmt_idx(left_sym, l_idx) - f_r = sh.fmt_idx(right_sym, r_idx) - f_o = sh.fmt_idx(res_name, o_idx) + f_l = sh.fmt_idx(ctx.left_sym, l_idx) + f_r = sh.fmt_idx(ctx.right_sym, r_idx) + f_o = sh.fmt_idx(ctx.res_name, o_idx) assert scalar_o == expected, ( f"{f_o}={scalar_o}, but should be ({f_l} & {f_r})={expected} " - f"[{func_name}()]\n{f_l}={scalar_l}, {f_r}={scalar_r}" + f"[{ctx.func_name}()]\n{f_l}={scalar_l}, {f_r}={scalar_r}" ) @pytest.mark.parametrize( - binary_argnames, make_binary_params("bitwise_left_shift", all_integer_dtypes()) + "ctx", make_binary_params("bitwise_left_shift", all_integer_dtypes()) ) @given(data=st.data()) -def test_bitwise_left_shift( - func_name, - func, - left_sym, - left_strat, - right_sym, - right_strat, - right_is_scalar, - res_name, - data, -): - left = data.draw(left_strat, label=left_sym) - right = data.draw(right_strat, label=right_sym) - if right_is_scalar: +def test_bitwise_left_shift(ctx, data): + left = data.draw(ctx.left_strat, label=ctx.left_sym) + right = data.draw(ctx.right_strat, label=ctx.right_sym) + if ctx.right_is_scalar: assume(right >= 0) else: assume(not ah.any(ah.isnegative(right))) - res = func(left, right) + res = ctx.func(left, right) - assert_binary_param_dtype(func_name, left, right, right_is_scalar, res, res_name) - assert_binary_param_shape(func_name, left, right, right_is_scalar, res, res_name) - if right_is_scalar: + assert_binary_param_dtype(ctx, left, right, res) + assert_binary_param_shape(ctx, left, right, res) + if ctx.right_is_scalar: for idx in sh.ndindex(res.shape): scalar_l = int(left[idx]) expected = ah.int_to_dtype( @@ -521,11 +493,11 @@ def test_bitwise_left_shift( dh.dtype_signed[res.dtype], ) scalar_o = int(res[idx]) - f_l = sh.fmt_idx(left_sym, idx) - f_o = sh.fmt_idx(res_name, idx) + f_l = sh.fmt_idx(ctx.left_sym, idx) + f_o = sh.fmt_idx(ctx.res_name, idx) assert scalar_o == expected, ( f"{f_o}={scalar_o}, but should be ({f_l} << {right})={expected} " - f"[{func_name}()]\n{f_l}={scalar_l}" + f"[{ctx.func_name}()]\n{f_l}={scalar_l}" ) else: for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, res.shape): @@ -538,27 +510,27 @@ def test_bitwise_left_shift( dh.dtype_signed[res.dtype], ) scalar_o = int(res[o_idx]) - f_l = sh.fmt_idx(left_sym, l_idx) - f_r = sh.fmt_idx(right_sym, r_idx) - f_o = sh.fmt_idx(res_name, o_idx) + f_l = sh.fmt_idx(ctx.left_sym, l_idx) + f_r = sh.fmt_idx(ctx.right_sym, r_idx) + f_o = sh.fmt_idx(ctx.res_name, o_idx) assert scalar_o == expected, ( f"{f_o}={scalar_o}, but should be ({f_l} << {f_r})={expected} " - f"[{func_name}()]\n{f_l}={scalar_l}, {f_r}={scalar_r}" + f"[{ctx.func_name}()]\n{f_l}={scalar_l}, {f_r}={scalar_r}" ) @pytest.mark.parametrize( - unary_argnames, + "ctx", make_unary_params("bitwise_invert", boolean_and_all_integer_dtypes()), ) @given(data=st.data()) -def test_bitwise_invert(func_name, func, strat, data): - x = data.draw(strat, label="x") +def test_bitwise_invert(ctx, data): + x = data.draw(ctx.strat, label="x") - out = func(x) + out = ctx.func(x) - ph.assert_dtype(func_name, x.dtype, out.dtype) - ph.assert_shape(func_name, out.shape, x.shape) + ph.assert_dtype(ctx.func_name, x.dtype, out.dtype) + ph.assert_shape(ctx.func_name, out.shape, x.shape) for idx in sh.ndindex(out.shape): if out.dtype == xp.bool: scalar_x = bool(x[idx]) @@ -574,33 +546,23 @@ def test_bitwise_invert(func_name, func, strat, data): f_o = sh.fmt_idx("out", idx) assert scalar_o == expected, ( f"{f_o}={scalar_o}, but should be ~{f_x}={scalar_x} " - f"[{func_name}()]\n{f_x}={scalar_x}" + f"[{ctx.func_name}()]\n{f_x}={scalar_x}" ) @pytest.mark.parametrize( - binary_argnames, make_binary_params("bitwise_or", boolean_and_all_integer_dtypes()) + "ctx", make_binary_params("bitwise_or", boolean_and_all_integer_dtypes()) ) @given(data=st.data()) -def test_bitwise_or( - func_name, - func, - left_sym, - left_strat, - right_sym, - right_strat, - right_is_scalar, - res_name, - data, -): - left = data.draw(left_strat, label=left_sym) - right = data.draw(right_strat, label=right_sym) +def test_bitwise_or(ctx, data): + left = data.draw(ctx.left_strat, label=ctx.left_sym) + right = data.draw(ctx.right_strat, label=ctx.right_sym) - res = func(left, right) + res = ctx.func(left, right) - assert_binary_param_dtype(func_name, left, right, right_is_scalar, res, res_name) - assert_binary_param_shape(func_name, left, right, right_is_scalar, res, res_name) - if right_is_scalar: + assert_binary_param_dtype(ctx, left, right, res) + assert_binary_param_shape(ctx, left, right, res) + if ctx.right_is_scalar: for idx in sh.ndindex(res.shape): if res.dtype == xp.bool: scalar_l = bool(left[idx]) @@ -614,11 +576,11 @@ def test_bitwise_or( dh.dtype_nbits[res.dtype], dh.dtype_signed[res.dtype], ) - f_l = sh.fmt_idx(left_sym, idx) - f_o = sh.fmt_idx(res_name, idx) + f_l = sh.fmt_idx(ctx.left_sym, idx) + f_o = sh.fmt_idx(ctx.res_name, idx) assert scalar_o == expected, ( f"{f_o}={scalar_o}, but should be ({f_l} | {right})={expected} " - f"[{func_name}()]\n{f_l}={scalar_l}" + f"[{ctx.func_name}()]\n{f_l}={scalar_l}" ) else: for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, res.shape): @@ -636,42 +598,32 @@ def test_bitwise_or( dh.dtype_nbits[res.dtype], dh.dtype_signed[res.dtype], ) - f_l = sh.fmt_idx(left_sym, l_idx) - f_r = sh.fmt_idx(right_sym, r_idx) - f_o = sh.fmt_idx(res_name, o_idx) + f_l = sh.fmt_idx(ctx.left_sym, l_idx) + f_r = sh.fmt_idx(ctx.right_sym, r_idx) + f_o = sh.fmt_idx(ctx.res_name, o_idx) assert scalar_o == expected, ( f"{f_o}={scalar_o}, but should be ({f_l} | {f_r})={expected} " - f"[{func_name}()]\n{f_l}={scalar_l}, {f_r}={scalar_r}" + f"[{ctx.func_name}()]\n{f_l}={scalar_l}, {f_r}={scalar_r}" ) @pytest.mark.parametrize( - binary_argnames, make_binary_params("bitwise_right_shift", all_integer_dtypes()) + "ctx", make_binary_params("bitwise_right_shift", all_integer_dtypes()) ) @given(data=st.data()) -def test_bitwise_right_shift( - func_name, - func, - left_sym, - left_strat, - right_sym, - right_strat, - right_is_scalar, - res_name, - data, -): - left = data.draw(left_strat, label=left_sym) - right = data.draw(right_strat, label=right_sym) - if right_is_scalar: +def test_bitwise_right_shift(ctx, data): + left = data.draw(ctx.left_strat, label=ctx.left_sym) + right = data.draw(ctx.right_strat, label=ctx.right_sym) + if ctx.right_is_scalar: assume(right >= 0) else: assume(not ah.any(ah.isnegative(right))) - res = func(left, right) + res = ctx.func(left, right) - assert_binary_param_dtype(func_name, left, right, right_is_scalar, res, res_name) - assert_binary_param_shape(func_name, left, right, right_is_scalar, res, res_name) - if right_is_scalar: + assert_binary_param_dtype(ctx, left, right, res) + assert_binary_param_shape(ctx, left, right, res) + if ctx.right_is_scalar: for idx in sh.ndindex(res.shape): scalar_l = int(left[idx]) expected = ah.int_to_dtype( @@ -680,11 +632,11 @@ def test_bitwise_right_shift( dh.dtype_signed[res.dtype], ) scalar_o = int(res[idx]) - f_l = sh.fmt_idx(left_sym, idx) - f_o = sh.fmt_idx(res_name, idx) + f_l = sh.fmt_idx(ctx.left_sym, idx) + f_o = sh.fmt_idx(ctx.res_name, idx) assert scalar_o == expected, ( f"{f_o}={scalar_o}, but should be ({f_l} >> {right})={expected} " - f"[{func_name}()]\n{f_l}={scalar_l}" + f"[{ctx.func_name}()]\n{f_l}={scalar_l}" ) else: for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, res.shape): @@ -696,38 +648,28 @@ def test_bitwise_right_shift( dh.dtype_signed[res.dtype], ) scalar_o = int(res[o_idx]) - f_l = sh.fmt_idx(left_sym, l_idx) - f_r = sh.fmt_idx(right_sym, r_idx) - f_o = sh.fmt_idx(res_name, o_idx) + f_l = sh.fmt_idx(ctx.left_sym, l_idx) + f_r = sh.fmt_idx(ctx.right_sym, r_idx) + f_o = sh.fmt_idx(ctx.res_name, o_idx) assert scalar_o == expected, ( f"{f_o}={scalar_o}, but should be ({f_l} >> {f_r})={expected} " - f"[{func_name}()]\n{f_l}={scalar_l}, {f_r}={scalar_r}" + f"[{ctx.func_name}()]\n{f_l}={scalar_l}, {f_r}={scalar_r}" ) @pytest.mark.parametrize( - binary_argnames, make_binary_params("bitwise_xor", boolean_and_all_integer_dtypes()) + "ctx", make_binary_params("bitwise_xor", boolean_and_all_integer_dtypes()) ) @given(data=st.data()) -def test_bitwise_xor( - func_name, - func, - left_sym, - left_strat, - right_sym, - right_strat, - right_is_scalar, - res_name, - data, -): - left = data.draw(left_strat, label=left_sym) - right = data.draw(right_strat, label=right_sym) +def test_bitwise_xor(ctx, data): + left = data.draw(ctx.left_strat, label=ctx.left_sym) + right = data.draw(ctx.right_strat, label=ctx.right_sym) - res = func(left, right) + res = ctx.func(left, right) - assert_binary_param_dtype(func_name, left, right, right_is_scalar, res, res_name) - assert_binary_param_shape(func_name, left, right, right_is_scalar, res, res_name) - if right_is_scalar: + assert_binary_param_dtype(ctx, left, right, res) + assert_binary_param_shape(ctx, left, right, res) + if ctx.right_is_scalar: for idx in sh.ndindex(res.shape): if res.dtype == xp.bool: scalar_l = bool(left[idx]) @@ -741,11 +683,11 @@ def test_bitwise_xor( dh.dtype_nbits[res.dtype], dh.dtype_signed[res.dtype], ) - f_l = sh.fmt_idx(left_sym, idx) - f_o = sh.fmt_idx(res_name, idx) + f_l = sh.fmt_idx(ctx.left_sym, idx) + f_o = sh.fmt_idx(ctx.res_name, idx) assert scalar_o == expected, ( f"{f_o}={scalar_o}, but should be ({f_l} ^ {right})={expected} " - f"[{func_name}()]\n{f_l}={scalar_l}" + f"[{ctx.func_name}()]\n{f_l}={scalar_l}" ) else: for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, res.shape): @@ -763,12 +705,12 @@ def test_bitwise_xor( dh.dtype_nbits[res.dtype], dh.dtype_signed[res.dtype], ) - f_l = sh.fmt_idx(left_sym, l_idx) - f_r = sh.fmt_idx(right_sym, r_idx) - f_o = sh.fmt_idx(res_name, o_idx) + f_l = sh.fmt_idx(ctx.left_sym, l_idx) + f_r = sh.fmt_idx(ctx.right_sym, r_idx) + f_o = sh.fmt_idx(ctx.res_name, o_idx) assert scalar_o == expected, ( f"{f_o}={scalar_o}, but should be ({f_l} ^ {f_r})={expected} " - f"[{func_name}()]\n{f_l}={scalar_l}, {f_r}={scalar_r}" + f"[{ctx.func_name}()]\n{f_l}={scalar_l}, {f_r}={scalar_r}" ) @@ -815,27 +757,15 @@ def test_cosh(x): ah.assert_exactly_equal(domain, codomain) -@pytest.mark.parametrize( - binary_argnames, make_binary_params("divide", xps.floating_dtypes()) -) +@pytest.mark.parametrize("ctx", make_binary_params("divide", xps.floating_dtypes())) @given(data=st.data()) -def test_divide( - func_name, - func, - left_sym, - left_strat, - right_sym, - right_strat, - right_is_scalar, - res_name, - data, -): - left = data.draw(left_strat, label=left_sym) - right = data.draw(right_strat, label=right_sym) +def test_divide(ctx, data): + left = data.draw(ctx.left_strat, label=ctx.left_sym) + right = data.draw(ctx.right_strat, label=ctx.right_sym) - res = func(left, right) + res = ctx.func(left, right) - assert_binary_param_dtype(func_name, left, right, right_is_scalar, res, res_name) + assert_binary_param_dtype(ctx, left, right, res) # There isn't much we can test here. The spec doesn't require any behavior # beyond the special cases, and indeed, there aren't many mathematical # properties of division that strictly hold for floating-point numbers. We @@ -843,41 +773,27 @@ def test_divide( # have those sorts in general for this module. -@pytest.mark.parametrize( - binary_argnames, make_binary_params("equal", xps.scalar_dtypes()) -) +@pytest.mark.parametrize("ctx", make_binary_params("equal", xps.scalar_dtypes())) @given(data=st.data()) -def test_equal( - func_name, - func, - left_sym, - left_strat, - right_sym, - right_strat, - right_is_scalar, - res_name, - data, -): - left = data.draw(left_strat, label=left_sym) - right = data.draw(right_strat, label=right_sym) +def test_equal(ctx, data): + left = data.draw(ctx.left_strat, label=ctx.left_sym) + right = data.draw(ctx.right_strat, label=ctx.right_sym) - out = func(left, right) + out = ctx.func(left, right) - assert_binary_param_dtype( - func_name, left, right, right_is_scalar, out, res_name, xp.bool - ) - assert_binary_param_shape(func_name, left, right, right_is_scalar, out, res_name) - if right_is_scalar: + assert_binary_param_dtype(ctx, left, right, out, xp.bool) + assert_binary_param_shape(ctx, left, right, out) + if ctx.right_is_scalar: scalar_type = dh.get_scalar_type(left.dtype) for idx in sh.ndindex(left.shape): scalar_l = scalar_type(left[idx]) expected = scalar_l == right scalar_o = bool(out[idx]) - f_l = sh.fmt_idx(left_sym, idx) - f_o = sh.fmt_idx(res_name, idx) + f_l = sh.fmt_idx(ctx.left_sym, idx) + f_o = sh.fmt_idx(ctx.res_name, idx) assert scalar_o == expected, ( f"{f_o}={scalar_o}, but should be ({f_l} == {right})={expected} " - f"[{func_name}()]\n{f_l}={scalar_l}" + f"[{ctx.func_name}()]\n{f_l}={scalar_l}" ) else: # We manually promote the dtypes as incorrect internal type promotion @@ -898,12 +814,12 @@ def test_equal( scalar_r = scalar_type(_right[r_idx]) expected = scalar_l == scalar_r scalar_o = bool(out[o_idx]) - f_l = sh.fmt_idx(left_sym, l_idx) - f_r = sh.fmt_idx(right_sym, r_idx) - f_o = sh.fmt_idx(res_name, o_idx) + f_l = sh.fmt_idx(ctx.left_sym, l_idx) + f_r = sh.fmt_idx(ctx.right_sym, r_idx) + f_o = sh.fmt_idx(ctx.res_name, o_idx) assert scalar_o == expected, ( f"{f_o}={scalar_o}, but should be ({f_l} == {f_r})={expected} " - f"[{func_name}()]\n{f_l}={scalar_l}, {f_r}={scalar_r}" + f"[{ctx.func_name}()]\n{f_l}={scalar_l}, {f_r}={scalar_r}" ) @@ -952,31 +868,23 @@ def test_floor(x): @pytest.mark.parametrize( - binary_argnames, make_binary_params("floor_divide", xps.numeric_dtypes()) + "ctx", make_binary_params("floor_divide", xps.numeric_dtypes()) ) @given(data=st.data()) -def test_floor_divide( - func_name, - func, - left_sym, - left_strat, - right_sym, - right_strat, - right_is_scalar, - res_name, - data, -): - left = data.draw(left_strat.filter(lambda x: not ah.any(x == 0)), label=left_sym) - right = data.draw(right_strat, label=right_sym) - if right_is_scalar: +def test_floor_divide(ctx, data): + left = data.draw( + ctx.left_strat.filter(lambda x: not ah.any(x == 0)), label=ctx.left_sym + ) + right = data.draw(ctx.right_strat, label=ctx.right_sym) + if ctx.right_is_scalar: assume(right != 0) else: assume(not ah.any(right == 0)) - res = func(left, right) + res = ctx.func(left, right) - assert_binary_param_dtype(func_name, left, right, right_is_scalar, res, res_name) - if not right_is_scalar: + assert_binary_param_dtype(ctx, left, right, res) + if not ctx.right_is_scalar: if dh.is_int_dtype(left.dtype): # The spec does not specify the behavior for division by 0 for integer # dtypes. A library may choose to raise an exception in this case, so @@ -995,33 +903,19 @@ def test_floor_divide( # TODO: Test the exact output for floor_divide. -@pytest.mark.parametrize( - binary_argnames, make_binary_params("greater", xps.numeric_dtypes()) -) +@pytest.mark.parametrize("ctx", make_binary_params("greater", xps.numeric_dtypes())) @given(data=st.data()) -def test_greater( - func_name, - func, - left_sym, - left_strat, - right_sym, - right_strat, - right_is_scalar, - res_name, - data, -): - left = data.draw(left_strat, label=left_sym) - right = data.draw(right_strat, label=right_sym) +def test_greater(ctx, data): + left = data.draw(ctx.left_strat, label=ctx.left_sym) + right = data.draw(ctx.right_strat, label=ctx.right_sym) - out = func(left, right) + out = ctx.func(left, right) - assert_binary_param_dtype( - func_name, left, right, right_is_scalar, out, res_name, xp.bool - ) - if not right_is_scalar: + assert_binary_param_dtype(ctx, left, right, out, xp.bool) + if not ctx.right_is_scalar: # TODO: generate indices without broadcasting arrays (see test_equal comment) shape = broadcast_shapes(left.shape, right.shape) - ph.assert_shape(func_name, out.shape, shape) + ph.assert_shape(ctx.func_name, out.shape, shape) _left = xp.broadcast_to(left, shape) _right = xp.broadcast_to(right, shape) @@ -1039,33 +933,21 @@ def test_greater( @pytest.mark.parametrize( - binary_argnames, make_binary_params("greater_equal", xps.numeric_dtypes()) + "ctx", make_binary_params("greater_equal", xps.numeric_dtypes()) ) @given(data=st.data()) -def test_greater_equal( - func_name, - func, - left_sym, - left_strat, - right_sym, - right_strat, - right_is_scalar, - res_name, - data, -): - left = data.draw(left_strat, label=left_sym) - right = data.draw(right_strat, label=right_sym) +def test_greater_equal(ctx, data): + left = data.draw(ctx.left_strat, label=ctx.left_sym) + right = data.draw(ctx.right_strat, label=ctx.right_sym) - out = func(left, right) + out = ctx.func(left, right) - assert_binary_param_dtype( - func_name, left, right, right_is_scalar, out, res_name, xp.bool - ) - if not right_is_scalar: + assert_binary_param_dtype(ctx, left, right, out, xp.bool) + if not ctx.right_is_scalar: # TODO: generate indices without broadcasting arrays (see test_equal comment) shape = broadcast_shapes(left.shape, right.shape) - ph.assert_shape(func_name, out.shape, shape) + ph.assert_shape(ctx.func_name, out.shape, shape) _left = xp.broadcast_to(left, shape) _right = xp.broadcast_to(right, shape) @@ -1138,34 +1020,20 @@ def test_isnan(x): assert bool(out[idx]) == math.isnan(s) -@pytest.mark.parametrize( - binary_argnames, make_binary_params("less", xps.numeric_dtypes()) -) +@pytest.mark.parametrize("ctx", make_binary_params("less", xps.numeric_dtypes())) @given(data=st.data()) -def test_less( - func_name, - func, - left_sym, - left_strat, - right_sym, - right_strat, - right_is_scalar, - res_name, - data, -): - left = data.draw(left_strat, label=left_sym) - right = data.draw(right_strat, label=right_sym) +def test_less(ctx, data): + left = data.draw(ctx.left_strat, label=ctx.left_sym) + right = data.draw(ctx.right_strat, label=ctx.right_sym) - out = func(left, right) + out = ctx.func(left, right) - assert_binary_param_dtype( - func_name, left, right, right_is_scalar, out, res_name, xp.bool - ) - if not right_is_scalar: + assert_binary_param_dtype(ctx, left, right, out, xp.bool) + if not ctx.right_is_scalar: # TODO: generate indices without broadcasting arrays (see test_equal comment) shape = broadcast_shapes(left.shape, right.shape) - ph.assert_shape(func_name, out.shape, shape) + ph.assert_shape(ctx.func_name, out.shape, shape) _left = xp.broadcast_to(left, shape) _right = xp.broadcast_to(right, shape) @@ -1182,34 +1050,20 @@ def test_less( assert bool(out_idx) == (scalar_type(x1_idx) < scalar_type(x2_idx)) -@pytest.mark.parametrize( - binary_argnames, make_binary_params("less_equal", xps.numeric_dtypes()) -) +@pytest.mark.parametrize("ctx", make_binary_params("less_equal", xps.numeric_dtypes())) @given(data=st.data()) -def test_less_equal( - func_name, - func, - left_sym, - left_strat, - right_sym, - right_strat, - right_is_scalar, - res_name, - data, -): - left = data.draw(left_strat, label=left_sym) - right = data.draw(right_strat, label=right_sym) +def test_less_equal(ctx, data): + left = data.draw(ctx.left_strat, label=ctx.left_sym) + right = data.draw(ctx.right_strat, label=ctx.right_sym) - out = func(left, right) + out = ctx.func(left, right) - assert_binary_param_dtype( - func_name, left, right, right_is_scalar, out, res_name, xp.bool - ) - if not right_is_scalar: + assert_binary_param_dtype(ctx, left, right, out, xp.bool) + if not ctx.right_is_scalar: # TODO: generate indices without broadcasting arrays (see test_equal comment) shape = broadcast_shapes(left.shape, right.shape) - ph.assert_shape(func_name, out.shape, shape) + ph.assert_shape(ctx.func_name, out.shape, shape) _left = xp.broadcast_to(left, shape) _right = xp.broadcast_to(right, shape) @@ -1341,47 +1195,33 @@ def test_logical_xor(x1, x2): assert out[idx] == (bool(_x1[idx]) ^ bool(_x2[idx])) -@pytest.mark.parametrize( - binary_argnames, make_binary_params("multiply", xps.numeric_dtypes()) -) +@pytest.mark.parametrize("ctx", make_binary_params("multiply", xps.numeric_dtypes())) @given(data=st.data()) -def test_multiply( - func_name, - func, - left_sym, - left_strat, - right_sym, - right_strat, - right_is_scalar, - res_name, - data, -): - left = data.draw(left_strat, label=left_sym) - right = data.draw(right_strat, label=right_sym) +def test_multiply(ctx, data): + left = data.draw(ctx.left_strat, label=ctx.left_sym) + right = data.draw(ctx.right_strat, label=ctx.right_sym) - res = func(left, right) + res = ctx.func(left, right) - assert_binary_param_dtype(func_name, left, right, right_is_scalar, res, res_name) - if not right_is_scalar: + assert_binary_param_dtype(ctx, left, right, res) + if not ctx.right_is_scalar: # multiply is commutative - expected = func(right, left) + expected = ctx.func(right, left) ah.assert_exactly_equal(res, expected) -@pytest.mark.parametrize( - unary_argnames, make_unary_params("negative", xps.numeric_dtypes()) -) +@pytest.mark.parametrize("ctx", make_unary_params("negative", xps.numeric_dtypes())) @given(data=st.data()) -def test_negative(func_name, func, strat, data): - x = data.draw(strat, label="x") +def test_negative(ctx, data): + x = data.draw(ctx.strat, label="x") - out = func(x) + out = ctx.func(x) - ph.assert_dtype(func_name, x.dtype, out.dtype) - ph.assert_shape(func_name, out.shape, x.shape) + ph.assert_dtype(ctx.func_name, x.dtype, out.dtype) + ph.assert_shape(ctx.func_name, out.shape, x.shape) # Negation is an involution - ah.assert_exactly_equal(x, func(out)) + ah.assert_exactly_equal(x, ctx.func(out)) mask = ah.isfinite(x) if dh.is_int_dtype(x.dtype): @@ -1395,41 +1235,27 @@ def test_negative(func_name, func, strat, data): ah.assert_exactly_equal(y, ah.zero(x[mask].shape, x.dtype)) -@pytest.mark.parametrize( - binary_argnames, make_binary_params("not_equal", xps.scalar_dtypes()) -) +@pytest.mark.parametrize("ctx", make_binary_params("not_equal", xps.scalar_dtypes())) @given(data=st.data()) -def test_not_equal( - func_name, - func, - left_sym, - left_strat, - right_sym, - right_strat, - right_is_scalar, - res_name, - data, -): - left = data.draw(left_strat, label=left_sym) - right = data.draw(right_strat, label=right_sym) +def test_not_equal(ctx, data): + left = data.draw(ctx.left_strat, label=ctx.left_sym) + right = data.draw(ctx.right_strat, label=ctx.right_sym) - out = func(left, right) + out = ctx.func(left, right) - assert_binary_param_dtype( - func_name, left, right, right_is_scalar, out, res_name, xp.bool - ) - assert_binary_param_shape(func_name, left, right, right_is_scalar, out, res_name) - if right_is_scalar: + assert_binary_param_dtype(ctx, left, right, out, xp.bool) + assert_binary_param_shape(ctx, left, right, out) + if ctx.right_is_scalar: scalar_type = dh.get_scalar_type(left.dtype) for idx in sh.ndindex(left.shape): scalar_l = scalar_type(left[idx]) expected = scalar_l != right scalar_o = bool(out[idx]) - f_l = sh.fmt_idx(left_sym, idx) - f_o = sh.fmt_idx(res_name, idx) + f_l = sh.fmt_idx(ctx.left_sym, idx) + f_o = sh.fmt_idx(ctx.res_name, idx) assert scalar_o == expected, ( f"{f_o}={scalar_o}, but should be ({f_l} != {right})={expected} " - f"[{func_name}()]\n{f_l}={scalar_l}" + f"[{ctx.func_name}()]\n{f_l}={scalar_l}" ) else: # See test_equal note @@ -1442,48 +1268,34 @@ def test_not_equal( scalar_r = scalar_type(_right[r_idx]) expected = scalar_l != scalar_r scalar_o = bool(out[o_idx]) - f_l = sh.fmt_idx(left_sym, l_idx) - f_r = sh.fmt_idx(right_sym, r_idx) - f_o = sh.fmt_idx(res_name, o_idx) + f_l = sh.fmt_idx(ctx.left_sym, l_idx) + f_r = sh.fmt_idx(ctx.right_sym, r_idx) + f_o = sh.fmt_idx(ctx.res_name, o_idx) assert scalar_o == expected, ( f"{f_o}={scalar_o}, but should be ({f_l} != {f_r})={expected} " - f"[{func_name}()]\n{f_l}={scalar_l}, {f_r}={scalar_r}" + f"[{ctx.func_name}()]\n{f_l}={scalar_l}, {f_r}={scalar_r}" ) -@pytest.mark.parametrize( - unary_argnames, make_unary_params("positive", xps.numeric_dtypes()) -) +@pytest.mark.parametrize("ctx", make_unary_params("positive", xps.numeric_dtypes())) @given(data=st.data()) -def test_positive(func_name, func, strat, data): - x = data.draw(strat, label="x") +def test_positive(ctx, data): + x = data.draw(ctx.strat, label="x") - out = func(x) + out = ctx.func(x) - ph.assert_dtype(func_name, x.dtype, out.dtype) - ph.assert_shape(func_name, out.shape, x.shape) + ph.assert_dtype(ctx.func_name, x.dtype, out.dtype) + ph.assert_shape(ctx.func_name, out.shape, x.shape) # Positive does nothing ah.assert_exactly_equal(out, x) -@pytest.mark.parametrize( - binary_argnames, make_binary_params("pow", xps.numeric_dtypes()) -) +@pytest.mark.parametrize("ctx", make_binary_params("pow", xps.numeric_dtypes())) @given(data=st.data()) -def test_pow( - func_name, - func, - left_sym, - left_strat, - right_sym, - right_strat, - right_is_scalar, - res_name, - data, -): - left = data.draw(left_strat, label=left_sym) - right = data.draw(right_strat, label=right_sym) - if right_is_scalar: +def test_pow(ctx, data): + left = data.draw(ctx.left_strat, label=ctx.left_sym) + right = data.draw(ctx.right_strat, label=ctx.right_sym) + if ctx.right_is_scalar: if isinstance(right, int): assume(right >= 0) else: @@ -1491,11 +1303,11 @@ def test_pow( assume(xp.all(right >= 0)) try: - res = func(left, right) + res = ctx.func(left, right) except OverflowError: reject() - assert_binary_param_dtype(func_name, left, right, right_is_scalar, res, res_name) + assert_binary_param_dtype(ctx, left, right, res) # There isn't much we can test here. The spec doesn't require any behavior # beyond the special cases, and indeed, there aren't many mathematical # properties of exponentiation that strictly hold for floating-point @@ -1503,36 +1315,24 @@ def test_pow( # don't yet have those sorts in general for this module. -@pytest.mark.parametrize( - binary_argnames, make_binary_params("remainder", xps.numeric_dtypes()) -) +@pytest.mark.parametrize("ctx", make_binary_params("remainder", xps.numeric_dtypes())) @given(data=st.data()) -def test_remainder( - func_name, - func, - left_sym, - left_strat, - right_sym, - right_strat, - right_is_scalar, - res_name, - data, -): - left = data.draw(left_strat, label=left_sym) - right = data.draw(right_strat, label=right_sym) - if right_is_scalar: +def test_remainder(ctx, data): + left = data.draw(ctx.left_strat, label=ctx.left_sym) + right = data.draw(ctx.right_strat, label=ctx.right_sym) + if ctx.right_is_scalar: out_dtype = left.dtype else: out_dtype = dh.result_type(left.dtype, right.dtype) if dh.is_int_dtype(out_dtype): - if right_is_scalar: + if ctx.right_is_scalar: assume(right != 0) else: assume(not ah.any(right == 0)) - res = func(left, right) + res = ctx.func(left, right) - assert_binary_param_dtype(func_name, left, right, right_is_scalar, res, res_name) + assert_binary_param_dtype(ctx, left, right, res) # TODO: test results @@ -1602,30 +1402,18 @@ def test_sqrt(x): ph.assert_shape("sqrt", out.shape, x.shape) -@pytest.mark.parametrize( - binary_argnames, make_binary_params("subtract", xps.numeric_dtypes()) -) +@pytest.mark.parametrize("ctx", make_binary_params("subtract", xps.numeric_dtypes())) @given(data=st.data()) -def test_subtract( - func_name, - func, - left_sym, - left_strat, - right_sym, - right_strat, - right_is_scalar, - res_name, - data, -): - left = data.draw(left_strat, label=left_sym) - right = data.draw(right_strat, label=right_sym) +def test_subtract(ctx, data): + left = data.draw(ctx.left_strat, label=ctx.left_sym) + right = data.draw(ctx.right_strat, label=ctx.right_sym) try: - res = func(left, right) + res = ctx.func(left, right) except OverflowError: reject() - assert_binary_param_dtype(func_name, left, right, right_is_scalar, res, res_name) + assert_binary_param_dtype(ctx, left, right, res) # TODO From 4623214036d6d079fbaa087742e363ad840c3a21 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Wed, 26 Jan 2022 17:49:23 +0000 Subject: [PATCH 06/40] Apply `iter_indices()` logic to binary op/elwise tests --- ...est_operators_and_elementwise_functions.py | 241 ++++++++++++------ 1 file changed, 158 insertions(+), 83 deletions(-) diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index d4cc8199..6e26de67 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -39,6 +39,12 @@ def boolean_and_all_integer_dtypes() -> st.SearchStrategy[DataType]: return xps.boolean_dtypes() | all_integer_dtypes() +def isclose(n1: float, n2: float): + if not (math.isfinite(n1) and math.isfinite(n2)): + raise ValueError(f"{n1=} and {n1=}, but input must be finite") + return math.isclose(n1, n2, rel_tol=0.25, abs_tol=1) + + # When appropiate, this module tests operators alongside their respective # elementwise methods. We do this by parametrizing a generalised test method # with every relevant method and operator. @@ -766,6 +772,7 @@ def test_divide(ctx, data): res = ctx.func(left, right) assert_binary_param_dtype(ctx, left, right, res) + assert_binary_param_shape(ctx, left, right, res) # There isn't much we can test here. The spec doesn't require any behavior # beyond the special cases, and indeed, there aren't many mathematical # properties of division that strictly hold for floating-point numbers. We @@ -884,23 +891,38 @@ def test_floor_divide(ctx, data): res = ctx.func(left, right) assert_binary_param_dtype(ctx, left, right, res) - if not ctx.right_is_scalar: - if dh.is_int_dtype(left.dtype): - # The spec does not specify the behavior for division by 0 for integer - # dtypes. A library may choose to raise an exception in this case, so - # we avoid passing it in entirely. - div = xp.divide( - ah.asarray(left, dtype=xp.float64), - ah.asarray(right, dtype=xp.float64), + assert_binary_param_shape(ctx, left, right, res) + scalar_type = dh.get_scalar_type(res.dtype) + if ctx.right_is_scalar: + for idx in sh.ndindex(res.shape): + scalar_l = scalar_type(left[idx]) + expected = scalar_l // right + scalar_o = scalar_type(res[idx]) + if not all(math.isfinite(n) for n in [scalar_l, right, scalar_o, expected]): + continue + f_l = sh.fmt_idx(ctx.left_sym, idx) + f_o = sh.fmt_idx(ctx.res_name, idx) + assert isclose(scalar_o, expected), ( + f"{f_o}={scalar_o}, but should be roughly ({f_l} // {right})={expected} " + f"[{ctx.func_name}()]\n{f_l}={scalar_l}" + ) + else: + for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, res.shape): + scalar_l = scalar_type(left[l_idx]) + scalar_r = scalar_type(right[r_idx]) + expected = scalar_l // scalar_r + scalar_o = scalar_type(res[o_idx]) + if not all( + math.isfinite(n) for n in [scalar_l, scalar_r, scalar_o, expected] + ): + continue + f_l = sh.fmt_idx(ctx.left_sym, l_idx) + f_r = sh.fmt_idx(ctx.right_sym, r_idx) + f_o = sh.fmt_idx(ctx.res_name, o_idx) + assert isclose(scalar_o, expected), ( + f"{f_o}={scalar_o}, but should be roughly ({f_l} // {f_r})={expected} " + f"[{ctx.func_name}()]\n{f_l}={scalar_l}, {f_r}={scalar_r}" ) - else: - div = xp.divide(left, right) - - # TODO: The spec doesn't clearly specify the behavior of floor_divide on - # infinities. See https://github.com/data-apis/array-api/issues/199. - finite = ah.isfinite(div) - ah.assert_integral(res[finite]) - # TODO: Test the exact output for floor_divide. @pytest.mark.parametrize("ctx", make_binary_params("greater", xps.numeric_dtypes())) @@ -912,24 +934,37 @@ def test_greater(ctx, data): out = ctx.func(left, right) assert_binary_param_dtype(ctx, left, right, out, xp.bool) - if not ctx.right_is_scalar: - # TODO: generate indices without broadcasting arrays (see test_equal comment) - shape = broadcast_shapes(left.shape, right.shape) - ph.assert_shape(ctx.func_name, out.shape, shape) - _left = xp.broadcast_to(left, shape) - _right = xp.broadcast_to(right, shape) - + assert_binary_param_shape(ctx, left, right, out) + if ctx.right_is_scalar: + scalar_type = dh.get_scalar_type(left.dtype) + for idx in sh.ndindex(left.shape): + scalar_l = scalar_type(left[idx]) + expected = scalar_l > right + scalar_o = bool(out[idx]) + f_l = sh.fmt_idx(ctx.left_sym, idx) + f_o = sh.fmt_idx(ctx.res_name, idx) + assert scalar_o == expected, ( + f"{f_o}={scalar_o}, but should be ({f_l} > {right})={expected} " + f"[{ctx.func_name}()]\n{f_l}={scalar_l}" + ) + else: + # See test_equal note promoted_dtype = dh.promotion_table[left.dtype, right.dtype] - _left = ah.asarray(_left, dtype=promoted_dtype) - _right = ah.asarray(_right, dtype=promoted_dtype) - + _left = xp.astype(left, promoted_dtype) + _right = xp.astype(right, promoted_dtype) scalar_type = dh.get_scalar_type(promoted_dtype) - for idx in sh.ndindex(shape): - out_idx = out[idx] - x1_idx = _left[idx] - x2_idx = _right[idx] - assert out_idx.shape == x1_idx.shape == x2_idx.shape # sanity check - assert bool(out_idx) == (scalar_type(x1_idx) > scalar_type(x2_idx)) + for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, out.shape): + scalar_l = scalar_type(_left[l_idx]) + scalar_r = scalar_type(_right[r_idx]) + expected = scalar_l > scalar_r + scalar_o = bool(out[o_idx]) + f_l = sh.fmt_idx(ctx.left_sym, l_idx) + f_r = sh.fmt_idx(ctx.right_sym, r_idx) + f_o = sh.fmt_idx(ctx.res_name, o_idx) + assert scalar_o == expected, ( + f"{f_o}={scalar_o}, but should be ({f_l} > {f_r})={expected} " + f"[{ctx.func_name}()]\n{f_l}={scalar_l}, {f_r}={scalar_r}" + ) @pytest.mark.parametrize( @@ -943,25 +978,37 @@ def test_greater_equal(ctx, data): out = ctx.func(left, right) assert_binary_param_dtype(ctx, left, right, out, xp.bool) - if not ctx.right_is_scalar: - # TODO: generate indices without broadcasting arrays (see test_equal comment) - - shape = broadcast_shapes(left.shape, right.shape) - ph.assert_shape(ctx.func_name, out.shape, shape) - _left = xp.broadcast_to(left, shape) - _right = xp.broadcast_to(right, shape) - + assert_binary_param_shape(ctx, left, right, out) + if ctx.right_is_scalar: + scalar_type = dh.get_scalar_type(left.dtype) + for idx in sh.ndindex(left.shape): + scalar_l = scalar_type(left[idx]) + expected = scalar_l >= right + scalar_o = bool(out[idx]) + f_l = sh.fmt_idx(ctx.left_sym, idx) + f_o = sh.fmt_idx(ctx.res_name, idx) + assert scalar_o == expected, ( + f"{f_o}={scalar_o}, but should be ({f_l} >= {right})={expected} " + f"[{ctx.func_name}()]\n{f_l}={scalar_l}" + ) + else: + # See test_equal note promoted_dtype = dh.promotion_table[left.dtype, right.dtype] - _left = ah.asarray(_left, dtype=promoted_dtype) - _right = ah.asarray(_right, dtype=promoted_dtype) - + _left = xp.astype(left, promoted_dtype) + _right = xp.astype(right, promoted_dtype) scalar_type = dh.get_scalar_type(promoted_dtype) - for idx in sh.ndindex(shape): - out_idx = out[idx] - x1_idx = _left[idx] - x2_idx = _right[idx] - assert out_idx.shape == x1_idx.shape == x2_idx.shape # sanity check - assert bool(out_idx) == (scalar_type(x1_idx) >= scalar_type(x2_idx)) + for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, out.shape): + scalar_l = scalar_type(_left[l_idx]) + scalar_r = scalar_type(_right[r_idx]) + expected = scalar_l >= scalar_r + scalar_o = bool(out[o_idx]) + f_l = sh.fmt_idx(ctx.left_sym, l_idx) + f_r = sh.fmt_idx(ctx.right_sym, r_idx) + f_o = sh.fmt_idx(ctx.res_name, o_idx) + assert scalar_o == expected, ( + f"{f_o}={scalar_o}, but should be ({f_l} >= {f_r})={expected} " + f"[{ctx.func_name}()]\n{f_l}={scalar_l}, {f_r}={scalar_r}" + ) @given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes())) @@ -1029,25 +1076,37 @@ def test_less(ctx, data): out = ctx.func(left, right) assert_binary_param_dtype(ctx, left, right, out, xp.bool) - if not ctx.right_is_scalar: - # TODO: generate indices without broadcasting arrays (see test_equal comment) - - shape = broadcast_shapes(left.shape, right.shape) - ph.assert_shape(ctx.func_name, out.shape, shape) - _left = xp.broadcast_to(left, shape) - _right = xp.broadcast_to(right, shape) - + assert_binary_param_shape(ctx, left, right, out) + if ctx.right_is_scalar: + scalar_type = dh.get_scalar_type(left.dtype) + for idx in sh.ndindex(left.shape): + scalar_l = scalar_type(left[idx]) + expected = scalar_l < right + scalar_o = bool(out[idx]) + f_l = sh.fmt_idx(ctx.left_sym, idx) + f_o = sh.fmt_idx(ctx.res_name, idx) + assert scalar_o == expected, ( + f"{f_o}={scalar_o}, but should be ({f_l} < {right})={expected} " + f"[{ctx.func_name}()]\n{f_l}={scalar_l}" + ) + else: + # See test_equal note promoted_dtype = dh.promotion_table[left.dtype, right.dtype] - _left = ah.asarray(_left, dtype=promoted_dtype) - _right = ah.asarray(_right, dtype=promoted_dtype) - + _left = xp.astype(left, promoted_dtype) + _right = xp.astype(right, promoted_dtype) scalar_type = dh.get_scalar_type(promoted_dtype) - for idx in sh.ndindex(shape): - x1_idx = _left[idx] - x2_idx = _right[idx] - out_idx = out[idx] - assert out_idx.shape == x1_idx.shape == x2_idx.shape # sanity check - assert bool(out_idx) == (scalar_type(x1_idx) < scalar_type(x2_idx)) + for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, out.shape): + scalar_l = scalar_type(_left[l_idx]) + scalar_r = scalar_type(_right[r_idx]) + expected = scalar_l < scalar_r + scalar_o = bool(out[o_idx]) + f_l = sh.fmt_idx(ctx.left_sym, l_idx) + f_r = sh.fmt_idx(ctx.right_sym, r_idx) + f_o = sh.fmt_idx(ctx.res_name, o_idx) + assert scalar_o == expected, ( + f"{f_o}={scalar_o}, but should be ({f_l} < {f_r})={expected} " + f"[{ctx.func_name}()]\n{f_l}={scalar_l}, {f_r}={scalar_r}" + ) @pytest.mark.parametrize("ctx", make_binary_params("less_equal", xps.numeric_dtypes())) @@ -1059,25 +1118,37 @@ def test_less_equal(ctx, data): out = ctx.func(left, right) assert_binary_param_dtype(ctx, left, right, out, xp.bool) - if not ctx.right_is_scalar: - # TODO: generate indices without broadcasting arrays (see test_equal comment) - - shape = broadcast_shapes(left.shape, right.shape) - ph.assert_shape(ctx.func_name, out.shape, shape) - _left = xp.broadcast_to(left, shape) - _right = xp.broadcast_to(right, shape) - + assert_binary_param_shape(ctx, left, right, out) + if ctx.right_is_scalar: + scalar_type = dh.get_scalar_type(left.dtype) + for idx in sh.ndindex(left.shape): + scalar_l = scalar_type(left[idx]) + expected = scalar_l <= right + scalar_o = bool(out[idx]) + f_l = sh.fmt_idx(ctx.left_sym, idx) + f_o = sh.fmt_idx(ctx.res_name, idx) + assert scalar_o == expected, ( + f"{f_o}={scalar_o}, but should be ({f_l} <= {right})={expected} " + f"[{ctx.func_name}()]\n{f_l}={scalar_l}" + ) + else: + # See test_equal note promoted_dtype = dh.promotion_table[left.dtype, right.dtype] - _left = ah.asarray(_left, dtype=promoted_dtype) - _right = ah.asarray(_right, dtype=promoted_dtype) - + _left = xp.astype(left, promoted_dtype) + _right = xp.astype(right, promoted_dtype) scalar_type = dh.get_scalar_type(promoted_dtype) - for idx in sh.ndindex(shape): - x1_idx = _left[idx] - x2_idx = _right[idx] - out_idx = out[idx] - assert out_idx.shape == x1_idx.shape == x2_idx.shape # sanity check - assert bool(out_idx) == (scalar_type(x1_idx) <= scalar_type(x2_idx)) + for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, out.shape): + scalar_l = scalar_type(_left[l_idx]) + scalar_r = scalar_type(_right[r_idx]) + expected = scalar_l <= scalar_r + scalar_o = bool(out[o_idx]) + f_l = sh.fmt_idx(ctx.left_sym, l_idx) + f_r = sh.fmt_idx(ctx.right_sym, r_idx) + f_o = sh.fmt_idx(ctx.res_name, o_idx) + assert scalar_o == expected, ( + f"{f_o}={scalar_o}, but should be ({f_l} <= {f_r})={expected} " + f"[{ctx.func_name}()]\n{f_l}={scalar_l}, {f_r}={scalar_r}" + ) @given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) @@ -1204,6 +1275,7 @@ def test_multiply(ctx, data): res = ctx.func(left, right) assert_binary_param_dtype(ctx, left, right, res) + assert_binary_param_shape(ctx, left, right, res) if not ctx.right_is_scalar: # multiply is commutative expected = ctx.func(right, left) @@ -1308,6 +1380,7 @@ def test_pow(ctx, data): reject() assert_binary_param_dtype(ctx, left, right, res) + assert_binary_param_shape(ctx, left, right, res) # There isn't much we can test here. The spec doesn't require any behavior # beyond the special cases, and indeed, there aren't many mathematical # properties of exponentiation that strictly hold for floating-point @@ -1333,6 +1406,7 @@ def test_remainder(ctx, data): res = ctx.func(left, right) assert_binary_param_dtype(ctx, left, right, res) + assert_binary_param_shape(ctx, left, right, res) # TODO: test results @@ -1414,6 +1488,7 @@ def test_subtract(ctx, data): reject() assert_binary_param_dtype(ctx, left, right, res) + assert_binary_param_shape(ctx, left, right, res) # TODO From 4b2c41e9b5e3656b9d6cecc9e29107d04d83e516 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Thu, 27 Jan 2022 09:52:01 +0000 Subject: [PATCH 07/40] Update `test_remainder` --- ...est_operators_and_elementwise_functions.py | 43 +++++++++++++++---- 1 file changed, 34 insertions(+), 9 deletions(-) diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index 6e26de67..345256a8 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -39,7 +39,7 @@ def boolean_and_all_integer_dtypes() -> st.SearchStrategy[DataType]: return xps.boolean_dtypes() | all_integer_dtypes() -def isclose(n1: float, n2: float): +def isclose(n1: Union[int, float], n2: Union[int, float]): if not (math.isfinite(n1) and math.isfinite(n2)): raise ValueError(f"{n1=} and {n1=}, but input must be finite") return math.isclose(n1, n2, rel_tol=0.25, abs_tol=1) @@ -1394,20 +1394,45 @@ def test_remainder(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) right = data.draw(ctx.right_strat, label=ctx.right_sym) if ctx.right_is_scalar: - out_dtype = left.dtype + assume(right != 0) else: - out_dtype = dh.result_type(left.dtype, right.dtype) - if dh.is_int_dtype(out_dtype): - if ctx.right_is_scalar: - assume(right != 0) - else: - assume(not ah.any(right == 0)) + assume(not ah.any(right == 0)) res = ctx.func(left, right) assert_binary_param_dtype(ctx, left, right, res) assert_binary_param_shape(ctx, left, right, res) - # TODO: test results + scalar_type = dh.get_scalar_type(res.dtype) + if ctx.right_is_scalar: + for idx in sh.ndindex(res.shape): + scalar_l = scalar_type(left[idx]) + expected = scalar_l % right + scalar_o = scalar_type(res[idx]) + if not all(math.isfinite(n) for n in [scalar_l, right, scalar_o, expected]): + continue + f_l = sh.fmt_idx(ctx.left_sym, idx) + f_o = sh.fmt_idx(ctx.res_name, idx) + assert isclose(scalar_o, expected), ( + f"{f_o}={scalar_o}, but should be roughly ({f_l} % {right})={expected} " + f"[{ctx.func_name}()]\n{f_l}={scalar_l}" + ) + else: + for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, res.shape): + scalar_l = scalar_type(left[l_idx]) + scalar_r = scalar_type(right[r_idx]) + expected = scalar_l % scalar_r + scalar_o = scalar_type(res[o_idx]) + if not all( + math.isfinite(n) for n in [scalar_l, scalar_r, scalar_o, expected] + ): + continue + f_l = sh.fmt_idx(ctx.left_sym, l_idx) + f_r = sh.fmt_idx(ctx.right_sym, r_idx) + f_o = sh.fmt_idx(ctx.res_name, o_idx) + assert isclose(scalar_o, expected), ( + f"{f_o}={scalar_o}, but should be roughly ({f_l} % {f_r})={expected} " + f"[{ctx.func_name}()]\n{f_l}={scalar_l}, {f_r}={scalar_r}" + ) @given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes())) From 1927c1094514c0ad1fdacd1a2925a4ccca104a90 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Thu, 27 Jan 2022 10:50:59 +0000 Subject: [PATCH 08/40] Move `broadcast_shapes()` to `shape_helpers.py` --- array_api_tests/hypothesis_helpers.py | 3 +- array_api_tests/meta/test_broadcasting.py | 8 +-- .../meta/test_hypothesis_helpers.py | 4 +- array_api_tests/pytest_helpers.py | 4 +- array_api_tests/shape_helpers.py | 49 +++++++++++++++++++ array_api_tests/test_data_type_functions.py | 6 +-- array_api_tests/test_linalg.py | 4 +- ...est_operators_and_elementwise_functions.py | 7 ++- array_api_tests/test_searching_functions.py | 3 +- array_api_tests/test_set_functions.py | 2 +- array_api_tests/test_sorting_functions.py | 2 +- array_api_tests/test_statistical_functions.py | 2 +- 12 files changed, 69 insertions(+), 25 deletions(-) diff --git a/array_api_tests/hypothesis_helpers.py b/array_api_tests/hypothesis_helpers.py index a0adc8c9..f77301da 100644 --- a/array_api_tests/hypothesis_helpers.py +++ b/array_api_tests/hypothesis_helpers.py @@ -16,7 +16,6 @@ from ._array_module import _UndefinedStub from ._array_module import bool as bool_dtype from ._array_module import broadcast_to, eye, float32, float64, full -from .algos import broadcast_shapes from .function_stubs import elementwise_functions from .pytest_helpers import nargs from .typing import Array, DataType, Shape @@ -243,7 +242,7 @@ def two_broadcastable_shapes(draw): broadcast to shape1. """ shape1, shape2 = draw(two_mutually_broadcastable_shapes) - assume(broadcast_shapes(shape1, shape2) == shape1) + assume(sh.broadcast_shapes(shape1, shape2) == shape1) return (shape1, shape2) sizes = integers(0, MAX_ARRAY_SIZE) diff --git a/array_api_tests/meta/test_broadcasting.py b/array_api_tests/meta/test_broadcasting.py index e347e525..72de61cf 100644 --- a/array_api_tests/meta/test_broadcasting.py +++ b/array_api_tests/meta/test_broadcasting.py @@ -4,7 +4,7 @@ import pytest -from ..algos import BroadcastError, _broadcast_shapes +from .. import shape_helpers as sh @pytest.mark.parametrize( @@ -19,7 +19,7 @@ ], ) def test_broadcast_shapes(shape1, shape2, expected): - assert _broadcast_shapes(shape1, shape2) == expected + assert sh._broadcast_shapes(shape1, shape2) == expected @pytest.mark.parametrize( @@ -31,5 +31,5 @@ def test_broadcast_shapes(shape1, shape2, expected): ], ) def test_broadcast_shapes_fails_on_bad_shapes(shape1, shape2): - with pytest.raises(BroadcastError): - _broadcast_shapes(shape1, shape2) + with pytest.raises(sh.BroadcastError): + sh._broadcast_shapes(shape1, shape2) diff --git a/array_api_tests/meta/test_hypothesis_helpers.py b/array_api_tests/meta/test_hypothesis_helpers.py index b4cb6e96..647cc145 100644 --- a/array_api_tests/meta/test_hypothesis_helpers.py +++ b/array_api_tests/meta/test_hypothesis_helpers.py @@ -8,9 +8,9 @@ from .. import array_helpers as ah from .. import dtype_helpers as dh from .. import hypothesis_helpers as hh +from .. import shape_helpers as sh from .. import xps from .._array_module import _UndefinedStub -from ..algos import broadcast_shapes UNDEFINED_DTYPES = any(isinstance(d, _UndefinedStub) for d in dh.all_dtypes) pytestmark = [pytest.mark.skipif(UNDEFINED_DTYPES, reason="undefined dtypes")] @@ -62,7 +62,7 @@ def test_two_mutually_broadcastable_shapes(pair): def test_two_broadcastable_shapes(pair): for shape in pair: assert valid_shape(shape) - assert broadcast_shapes(pair[0], pair[1]) == pair[0] + assert sh.broadcast_shapes(pair[0], pair[1]) == pair[0] @given(*hh.two_mutual_arrays()) diff --git a/array_api_tests/pytest_helpers.py b/array_api_tests/pytest_helpers.py index d1b48830..d07c09f5 100644 --- a/array_api_tests/pytest_helpers.py +++ b/array_api_tests/pytest_helpers.py @@ -6,7 +6,7 @@ from . import array_helpers as ah from . import dtype_helpers as dh from . import function_stubs -from .algos import broadcast_shapes +from . import shape_helpers as sh from .typing import Array, DataType, Scalar, ScalarType, Shape __all__ = [ @@ -159,7 +159,7 @@ def assert_result_shape( **kw, ): if expected is None: - expected = broadcast_shapes(*in_shapes) + expected = sh.broadcast_shapes(*in_shapes) f_in_shapes = " . ".join(str(s) for s in in_shapes) f_sig = f" {f_in_shapes} " if kw: diff --git a/array_api_tests/shape_helpers.py b/array_api_tests/shape_helpers.py index 98dbff85..c55fbaac 100644 --- a/array_api_tests/shape_helpers.py +++ b/array_api_tests/shape_helpers.py @@ -8,6 +8,7 @@ from .typing import AtomicIndex, Index, Scalar, Shape __all__ = [ + "broadcast_shapes", "normalise_axis", "ndindex", "axis_ndindex", @@ -17,6 +18,54 @@ ] +class BroadcastError(ValueError): + """Shapes do not broadcast with eachother""" + + +def _broadcast_shapes(shape1: Shape, shape2: Shape) -> Shape: + """Broadcasts `shape1` and `shape2`""" + N1 = len(shape1) + N2 = len(shape2) + N = max(N1, N2) + shape = [None for _ in range(N)] + i = N - 1 + while i >= 0: + n1 = N1 - N + i + if N1 - N + i >= 0: + d1 = shape1[n1] + else: + d1 = 1 + n2 = N2 - N + i + if N2 - N + i >= 0: + d2 = shape2[n2] + else: + d2 = 1 + + if d1 == 1: + shape[i] = d2 + elif d2 == 1: + shape[i] = d1 + elif d1 == d2: + shape[i] = d1 + else: + raise BroadcastError() + + i = i - 1 + + return tuple(shape) + + +def broadcast_shapes(*shapes: Shape): + if len(shapes) == 0: + raise ValueError("shapes=[] must be non-empty") + elif len(shapes) == 1: + return shapes[0] + result = _broadcast_shapes(shapes[0], shapes[1]) + for i in range(2, len(shapes)): + result = _broadcast_shapes(result, shapes[i]) + return result + + def normalise_axis( axis: Optional[Union[int, Tuple[int, ...]]], ndim: int ) -> Tuple[int, ...]: diff --git a/array_api_tests/test_data_type_functions.py b/array_api_tests/test_data_type_functions.py index ded82682..763c71a4 100644 --- a/array_api_tests/test_data_type_functions.py +++ b/array_api_tests/test_data_type_functions.py @@ -9,8 +9,8 @@ from . import dtype_helpers as dh from . import hypothesis_helpers as hh from . import pytest_helpers as ph +from . import shape_helpers as sh from . import xps -from .algos import broadcast_shapes from .typing import DataType pytestmark = pytest.mark.ci @@ -70,7 +70,7 @@ def test_broadcast_arrays(shapes, data): out = xp.broadcast_arrays(*arrays) - out_shape = broadcast_shapes(*shapes) + out_shape = sh.broadcast_shapes(*shapes) for i, x in enumerate(arrays): ph.assert_dtype( "broadcast_arrays", x.dtype, out[i].dtype, repr_name=f"out[{i}].dtype" @@ -90,7 +90,7 @@ def test_broadcast_to(x, data): shape = data.draw( hh.mutually_broadcastable_shapes(1, base_shape=x.shape) .map(lambda S: S[0]) - .filter(lambda s: broadcast_shapes(x.shape, s) == s), + .filter(lambda s: sh.broadcast_shapes(x.shape, s) == s), label="shape", ) diff --git a/array_api_tests/test_linalg.py b/array_api_tests/test_linalg.py index 62c93562..5f5ce2bd 100644 --- a/array_api_tests/test_linalg.py +++ b/array_api_tests/test_linalg.py @@ -31,8 +31,6 @@ from . import pytest_helpers as ph from . import shape_helpers as sh -from .algos import broadcast_shapes - from . import _array_module from . import _array_module as xp from ._array_module import linalg @@ -310,7 +308,7 @@ def test_matmul(x1, x2): assert res.shape == x1.shape[:-1] _test_stacks(_array_module.matmul, x1, x2, res=res, dims=1) else: - stack_shape = broadcast_shapes(x1.shape[:-2], x2.shape[:-2]) + stack_shape = sh.broadcast_shapes(x1.shape[:-2], x2.shape[:-2]) assert res.shape == stack_shape + (x1.shape[-2], x2.shape[-1]) _test_stacks(_array_module.matmul, x1, x2, res=res) diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index 345256a8..3bb06258 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -25,7 +25,6 @@ from . import pytest_helpers as ph from . import shape_helpers as sh from . import xps -from .algos import broadcast_shapes from .typing import Array, DataType, Param, Scalar, Shape pytestmark = pytest.mark.ci @@ -1223,7 +1222,7 @@ def test_logical_and(x1, x2): out = ah.logical_and(x1, x2) ph.assert_dtype("logical_and", (x1.dtype, x2.dtype), out.dtype) # See the comments in test_equal - shape = broadcast_shapes(x1.shape, x2.shape) + shape = sh.broadcast_shapes(x1.shape, x2.shape) ph.assert_shape("logical_and", out.shape, shape) _x1 = xp.broadcast_to(x1, shape) _x2 = xp.broadcast_to(x2, shape) @@ -1245,7 +1244,7 @@ def test_logical_or(x1, x2): out = ah.logical_or(x1, x2) ph.assert_dtype("logical_or", (x1.dtype, x2.dtype), out.dtype) # See the comments in test_equal - shape = broadcast_shapes(x1.shape, x2.shape) + shape = sh.broadcast_shapes(x1.shape, x2.shape) ph.assert_shape("logical_or", out.shape, shape) _x1 = xp.broadcast_to(x1, shape) _x2 = xp.broadcast_to(x2, shape) @@ -1258,7 +1257,7 @@ def test_logical_xor(x1, x2): out = xp.logical_xor(x1, x2) ph.assert_dtype("logical_xor", (x1.dtype, x2.dtype), out.dtype) # See the comments in test_equal - shape = broadcast_shapes(x1.shape, x2.shape) + shape = sh.broadcast_shapes(x1.shape, x2.shape) ph.assert_shape("logical_xor", out.shape, shape) _x1 = xp.broadcast_to(x1, shape) _x2 = xp.broadcast_to(x2, shape) diff --git a/array_api_tests/test_searching_functions.py b/array_api_tests/test_searching_functions.py index b6a66086..01c26d0c 100644 --- a/array_api_tests/test_searching_functions.py +++ b/array_api_tests/test_searching_functions.py @@ -8,7 +8,6 @@ from . import pytest_helpers as ph from . import shape_helpers as sh from . import xps -from .algos import broadcast_shapes pytestmark = pytest.mark.ci @@ -134,7 +133,7 @@ def test_where(shapes, dtypes, data): out = xp.where(cond, x1, x2) - shape = broadcast_shapes(*shapes) + shape = sh.broadcast_shapes(*shapes) ph.assert_shape("where", out.shape, shape) # TODO: generate indices without broadcasting arrays _cond = xp.broadcast_to(cond, shape) diff --git a/array_api_tests/test_set_functions.py b/array_api_tests/test_set_functions.py index 9679eaac..5ceceb54 100644 --- a/array_api_tests/test_set_functions.py +++ b/array_api_tests/test_set_functions.py @@ -1,8 +1,8 @@ # TODO: disable if opted out, refactor things import math -import pytest from collections import Counter, defaultdict +import pytest from hypothesis import assume, given from . import _array_module as xp diff --git a/array_api_tests/test_sorting_functions.py b/array_api_tests/test_sorting_functions.py index ea375b57..7c5a1411 100644 --- a/array_api_tests/test_sorting_functions.py +++ b/array_api_tests/test_sorting_functions.py @@ -1,7 +1,7 @@ import math -import pytest from typing import Set +import pytest from hypothesis import given from hypothesis import strategies as st from hypothesis.control import assume diff --git a/array_api_tests/test_statistical_functions.py b/array_api_tests/test_statistical_functions.py index c955b570..c86111a0 100644 --- a/array_api_tests/test_statistical_functions.py +++ b/array_api_tests/test_statistical_functions.py @@ -1,7 +1,7 @@ import math -import pytest from typing import Optional +import pytest from hypothesis import assume, given from hypothesis import strategies as st from hypothesis.control import reject From bb836b7fd850fe6ca9aaebb3f34e46e37308c89f Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Thu, 27 Jan 2022 10:55:11 +0000 Subject: [PATCH 09/40] Skip `sh.iter_indices()` generation for 0-sided shapes Also updates `test_logical_and` --- array_api_tests/shape_helpers.py | 16 ++++++++++----- ...est_operators_and_elementwise_functions.py | 20 ++++++++++++------- 2 files changed, 24 insertions(+), 12 deletions(-) diff --git a/array_api_tests/shape_helpers.py b/array_api_tests/shape_helpers.py index c55fbaac..2fcb671a 100644 --- a/array_api_tests/shape_helpers.py +++ b/array_api_tests/shape_helpers.py @@ -76,15 +76,21 @@ def normalise_axis( return axes -def ndindex(shape): - """Yield every index of shape""" +def ndindex(shape: Shape) -> Iterator[Index]: + """Yield every index of a shape""" return (indices[0] for indices in iter_indices(shape)) -def iter_indices(*shapes, skip_axes=()): +def iter_indices( + *shapes: Shape, skip_axes: Tuple[int, ...] = () +) -> Iterator[Tuple[Index, ...]]: """Wrapper for ndindex.iter_indices()""" - gen = _iter_indices(*shapes, skip_axes=skip_axes) - return ([i.raw for i in indices] for indices in gen) + # Prevent iterations if any shape has 0-sides + for shape in shapes: + if 0 in shape: + return + for indices in _iter_indices(*shapes, skip_axes=skip_axes): + yield tuple(i.raw for i in indices) # type: ignore def axis_ndindex( diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index 3bb06258..41a62f55 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -1221,13 +1221,19 @@ def test_logaddexp(x1, x2): def test_logical_and(x1, x2): out = ah.logical_and(x1, x2) ph.assert_dtype("logical_and", (x1.dtype, x2.dtype), out.dtype) - # See the comments in test_equal - shape = sh.broadcast_shapes(x1.shape, x2.shape) - ph.assert_shape("logical_and", out.shape, shape) - _x1 = xp.broadcast_to(x1, shape) - _x2 = xp.broadcast_to(x2, shape) - for idx in sh.ndindex(shape): - assert out[idx] == (bool(_x1[idx]) and bool(_x2[idx])) + ph.assert_result_shape("logical_and", (x1.shape, x2.shape), out.shape) + for l_idx, r_idx, o_idx in sh.iter_indices(x1.shape, x2.shape, out.shape): + scalar_l = bool(x1[l_idx]) + scalar_r = bool(x2[r_idx]) + expected = scalar_l and scalar_r + scalar_o = bool(out[o_idx]) + f_l = sh.fmt_idx("x1", l_idx) + f_r = sh.fmt_idx("x2", r_idx) + f_o = sh.fmt_idx("out", o_idx) + assert scalar_o == expected, ( + f"{f_o}={scalar_o}, but should be ({f_l} and {f_r})={expected} " + f"[logical_and()]\n{f_l}={scalar_l}, {f_r}={scalar_r}" + ) @given(xps.arrays(dtype=xp.bool, shape=hh.shapes())) From f11a6d0aa2001818ea05b09bcb45160873647bca Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Thu, 27 Jan 2022 12:10:35 +0000 Subject: [PATCH 10/40] Values testing for `test_sign` --- ...test_operators_and_elementwise_functions.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index 41a62f55..c688ac28 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -1473,7 +1473,23 @@ def test_sign(x): out = xp.sign(x) ph.assert_dtype("sign", x.dtype, out.dtype) ph.assert_shape("sign", out.shape, x.shape) - # TODO + scalar_type = dh.get_scalar_type(x.dtype) + for idx in sh.ndindex(x.shape): + scalar_x = scalar_type(x[idx]) + f_x = sh.fmt_idx("x", idx) + if math.isnan(scalar_x): + continue + if scalar_x == 0: + expected = 0 + expr = f"{f_x}=0" + else: + expected = 1 if scalar_x > 0 else -1 + expr = f"({f_x} / |{f_x}|)={expected}" + scalar_o = scalar_type(out[idx]) + f_o = sh.fmt_idx("out", idx) + assert scalar_o == expected, ( + f"{f_o}={scalar_o}, but should be {expr} [sign()]\n{f_x}={scalar_x}" + ) @given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) From 47424e8c14b1736c4c21a64378a8a8bd0cfe3648 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Thu, 27 Jan 2022 12:46:13 +0000 Subject: [PATCH 11/40] Values testing for `test_add` and `test_subtract` --- ...est_operators_and_elementwise_functions.py | 72 ++++++++++++++++--- 1 file changed, 64 insertions(+), 8 deletions(-) diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index c688ac28..34cb32ac 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -310,10 +310,37 @@ def test_add(ctx, data): assert_binary_param_dtype(ctx, left, right, res) assert_binary_param_shape(ctx, left, right, res) - if not ctx.right_is_scalar: - # add is commutative - expected = ctx.func(right, left) - ah.assert_exactly_equal(res, expected) + m, M = dh.dtype_ranges[res.dtype] + scalar_type = dh.get_scalar_type(res.dtype) + if ctx.right_is_scalar: + for idx in sh.ndindex(res.shape): + scalar_l = scalar_type(left[idx]) + expected = scalar_l + right + if not math.isfinite(expected) or expected <= m or expected >= M: + continue + scalar_o = scalar_type(res[idx]) + f_l = sh.fmt_idx(ctx.left_sym, idx) + f_o = sh.fmt_idx(ctx.res_name, idx) + assert isclose(scalar_o, expected), ( + f"{f_o}={scalar_o}, but should be roughly ({f_l} + {right})={expected} " + f"[{ctx.func_name}()]\n{f_l}={scalar_l}" + ) + else: + ph.assert_array(ctx.func_name, res, ctx.func(right, left)) # cumulative + for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, res.shape): + scalar_l = scalar_type(left[l_idx]) + scalar_r = scalar_type(right[r_idx]) + expected = scalar_l + scalar_r + if not math.isfinite(expected) or expected <= m or expected >= M: + continue + scalar_o = scalar_type(res[o_idx]) + f_l = sh.fmt_idx(ctx.left_sym, l_idx) + f_r = sh.fmt_idx(ctx.right_sym, r_idx) + f_o = sh.fmt_idx(ctx.res_name, o_idx) + assert isclose(scalar_o, expected), ( + f"{f_o}={scalar_o}, but should be roughly ({f_l} + {f_r})={expected} " + f"[{ctx.func_name}()]\n{f_l}={scalar_l}, {f_r}={scalar_r}" + ) @given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) @@ -1487,9 +1514,9 @@ def test_sign(x): expr = f"({f_x} / |{f_x}|)={expected}" scalar_o = scalar_type(out[idx]) f_o = sh.fmt_idx("out", idx) - assert scalar_o == expected, ( - f"{f_o}={scalar_o}, but should be {expr} [sign()]\n{f_x}={scalar_x}" - ) + assert ( + scalar_o == expected + ), f"{f_o}={scalar_o}, but should be {expr} [sign()]\n{f_x}={scalar_x}" @given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) @@ -1535,7 +1562,36 @@ def test_subtract(ctx, data): assert_binary_param_dtype(ctx, left, right, res) assert_binary_param_shape(ctx, left, right, res) - # TODO + m, M = dh.dtype_ranges[res.dtype] + scalar_type = dh.get_scalar_type(res.dtype) + if ctx.right_is_scalar: + for idx in sh.ndindex(res.shape): + scalar_l = scalar_type(left[idx]) + expected = scalar_l - right + if not math.isfinite(expected) or expected <= m or expected >= M: + continue + scalar_o = scalar_type(res[idx]) + f_l = sh.fmt_idx(ctx.left_sym, idx) + f_o = sh.fmt_idx(ctx.res_name, idx) + assert isclose(scalar_o, expected), ( + f"{f_o}={scalar_o}, but should be roughly ({f_l} - {right})={expected} " + f"[{ctx.func_name}()]\n{f_l}={scalar_l}" + ) + else: + for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, res.shape): + scalar_l = scalar_type(left[l_idx]) + scalar_r = scalar_type(right[r_idx]) + expected = scalar_l - scalar_r + if not math.isfinite(expected) or expected <= m or expected >= M: + continue + scalar_o = scalar_type(res[o_idx]) + f_l = sh.fmt_idx(ctx.left_sym, l_idx) + f_r = sh.fmt_idx(ctx.right_sym, r_idx) + f_o = sh.fmt_idx(ctx.res_name, o_idx) + assert isclose(scalar_o, expected), ( + f"{f_o}={scalar_o}, but should be roughly ({f_l} - {f_r})={expected} " + f"[{ctx.func_name}()]\n{f_l}={scalar_l}, {f_r}={scalar_r}" + ) @given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) From 20779860504a692dd797d5b8d63192eb91db7f50 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Fri, 28 Jan 2022 10:40:24 +0000 Subject: [PATCH 12/40] Rudimentary values testing refactor, updates to logical elwise tests --- ...est_operators_and_elementwise_functions.py | 108 +++++++++++++----- 1 file changed, 78 insertions(+), 30 deletions(-) diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index 34cb32ac..75617479 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -25,7 +25,7 @@ from . import pytest_helpers as ph from . import shape_helpers as sh from . import xps -from .typing import Array, DataType, Param, Scalar, Shape +from .typing import Array, DataType, Param, Scalar, ScalarType, Shape pytestmark = pytest.mark.ci @@ -38,12 +38,68 @@ def boolean_and_all_integer_dtypes() -> st.SearchStrategy[DataType]: return xps.boolean_dtypes() | all_integer_dtypes() -def isclose(n1: Union[int, float], n2: Union[int, float]): +def isclose(n1: Union[int, float], n2: Union[int, float]) -> bool: if not (math.isfinite(n1) and math.isfinite(n2)): raise ValueError(f"{n1=} and {n1=}, but input must be finite") return math.isclose(n1, n2, rel_tol=0.25, abs_tol=1) +def unary_assert_against_refimpl( + func_name: str, + in_stype: ScalarType, + in_: Array, + res: Array, + refimpl: Callable[[Scalar], Scalar], + expr_template: str, + res_stype: Optional[ScalarType] = None, +): + if in_.shape != res.shape: + raise ValueError(f"{res.shape=}, but should be {in_.shape=}") + if res_stype is None: + res_stype = in_stype + for idx in sh.ndindex(in_.shape): + scalar_i = in_stype(in_[idx]) + expected = refimpl(scalar_i) + scalar_o = res_stype(res[idx]) + f_i = sh.fmt_idx("x", idx) + f_o = sh.fmt_idx("out", idx) + expr = expr_template.format(scalar_i, expected) + assert scalar_o == expected, ( + f"{f_o}={scalar_o}, but should be {expr} [{func_name}()]\n" + f"{f_i}={scalar_i}" + ) + + +def binary_assert_against_refimpl( + func_name: str, + in_stype: ScalarType, + left: Array, + right: Array, + res: Array, + refimpl: Callable[[Scalar, Scalar], Scalar], + expr_template: str, + res_stype: Optional[ScalarType] = None, + left_sym: str = "x1", + right_sym: str = "x2", + res_sym: str = "out", +): + if res_stype is None: + res_stype = in_stype + for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, res.shape): + scalar_l = in_stype(left[l_idx]) + scalar_r = in_stype(right[r_idx]) + expected = refimpl(scalar_l, scalar_r) + scalar_o = res_stype(res[o_idx]) + f_l = sh.fmt_idx(left_sym, l_idx) + f_r = sh.fmt_idx(right_sym, r_idx) + f_o = sh.fmt_idx(res_sym, o_idx) + expr = expr_template.format(scalar_l, scalar_r, expected) + assert scalar_o == expected, ( + f"{f_o}={scalar_o}, but should be {expr} [{func_name}()]\n" + f"{f_l}={scalar_l}, {f_r}={scalar_r}" + ) + + # When appropiate, this module tests operators alongside their respective # elementwise methods. We do this by parametrizing a generalised test method # with every relevant method and operator. @@ -1249,18 +1305,15 @@ def test_logical_and(x1, x2): out = ah.logical_and(x1, x2) ph.assert_dtype("logical_and", (x1.dtype, x2.dtype), out.dtype) ph.assert_result_shape("logical_and", (x1.shape, x2.shape), out.shape) - for l_idx, r_idx, o_idx in sh.iter_indices(x1.shape, x2.shape, out.shape): - scalar_l = bool(x1[l_idx]) - scalar_r = bool(x2[r_idx]) - expected = scalar_l and scalar_r - scalar_o = bool(out[o_idx]) - f_l = sh.fmt_idx("x1", l_idx) - f_r = sh.fmt_idx("x2", r_idx) - f_o = sh.fmt_idx("out", o_idx) - assert scalar_o == expected, ( - f"{f_o}={scalar_o}, but should be ({f_l} and {f_r})={expected} " - f"[logical_and()]\n{f_l}={scalar_l}, {f_r}={scalar_r}" - ) + binary_assert_against_refimpl( + "logical_and", + bool, + x1, + x2, + out, + lambda l, r: l and r, + "({} and {})={}", + ) @given(xps.arrays(dtype=xp.bool, shape=hh.shapes())) @@ -1268,34 +1321,29 @@ def test_logical_not(x): out = ah.logical_not(x) ph.assert_dtype("logical_not", x.dtype, out.dtype) ph.assert_shape("logical_not", out.shape, x.shape) - for idx in sh.ndindex(x.shape): - assert out[idx] == (not bool(x[idx])) + unary_assert_against_refimpl( + "logical_not", bool, x, out, lambda i: not i, "(not {})={}" + ) @given(*hh.two_mutual_arrays([xp.bool])) def test_logical_or(x1, x2): out = ah.logical_or(x1, x2) ph.assert_dtype("logical_or", (x1.dtype, x2.dtype), out.dtype) - # See the comments in test_equal - shape = sh.broadcast_shapes(x1.shape, x2.shape) - ph.assert_shape("logical_or", out.shape, shape) - _x1 = xp.broadcast_to(x1, shape) - _x2 = xp.broadcast_to(x2, shape) - for idx in sh.ndindex(shape): - assert out[idx] == (bool(_x1[idx]) or bool(_x2[idx])) + ph.assert_result_shape("logical_or", (x1.shape, x2.shape), out.shape) + binary_assert_against_refimpl( + "logical_or", bool, x1, x2, out, lambda l, r: l or r, "({} or {})={}" + ) @given(*hh.two_mutual_arrays([xp.bool])) def test_logical_xor(x1, x2): out = xp.logical_xor(x1, x2) ph.assert_dtype("logical_xor", (x1.dtype, x2.dtype), out.dtype) - # See the comments in test_equal - shape = sh.broadcast_shapes(x1.shape, x2.shape) - ph.assert_shape("logical_xor", out.shape, shape) - _x1 = xp.broadcast_to(x1, shape) - _x2 = xp.broadcast_to(x2, shape) - for idx in sh.ndindex(shape): - assert out[idx] == (bool(_x1[idx]) ^ bool(_x2[idx])) + ph.assert_result_shape("logical_xor", (x1.shape, x2.shape), out.shape) + binary_assert_against_refimpl( + "logical_xor", bool, x1, x2, out, lambda l, r: l ^ r, "({} ^ {})={}" + ) @pytest.mark.parametrize("ctx", make_binary_params("multiply", xps.numeric_dtypes())) From 66a1fd45e22d4c3d63ac707408336033a55cf336 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Fri, 28 Jan 2022 11:15:28 +0000 Subject: [PATCH 13/40] Favour lists compared to tuples for `ph.assert_dtypes()` Tuples give the impression of `in_dtype` being hetereogenous --- array_api_tests/meta/test_pytest_helpers.py | 10 +++++----- array_api_tests/pytest_helpers.py | 9 ++++----- array_api_tests/test_creation_functions.py | 12 ++++++------ array_api_tests/test_linalg.py | 2 +- .../test_operators_and_elementwise_functions.py | 12 ++++++------ array_api_tests/test_type_promotion.py | 4 ++-- 6 files changed, 24 insertions(+), 25 deletions(-) diff --git a/array_api_tests/meta/test_pytest_helpers.py b/array_api_tests/meta/test_pytest_helpers.py index 9b0f4fad..21da2264 100644 --- a/array_api_tests/meta/test_pytest_helpers.py +++ b/array_api_tests/meta/test_pytest_helpers.py @@ -5,9 +5,9 @@ def test_assert_dtype(): - ph.assert_dtype("promoted_func", (xp.uint8, xp.int8), xp.int16) + ph.assert_dtype("promoted_func", [xp.uint8, xp.int8], xp.int16) with raises(AssertionError): - ph.assert_dtype("bad_func", (xp.uint8, xp.int8), xp.float32) - ph.assert_dtype("bool_func", (xp.uint8, xp.int8), xp.bool, xp.bool) - ph.assert_dtype("single_promoted_func", (xp.uint8,), xp.uint8) - ph.assert_dtype("single_bool_func", (xp.uint8,), xp.bool, xp.bool) + ph.assert_dtype("bad_func", [xp.uint8, xp.int8], xp.float32) + ph.assert_dtype("bool_func", [xp.uint8, xp.int8], xp.bool, xp.bool) + ph.assert_dtype("single_promoted_func", [xp.uint8], xp.uint8) + ph.assert_dtype("single_bool_func", [xp.uint8], xp.bool, xp.bool) diff --git a/array_api_tests/pytest_helpers.py b/array_api_tests/pytest_helpers.py index d07c09f5..bcd44513 100644 --- a/array_api_tests/pytest_helpers.py +++ b/array_api_tests/pytest_helpers.py @@ -1,6 +1,6 @@ import math from inspect import getfullargspec -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Sequence, Tuple, Union from . import _array_module as xp from . import array_helpers as ah @@ -71,15 +71,14 @@ def fmt_kw(kw: Dict[str, Any]) -> str: def assert_dtype( func_name: str, - in_dtypes: Union[DataType, Tuple[DataType, ...]], + in_dtype: Union[DataType, Sequence[DataType]], out_dtype: DataType, expected: Optional[DataType] = None, *, repr_name: str = "out.dtype", ): - if not isinstance(in_dtypes, tuple): - in_dtypes = (in_dtypes,) - f_in_dtypes = dh.fmt_types(in_dtypes) + in_dtypes = in_dtype if isinstance(in_dtype, Sequence) else [in_dtype] + f_in_dtypes = dh.fmt_types(tuple(in_dtypes)) f_out_dtype = dh.dtype_to_name[out_dtype] if expected is None: expected = dh.result_type(*in_dtypes) diff --git a/array_api_tests/test_creation_functions.py b/array_api_tests/test_creation_functions.py index 9d6e7fe1..a81339d0 100644 --- a/array_api_tests/test_creation_functions.py +++ b/array_api_tests/test_creation_functions.py @@ -152,7 +152,7 @@ def test_arange(dtype, data): else: ph.assert_default_float("arange", out.dtype) else: - ph.assert_dtype("arange", (out.dtype,), dtype) + ph.assert_kw_dtype("arange", dtype, out.dtype) f_sig = ", ".join(str(n) for n in args) if len(kwargs) > 0: f_sig += f", {ph.fmt_kw(kwargs)}" @@ -302,7 +302,7 @@ def test_empty(shape, kw): def test_empty_like(x, kw): out = xp.empty_like(x, **kw) if kw.get("dtype", None) is None: - ph.assert_dtype("empty_like", (x.dtype,), out.dtype) + ph.assert_dtype("empty_like", x.dtype, out.dtype) else: ph.assert_kw_dtype("empty_like", kw["dtype"], out.dtype) ph.assert_shape("empty_like", out.shape, x.shape) @@ -399,7 +399,7 @@ def test_full_like(x, fill_value, kw): out = xp.full_like(x, fill_value, **kw) dtype = kw.get("dtype", None) or x.dtype if kw.get("dtype", None) is None: - ph.assert_dtype("full_like", (x.dtype,), out.dtype) + ph.assert_dtype("full_like", x.dtype, out.dtype) else: ph.assert_kw_dtype("full_like", kw["dtype"], out.dtype) ph.assert_shape("full_like", out.shape, x.shape) @@ -459,7 +459,7 @@ def test_linspace(num, dtype, endpoint, data): if dtype is None: ph.assert_default_float("linspace", out.dtype) else: - ph.assert_dtype("linspace", (out.dtype,), dtype) + ph.assert_kw_dtype("linspace", dtype, out.dtype) ph.assert_shape("linspace", out.shape, num, start=stop, stop=stop, num=num) f_func = f"[linspace({start}, {stop}, {num})]" if num > 0: @@ -529,7 +529,7 @@ def test_ones(shape, kw): def test_ones_like(x, kw): out = xp.ones_like(x, **kw) if kw.get("dtype", None) is None: - ph.assert_dtype("ones_like", (x.dtype,), out.dtype) + ph.assert_dtype("ones_like", x.dtype, out.dtype) else: ph.assert_kw_dtype("ones_like", kw["dtype"], out.dtype) ph.assert_shape("ones_like", out.shape, x.shape) @@ -565,7 +565,7 @@ def test_zeros(shape, kw): def test_zeros_like(x, kw): out = xp.zeros_like(x, **kw) if kw.get("dtype", None) is None: - ph.assert_dtype("zeros_like", (x.dtype,), out.dtype) + ph.assert_dtype("zeros_like", x.dtype, out.dtype) else: ph.assert_kw_dtype("zeros_like", kw["dtype"], out.dtype) ph.assert_shape("zeros_like", out.shape, x.shape) diff --git a/array_api_tests/test_linalg.py b/array_api_tests/test_linalg.py index 5f5ce2bd..7117c20b 100644 --- a/array_api_tests/test_linalg.py +++ b/array_api_tests/test_linalg.py @@ -297,7 +297,7 @@ def test_matmul(x1, x2): else: res = _array_module.matmul(x1, x2) - ph.assert_dtype("matmul", (x1.dtype, x2.dtype), res.dtype) + ph.assert_dtype("matmul", [x1.dtype, x2.dtype], res.dtype) if len(x1.shape) == len(x2.shape) == 1: assert res.shape == () diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index 75617479..61f23e7d 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -273,7 +273,7 @@ def assert_binary_param_dtype( if ctx.right_is_scalar: in_dtypes = left.dtype else: - in_dtypes = (left.dtype, right.dtype) # type: ignore + in_dtypes = [left.dtype, right.dtype] # type: ignore ph.assert_dtype( ctx.func_name, in_dtypes, res.dtype, expected, repr_name=f"{ctx.res_name}.dtype" ) @@ -443,7 +443,7 @@ def test_atan(x): @given(*hh.two_mutual_arrays(dh.float_dtypes)) def test_atan2(x1, x2): out = xp.atan2(x1, x2) - ph.assert_dtype("atan2", (x1.dtype, x2.dtype), out.dtype) + ph.assert_dtype("atan2", [x1.dtype, x2.dtype], out.dtype) ph.assert_result_shape("atan2", (x1.shape, x2.shape), out.shape) INFINITY1 = ah.infinity(x1.shape, x1.dtype) INFINITY2 = ah.infinity(x2.shape, x2.dtype) @@ -1294,7 +1294,7 @@ def test_log10(x): @given(*hh.two_mutual_arrays(dh.float_dtypes)) def test_logaddexp(x1, x2): out = xp.logaddexp(x1, x2) - ph.assert_dtype("logaddexp", (x1.dtype, x2.dtype), out.dtype) + ph.assert_dtype("logaddexp", [x1.dtype, x2.dtype], out.dtype) # The spec doesn't require any behavior for this function. We could test # that this is indeed an approximation of log(exp(x1) + exp(x2)), but we # don't have tests for this sort of thing for any functions yet. @@ -1303,7 +1303,7 @@ def test_logaddexp(x1, x2): @given(*hh.two_mutual_arrays([xp.bool])) def test_logical_and(x1, x2): out = ah.logical_and(x1, x2) - ph.assert_dtype("logical_and", (x1.dtype, x2.dtype), out.dtype) + ph.assert_dtype("logical_and", [x1.dtype, x2.dtype], out.dtype) ph.assert_result_shape("logical_and", (x1.shape, x2.shape), out.shape) binary_assert_against_refimpl( "logical_and", @@ -1329,7 +1329,7 @@ def test_logical_not(x): @given(*hh.two_mutual_arrays([xp.bool])) def test_logical_or(x1, x2): out = ah.logical_or(x1, x2) - ph.assert_dtype("logical_or", (x1.dtype, x2.dtype), out.dtype) + ph.assert_dtype("logical_or", [x1.dtype, x2.dtype], out.dtype) ph.assert_result_shape("logical_or", (x1.shape, x2.shape), out.shape) binary_assert_against_refimpl( "logical_or", bool, x1, x2, out, lambda l, r: l or r, "({} or {})={}" @@ -1339,7 +1339,7 @@ def test_logical_or(x1, x2): @given(*hh.two_mutual_arrays([xp.bool])) def test_logical_xor(x1, x2): out = xp.logical_xor(x1, x2) - ph.assert_dtype("logical_xor", (x1.dtype, x2.dtype), out.dtype) + ph.assert_dtype("logical_xor", [x1.dtype, x2.dtype], out.dtype) ph.assert_result_shape("logical_xor", (x1.shape, x2.shape), out.shape) binary_assert_against_refimpl( "logical_xor", bool, x1, x2, out, lambda l, r: l ^ r, "({} ^ {})={}" diff --git a/array_api_tests/test_type_promotion.py b/array_api_tests/test_type_promotion.py index b1e5a09b..575e9011 100644 --- a/array_api_tests/test_type_promotion.py +++ b/array_api_tests/test_type_promotion.py @@ -271,7 +271,7 @@ def test_op_scalar_promotion(op, expr, in_dtype, in_stype, out_dtype, data): out = eval(expr, {"x": x, "s": s}) except OverflowError: reject() - ph.assert_dtype(op, (in_dtype, in_stype), out.dtype, out_dtype) + ph.assert_dtype(op, [in_dtype, in_stype], out.dtype, out_dtype) inplace_scalar_params: List[Param[str, str, DataType, ScalarType]] = [] @@ -307,7 +307,7 @@ def test_inplace_op_scalar_promotion(op, expr, dtype, in_stype, data): reject() x = locals_["x"] assert x.dtype == dtype, f"{x.dtype=!s}, but should be {dtype}" - ph.assert_dtype(op, (dtype, in_stype), x.dtype, dtype, repr_name="x.dtype") + ph.assert_dtype(op, [dtype, in_stype], x.dtype, dtype, repr_name="x.dtype") if __name__ == "__main__": From b6d05dae7a378eff2206c8fbacf7e77f98503d93 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Fri, 28 Jan 2022 11:21:07 +0000 Subject: [PATCH 14/40] Favour lists for `ph.assert_result_shape()` --- array_api_tests/pytest_helpers.py | 2 +- array_api_tests/test_manipulation_functions.py | 10 +++++----- .../test_operators_and_elementwise_functions.py | 12 ++++++------ 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/array_api_tests/pytest_helpers.py b/array_api_tests/pytest_helpers.py index bcd44513..9a5ffbb2 100644 --- a/array_api_tests/pytest_helpers.py +++ b/array_api_tests/pytest_helpers.py @@ -149,7 +149,7 @@ def assert_shape( def assert_result_shape( func_name: str, - in_shapes: Tuple[Shape], + in_shapes: Sequence[Shape], out_shape: Shape, /, expected: Optional[Shape] = None, diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index 1ae28919..b9d9e03d 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -142,7 +142,7 @@ def test_expand_dims(x, axis): index = axis if axis >= 0 else x.ndim + axis + 1 shape.insert(index, 1) shape = tuple(shape) - ph.assert_result_shape("expand_dims", (x.shape,), out.shape, shape) + ph.assert_result_shape("expand_dims", [x.shape], out.shape, shape) assert_array_ndindex( "expand_dims", x, sh.ndindex(x.shape), out, sh.ndindex(out.shape) @@ -181,7 +181,7 @@ def test_squeeze(x, data): if i not in axes: shape.append(side) shape = tuple(shape) - ph.assert_result_shape("squeeze", (x.shape,), out.shape, shape, axis=axis) + ph.assert_result_shape("squeeze", [x.shape], out.shape, shape, axis=axis) assert_array_ndindex("squeeze", x, sh.ndindex(x.shape), out, sh.ndindex(out.shape)) @@ -230,7 +230,7 @@ def test_permute_dims(x, axes): side = x.shape[dim] shape[i] = side shape = tuple(shape) - ph.assert_result_shape("permute_dims", (x.shape,), out.shape, shape, axes=axes) + ph.assert_result_shape("permute_dims", [x.shape], out.shape, shape, axes=axes) indices = list(sh.ndindex(x.shape)) permuted_indices = [tuple(idx[axis] for axis in axes) for idx in indices] @@ -265,7 +265,7 @@ def test_reshape(x, data): rsize = math.prod(shape) * -1 _shape[shape.index(-1)] = size / rsize _shape = tuple(_shape) - ph.assert_result_shape("reshape", (x.shape,), out.shape, _shape, shape=shape) + ph.assert_result_shape("reshape", [x.shape], out.shape, _shape, shape=shape) assert_array_ndindex("reshape", x, sh.ndindex(x.shape), out, sh.ndindex(out.shape)) @@ -303,7 +303,7 @@ def test_roll(x, data): ph.assert_dtype("roll", x.dtype, out.dtype) - ph.assert_result_shape("roll", (x.shape,), out.shape) + ph.assert_result_shape("roll", [x.shape], out.shape) if kw.get("axis", None) is None: assert isinstance(shift, int) # sanity check diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index 61f23e7d..a053db5e 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -287,9 +287,9 @@ def assert_binary_param_shape( expected: Optional[Shape] = None, ): if ctx.right_is_scalar: - in_shapes = (left.shape,) + in_shapes = [left.shape] else: - in_shapes = (left.shape, right.shape) # type: ignore + in_shapes = [left.shape, right.shape] # type: ignore ph.assert_result_shape( ctx.func_name, in_shapes, res.shape, expected, repr_name=f"{ctx.res_name}.shape" ) @@ -444,7 +444,7 @@ def test_atan(x): def test_atan2(x1, x2): out = xp.atan2(x1, x2) ph.assert_dtype("atan2", [x1.dtype, x2.dtype], out.dtype) - ph.assert_result_shape("atan2", (x1.shape, x2.shape), out.shape) + ph.assert_result_shape("atan2", [x1.shape, x2.shape], out.shape) INFINITY1 = ah.infinity(x1.shape, x1.dtype) INFINITY2 = ah.infinity(x2.shape, x2.dtype) PI = ah.π(out.shape, out.dtype) @@ -1304,7 +1304,7 @@ def test_logaddexp(x1, x2): def test_logical_and(x1, x2): out = ah.logical_and(x1, x2) ph.assert_dtype("logical_and", [x1.dtype, x2.dtype], out.dtype) - ph.assert_result_shape("logical_and", (x1.shape, x2.shape), out.shape) + ph.assert_result_shape("logical_and", [x1.shape, x2.shape], out.shape) binary_assert_against_refimpl( "logical_and", bool, @@ -1330,7 +1330,7 @@ def test_logical_not(x): def test_logical_or(x1, x2): out = ah.logical_or(x1, x2) ph.assert_dtype("logical_or", [x1.dtype, x2.dtype], out.dtype) - ph.assert_result_shape("logical_or", (x1.shape, x2.shape), out.shape) + ph.assert_result_shape("logical_or", [x1.shape, x2.shape], out.shape) binary_assert_against_refimpl( "logical_or", bool, x1, x2, out, lambda l, r: l or r, "({} or {})={}" ) @@ -1340,7 +1340,7 @@ def test_logical_or(x1, x2): def test_logical_xor(x1, x2): out = xp.logical_xor(x1, x2) ph.assert_dtype("logical_xor", [x1.dtype, x2.dtype], out.dtype) - ph.assert_result_shape("logical_xor", (x1.shape, x2.shape), out.shape) + ph.assert_result_shape("logical_xor", [x1.shape, x2.shape], out.shape) binary_assert_against_refimpl( "logical_xor", bool, x1, x2, out, lambda l, r: l ^ r, "({} ^ {})={}" ) From af6d15020af9689f34f240d2340d2f55365e2aef Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Fri, 28 Jan 2022 13:08:56 +0000 Subject: [PATCH 15/40] Remove `lru_cache` use in `sh.fmt_idx()` Slices are not hashable! --- array_api_tests/shape_helpers.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/array_api_tests/shape_helpers.py b/array_api_tests/shape_helpers.py index 2fcb671a..9b3d001b 100644 --- a/array_api_tests/shape_helpers.py +++ b/array_api_tests/shape_helpers.py @@ -1,5 +1,4 @@ import math -from functools import lru_cache from itertools import product from typing import Iterator, List, Optional, Tuple, Union @@ -161,7 +160,6 @@ def fmt_i(i: AtomicIndex) -> str: return "..." -@lru_cache def fmt_idx(sym: str, idx: Index) -> str: if idx == (): return sym From 799b4e6badc561c79d5a8ac965b4c7c228648085 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Fri, 28 Jan 2022 15:29:26 +0000 Subject: [PATCH 16/40] Refactor parametrized unary tests Also moves `ah.int_to_dtype()` and renames it `mock_int_dtype()` --- array_api_tests/array_helpers.py | 11 -- array_api_tests/meta/test_array_helpers.py | 16 +-- array_api_tests/meta/test_utils.py | 14 +++ ...est_operators_and_elementwise_functions.py | 114 +++++++++--------- 4 files changed, 69 insertions(+), 86 deletions(-) diff --git a/array_api_tests/array_helpers.py b/array_api_tests/array_helpers.py index b3ae583c..ef4f719a 100644 --- a/array_api_tests/array_helpers.py +++ b/array_api_tests/array_helpers.py @@ -306,14 +306,3 @@ def same_sign(x, y): def assert_same_sign(x, y): assert all(same_sign(x, y)), "The input arrays do not have the same sign" -def int_to_dtype(x, n, signed): - """ - Convert the Python integer x into an n bit signed or unsigned number. - """ - mask = (1 << n) - 1 - x &= mask - if signed: - highest_bit = 1 << (n-1) - if x & highest_bit: - x = -((~x & mask) + 1) - return x diff --git a/array_api_tests/meta/test_array_helpers.py b/array_api_tests/meta/test_array_helpers.py index 6a6b4849..68f96910 100644 --- a/array_api_tests/meta/test_array_helpers.py +++ b/array_api_tests/meta/test_array_helpers.py @@ -1,10 +1,5 @@ -from hypothesis import given, assume -from hypothesis.strategies import integers - -from ..array_helpers import exactly_equal, notequal, int_to_dtype -from ..hypothesis_helpers import integer_dtypes -from ..dtype_helpers import dtype_nbits, dtype_signed from .. import _array_module as xp +from ..array_helpers import exactly_equal, notequal # TODO: These meta-tests currently only work with NumPy @@ -22,12 +17,3 @@ def test_notequal(): res = xp.asarray([False, True, False, False, False, True, False, True]) assert xp.all(xp.equal(notequal(a, b), res)) -@given(integers(), integer_dtypes) -def test_int_to_dtype(x, dtype): - n = dtype_nbits[dtype] - signed = dtype_signed[dtype] - try: - d = xp.asarray(x, dtype=dtype) - except OverflowError: - assume(False) - assert int_to_dtype(x, n, signed) == d diff --git a/array_api_tests/meta/test_utils.py b/array_api_tests/meta/test_utils.py index 1b188df3..3cd819b4 100644 --- a/array_api_tests/meta/test_utils.py +++ b/array_api_tests/meta/test_utils.py @@ -1,8 +1,13 @@ import pytest +from hypothesis import given, reject +from hypothesis import strategies as st +from .. import _array_module as xp +from .. import xps from .. import shape_helpers as sh from ..test_creation_functions import frange from ..test_manipulation_functions import roll_ndindex +from ..test_operators_and_elementwise_functions import mock_int_dtype from ..test_signatures import extension_module @@ -101,3 +106,12 @@ def test_roll_ndindex(shape, shifts, axes, expected): ) def test_fmt_idx(idx, expected): assert sh.fmt_idx("x", idx) == expected + + +@given(x=st.integers(), dtype=xps.unsigned_integer_dtypes() | xps.integer_dtypes()) +def test_int_to_dtype(x, dtype): + try: + d = xp.asarray(x, dtype=dtype) + except OverflowError: + reject() + assert mock_int_dtype(x, dtype) == d diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index a053db5e..e791c26b 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -10,6 +10,7 @@ """ import math +import operator from enum import Enum, auto from typing import Callable, List, NamedTuple, Optional, Union @@ -44,6 +45,18 @@ def isclose(n1: Union[int, float], n2: Union[int, float]) -> bool: return math.isclose(n1, n2, rel_tol=0.25, abs_tol=1) +def mock_int_dtype(n: int, dtype: DataType) -> int: + """Returns equivalent of `n` that mocks `dtype` behaviour""" + nbits = dh.dtype_nbits[dtype] + mask = (1 << nbits) - 1 + n &= mask + if dh.dtype_signed[dtype]: + highest_bit = 1 << (nbits - 1) + if n & highest_bit: + n = -((~n & mask) + 1) + return n + + def unary_assert_against_refimpl( func_name: str, in_stype: ScalarType, @@ -52,6 +65,7 @@ def unary_assert_against_refimpl( refimpl: Callable[[Scalar], Scalar], expr_template: str, res_stype: Optional[ScalarType] = None, + ignorer: Callable[[Scalar], bool] = bool, ): if in_.shape != res.shape: raise ValueError(f"{res.shape=}, but should be {in_.shape=}") @@ -59,6 +73,8 @@ def unary_assert_against_refimpl( res_stype = in_stype for idx in sh.ndindex(in_.shape): scalar_i = in_stype(in_[idx]) + if ignorer(scalar_i): + continue expected = refimpl(scalar_i) scalar_o = res_stype(res[idx]) f_i = sh.fmt_idx("x", idx) @@ -299,25 +315,22 @@ def assert_binary_param_shape( @given(data=st.data()) def test_abs(ctx, data): x = data.draw(ctx.strat, label="x") + # abs of the smallest negative integer is out-of-scope if x.dtype in dh.int_dtypes: - # abs of the smallest representable negative integer is not defined - mask = xp.not_equal( - x, ah.full(x.shape, dh.dtype_ranges[x.dtype].min, dtype=x.dtype) - ) - x = x[mask] + assume(xp.all(x > dh.dtype_ranges[x.dtype].min)) + out = ctx.func(x) + ph.assert_dtype(ctx.func_name, x.dtype, out.dtype) ph.assert_shape(ctx.func_name, out.shape, x.shape) - assert ah.all( - ah.logical_not(ah.negative_mathematical_sign(out)) - ), f"out elements not all positively signed [{ctx.func_name}()]\n{out=}" - less_zero = ah.negative_mathematical_sign(x) - negx = ah.negative(x) - # abs(x) = -x for x < 0 - ah.assert_exactly_equal(out[less_zero], negx[less_zero]) - # abs(x) = x for x >= 0 - ah.assert_exactly_equal( - out[ah.logical_not(less_zero)], x[ah.logical_not(less_zero)] + unary_assert_against_refimpl( + ctx.func_name, + dh.get_scalar_type(x.dtype), + x, + out, + abs, + "abs({})={}", + ignorer=lambda s: math.isnan(s) or s is -0.0 or s == float("-infinity"), ) @@ -518,7 +531,7 @@ def test_bitwise_and(ctx, data): # for mypy assert isinstance(scalar_l, int) assert isinstance(right, int) - expected = ah.int_to_dtype( + expected = ah.mock_int_dtype( scalar_l & right, dh.dtype_nbits[res.dtype], dh.dtype_signed[res.dtype], @@ -540,7 +553,7 @@ def test_bitwise_and(ctx, data): # for mypy assert isinstance(scalar_l, int) assert isinstance(scalar_r, int) - expected = ah.int_to_dtype( + expected = ah.mock_int_dtype( scalar_l & scalar_r, dh.dtype_nbits[res.dtype], dh.dtype_signed[res.dtype], @@ -574,7 +587,7 @@ def test_bitwise_left_shift(ctx, data): if ctx.right_is_scalar: for idx in sh.ndindex(res.shape): scalar_l = int(left[idx]) - expected = ah.int_to_dtype( + expected = ah.mock_int_dtype( # We avoid shifting very large ints scalar_l << right if right < dh.dtype_nbits[res.dtype] else 0, dh.dtype_nbits[res.dtype], @@ -591,7 +604,7 @@ def test_bitwise_left_shift(ctx, data): for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, res.shape): scalar_l = int(left[l_idx]) scalar_r = int(right[r_idx]) - expected = ah.int_to_dtype( + expected = ah.mock_int_dtype( # We avoid shifting very large ints scalar_l << scalar_r if scalar_r < dh.dtype_nbits[res.dtype] else 0, dh.dtype_nbits[res.dtype], @@ -608,8 +621,7 @@ def test_bitwise_left_shift(ctx, data): @pytest.mark.parametrize( - "ctx", - make_unary_params("bitwise_invert", boolean_and_all_integer_dtypes()), + "ctx", make_unary_params("bitwise_invert", boolean_and_all_integer_dtypes()) ) @given(data=st.data()) def test_bitwise_invert(ctx, data): @@ -619,23 +631,14 @@ def test_bitwise_invert(ctx, data): ph.assert_dtype(ctx.func_name, x.dtype, out.dtype) ph.assert_shape(ctx.func_name, out.shape, x.shape) - for idx in sh.ndindex(out.shape): - if out.dtype == xp.bool: - scalar_x = bool(x[idx]) - scalar_o = bool(out[idx]) - expected = not scalar_x - else: - scalar_x = int(x[idx]) - scalar_o = int(out[idx]) - expected = ah.int_to_dtype( - ~scalar_x, dh.dtype_nbits[out.dtype], dh.dtype_signed[out.dtype] - ) - f_x = sh.fmt_idx("x", idx) - f_o = sh.fmt_idx("out", idx) - assert scalar_o == expected, ( - f"{f_o}={scalar_o}, but should be ~{f_x}={scalar_x} " - f"[{ctx.func_name}()]\n{f_x}={scalar_x}" - ) + if x.dtype == xp.bool: + # invert op for booleans is weird, so use not + refimpl = lambda s: not s + else: + refimpl = lambda s: mock_int_dtype(~s, x.dtype) + unary_assert_against_refimpl( + ctx.func_name, dh.get_scalar_type(x.dtype), x, out, refimpl, "~{}={}" + ) @pytest.mark.parametrize( @@ -659,7 +662,7 @@ def test_bitwise_or(ctx, data): else: scalar_l = int(left[idx]) scalar_o = int(res[idx]) - expected = ah.int_to_dtype( + expected = ah.mock_int_dtype( scalar_l | right, dh.dtype_nbits[res.dtype], dh.dtype_signed[res.dtype], @@ -681,7 +684,7 @@ def test_bitwise_or(ctx, data): scalar_l = int(left[l_idx]) scalar_r = int(right[r_idx]) scalar_o = int(res[o_idx]) - expected = ah.int_to_dtype( + expected = ah.mock_int_dtype( scalar_l | scalar_r, dh.dtype_nbits[res.dtype], dh.dtype_signed[res.dtype], @@ -714,7 +717,7 @@ def test_bitwise_right_shift(ctx, data): if ctx.right_is_scalar: for idx in sh.ndindex(res.shape): scalar_l = int(left[idx]) - expected = ah.int_to_dtype( + expected = ah.mock_int_dtype( scalar_l >> right, dh.dtype_nbits[res.dtype], dh.dtype_signed[res.dtype], @@ -730,7 +733,7 @@ def test_bitwise_right_shift(ctx, data): for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, res.shape): scalar_l = int(left[l_idx]) scalar_r = int(right[r_idx]) - expected = ah.int_to_dtype( + expected = ah.mock_int_dtype( scalar_l >> scalar_r, dh.dtype_nbits[res.dtype], dh.dtype_signed[res.dtype], @@ -766,7 +769,7 @@ def test_bitwise_xor(ctx, data): else: scalar_l = int(left[idx]) scalar_o = int(res[idx]) - expected = ah.int_to_dtype( + expected = ah.mock_int_dtype( scalar_l ^ right, dh.dtype_nbits[res.dtype], dh.dtype_signed[res.dtype], @@ -788,7 +791,7 @@ def test_bitwise_xor(ctx, data): scalar_l = int(left[l_idx]) scalar_r = int(right[r_idx]) scalar_o = int(res[o_idx]) - expected = ah.int_to_dtype( + expected = ah.mock_int_dtype( scalar_l ^ scalar_r, dh.dtype_nbits[res.dtype], dh.dtype_signed[res.dtype], @@ -1366,25 +1369,17 @@ def test_multiply(ctx, data): @given(data=st.data()) def test_negative(ctx, data): x = data.draw(ctx.strat, label="x") + # negative of the smallest negative integer is out-of-scope + if x.dtype in dh.int_dtypes: + assume(xp.all(x > dh.dtype_ranges[x.dtype].min)) out = ctx.func(x) ph.assert_dtype(ctx.func_name, x.dtype, out.dtype) ph.assert_shape(ctx.func_name, out.shape, x.shape) - - # Negation is an involution - ah.assert_exactly_equal(x, ctx.func(out)) - - mask = ah.isfinite(x) - if dh.is_int_dtype(x.dtype): - minval = dh.dtype_ranges[x.dtype][0] - if minval < 0: - # negative of the smallest representable negative integer is not defined - mask = xp.not_equal(x, ah.full(x.shape, minval, dtype=x.dtype)) - - # Additive inverse - y = xp.add(x[mask], out[mask]) - ah.assert_exactly_equal(y, ah.zero(x[mask].shape, x.dtype)) + unary_assert_against_refimpl( + ctx.func_name, dh.get_scalar_type(x.dtype), x, out, operator.neg, "-({})={}" + ) @pytest.mark.parametrize("ctx", make_binary_params("not_equal", xps.scalar_dtypes())) @@ -1438,8 +1433,7 @@ def test_positive(ctx, data): ph.assert_dtype(ctx.func_name, x.dtype, out.dtype) ph.assert_shape(ctx.func_name, out.shape, x.shape) - # Positive does nothing - ah.assert_exactly_equal(out, x) + ph.assert_array(ctx.func_name, out, x) @pytest.mark.parametrize("ctx", make_binary_params("pow", xps.numeric_dtypes())) From e2b69df50ee25e47c22b2b1d0109be21fc065ffd Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Mon, 31 Jan 2022 12:17:47 +0000 Subject: [PATCH 17/40] Op/elwise fixes and improvements - Fix old usage of `mock_int_dtype` - Infer `in_stype` - Allow scalar `right` for `binary_assert_against_refimpl()` - Use util in `test_add` --- ...est_operators_and_elementwise_functions.py | 217 +++++++++--------- 1 file changed, 107 insertions(+), 110 deletions(-) diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index e791c26b..e530f50c 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -59,16 +59,18 @@ def mock_int_dtype(n: int, dtype: DataType) -> int: def unary_assert_against_refimpl( func_name: str, - in_stype: ScalarType, in_: Array, res: Array, refimpl: Callable[[Scalar], Scalar], expr_template: str, + in_stype: Optional[ScalarType] = None, res_stype: Optional[ScalarType] = None, ignorer: Callable[[Scalar], bool] = bool, ): if in_.shape != res.shape: raise ValueError(f"{res.shape=}, but should be {in_.shape=}") + if in_stype is None: + in_stype = dh.get_scalar_type(in_.dtype) if res_stype is None: res_stype = in_stype for idx in sh.ndindex(in_.shape): @@ -88,32 +90,77 @@ def unary_assert_against_refimpl( def binary_assert_against_refimpl( func_name: str, - in_stype: ScalarType, left: Array, - right: Array, + right: Union[Scalar, Array], res: Array, refimpl: Callable[[Scalar, Scalar], Scalar], expr_template: str, + in_stype: Optional[ScalarType] = None, res_stype: Optional[ScalarType] = None, left_sym: str = "x1", right_sym: str = "x2", - res_sym: str = "out", + right_is_scalar: bool = False, + res_name: str = "out", + ignorer: Callable[[Scalar, Scalar], bool] = bool, ): + if in_stype is None: + in_stype = dh.get_scalar_type(left.dtype) if res_stype is None: res_stype = in_stype - for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, res.shape): - scalar_l = in_stype(left[l_idx]) - scalar_r = in_stype(right[r_idx]) - expected = refimpl(scalar_l, scalar_r) - scalar_o = res_stype(res[o_idx]) - f_l = sh.fmt_idx(left_sym, l_idx) - f_r = sh.fmt_idx(right_sym, r_idx) - f_o = sh.fmt_idx(res_sym, o_idx) - expr = expr_template.format(scalar_l, scalar_r, expected) - assert scalar_o == expected, ( - f"{f_o}={scalar_o}, but should be {expr} [{func_name}()]\n" - f"{f_l}={scalar_l}, {f_r}={scalar_r}" - ) + if right_is_scalar: + if left.dtype != xp.bool: + m, M = dh.dtype_ranges[left.dtype] + for idx in sh.ndindex(res.shape): + scalar_l = in_stype(left[idx]) + if any(ignorer(s) for s in [scalar_l, right]): + continue + expected = refimpl(scalar_l, right) + if left.dtype != xp.bool: + if expected <= m or expected >= M: + continue + scalar_o = res_stype(res[idx]) + f_l = sh.fmt_idx(left_sym, idx) + f_o = sh.fmt_idx(res_name, idx) + expr = expr_template.format(scalar_l, right, expected) + if dh.is_float_dtype(left.dtype): + assert isclose(scalar_o, expected), ( + f"{f_o}={scalar_o}, but should be roughly {expr} [{func_name}()]\n" + f"{f_l}={scalar_l}" + ) + + else: + assert scalar_o == expected, ( + f"{f_o}={scalar_o}, but should be {expr} [{func_name}()]\n" + f"{f_l}={scalar_l}" + ) + else: + result_dtype = dh.result_type(left.dtype, right.dtype) + if result_dtype != xp.bool: + m, M = dh.dtype_ranges[result_dtype] + for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, res.shape): + scalar_l = in_stype(left[l_idx]) + scalar_r = in_stype(right[r_idx]) + if any(ignorer(s) for s in [scalar_l, scalar_r]): + continue + expected = refimpl(scalar_l, scalar_r) + if result_dtype != xp.bool: + if expected <= m or expected >= M: + continue + scalar_o = res_stype(res[o_idx]) + f_l = sh.fmt_idx(left_sym, l_idx) + f_r = sh.fmt_idx(right_sym, r_idx) + f_o = sh.fmt_idx(res_name, o_idx) + expr = expr_template.format(scalar_l, scalar_r, expected) + if dh.is_float_dtype(result_dtype): + assert isclose(scalar_o, expected), ( + f"{f_o}={scalar_o}, but should be roughly {expr} [{func_name}()]\n" + f"{f_l}={scalar_l}, {f_r}={scalar_r}" + ) + else: + assert scalar_o == expected, ( + f"{f_o}={scalar_o}, but should be {expr} [{func_name}()]\n" + f"{f_l}={scalar_l}, {f_r}={scalar_r}" + ) # When appropiate, this module tests operators alongside their respective @@ -325,7 +372,6 @@ def test_abs(ctx, data): ph.assert_shape(ctx.func_name, out.shape, x.shape) unary_assert_against_refimpl( ctx.func_name, - dh.get_scalar_type(x.dtype), x, out, abs, @@ -379,37 +425,34 @@ def test_add(ctx, data): assert_binary_param_dtype(ctx, left, right, res) assert_binary_param_shape(ctx, left, right, res) - m, M = dh.dtype_ranges[res.dtype] - scalar_type = dh.get_scalar_type(res.dtype) if ctx.right_is_scalar: - for idx in sh.ndindex(res.shape): - scalar_l = scalar_type(left[idx]) - expected = scalar_l + right - if not math.isfinite(expected) or expected <= m or expected >= M: - continue - scalar_o = scalar_type(res[idx]) - f_l = sh.fmt_idx(ctx.left_sym, idx) - f_o = sh.fmt_idx(ctx.res_name, idx) - assert isclose(scalar_o, expected), ( - f"{f_o}={scalar_o}, but should be roughly ({f_l} + {right})={expected} " - f"[{ctx.func_name}()]\n{f_l}={scalar_l}" - ) + binary_assert_against_refimpl( + func_name=ctx.func_name, + left_sym=ctx.left_sym, + left=left, + right_sym=ctx.right_sym, + right=right, + right_is_scalar=True, + res_name=ctx.res_name, + res=res, + refimpl=operator.add, + expr_template="({} + {})={}", + ignorer=lambda s: not math.isfinite(s), + ) else: ph.assert_array(ctx.func_name, res, ctx.func(right, left)) # cumulative - for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, res.shape): - scalar_l = scalar_type(left[l_idx]) - scalar_r = scalar_type(right[r_idx]) - expected = scalar_l + scalar_r - if not math.isfinite(expected) or expected <= m or expected >= M: - continue - scalar_o = scalar_type(res[o_idx]) - f_l = sh.fmt_idx(ctx.left_sym, l_idx) - f_r = sh.fmt_idx(ctx.right_sym, r_idx) - f_o = sh.fmt_idx(ctx.res_name, o_idx) - assert isclose(scalar_o, expected), ( - f"{f_o}={scalar_o}, but should be roughly ({f_l} + {f_r})={expected} " - f"[{ctx.func_name}()]\n{f_l}={scalar_l}, {f_r}={scalar_r}" - ) + binary_assert_against_refimpl( + func_name=ctx.func_name, + left_sym=ctx.left_sym, + left=left, + right_sym=ctx.right_sym, + right=right, + res_name=ctx.res_name, + res=res, + refimpl=operator.add, + expr_template="({} + {})={}", + ignorer=lambda s: not math.isfinite(s), + ) @given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) @@ -531,11 +574,7 @@ def test_bitwise_and(ctx, data): # for mypy assert isinstance(scalar_l, int) assert isinstance(right, int) - expected = ah.mock_int_dtype( - scalar_l & right, - dh.dtype_nbits[res.dtype], - dh.dtype_signed[res.dtype], - ) + expected = mock_int_dtype(scalar_l & right, res.dtype) scalar_o = scalar_type(res[idx]) f_l = sh.fmt_idx(ctx.left_sym, idx) f_o = sh.fmt_idx(ctx.res_name, idx) @@ -553,11 +592,7 @@ def test_bitwise_and(ctx, data): # for mypy assert isinstance(scalar_l, int) assert isinstance(scalar_r, int) - expected = ah.mock_int_dtype( - scalar_l & scalar_r, - dh.dtype_nbits[res.dtype], - dh.dtype_signed[res.dtype], - ) + expected = mock_int_dtype(scalar_l & scalar_r, res.dtype) scalar_o = scalar_type(res[o_idx]) f_l = sh.fmt_idx(ctx.left_sym, l_idx) f_r = sh.fmt_idx(ctx.right_sym, r_idx) @@ -587,11 +622,10 @@ def test_bitwise_left_shift(ctx, data): if ctx.right_is_scalar: for idx in sh.ndindex(res.shape): scalar_l = int(left[idx]) - expected = ah.mock_int_dtype( + expected = mock_int_dtype( # We avoid shifting very large ints scalar_l << right if right < dh.dtype_nbits[res.dtype] else 0, - dh.dtype_nbits[res.dtype], - dh.dtype_signed[res.dtype], + res.dtype, ) scalar_o = int(res[idx]) f_l = sh.fmt_idx(ctx.left_sym, idx) @@ -604,11 +638,10 @@ def test_bitwise_left_shift(ctx, data): for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, res.shape): scalar_l = int(left[l_idx]) scalar_r = int(right[r_idx]) - expected = ah.mock_int_dtype( + expected = mock_int_dtype( # We avoid shifting very large ints scalar_l << scalar_r if scalar_r < dh.dtype_nbits[res.dtype] else 0, - dh.dtype_nbits[res.dtype], - dh.dtype_signed[res.dtype], + res.dtype, ) scalar_o = int(res[o_idx]) f_l = sh.fmt_idx(ctx.left_sym, l_idx) @@ -636,9 +669,7 @@ def test_bitwise_invert(ctx, data): refimpl = lambda s: not s else: refimpl = lambda s: mock_int_dtype(~s, x.dtype) - unary_assert_against_refimpl( - ctx.func_name, dh.get_scalar_type(x.dtype), x, out, refimpl, "~{}={}" - ) + unary_assert_against_refimpl(ctx.func_name, x, out, refimpl, "~{}={}") @pytest.mark.parametrize( @@ -662,11 +693,7 @@ def test_bitwise_or(ctx, data): else: scalar_l = int(left[idx]) scalar_o = int(res[idx]) - expected = ah.mock_int_dtype( - scalar_l | right, - dh.dtype_nbits[res.dtype], - dh.dtype_signed[res.dtype], - ) + expected = mock_int_dtype(scalar_l | right, res.dtype) f_l = sh.fmt_idx(ctx.left_sym, idx) f_o = sh.fmt_idx(ctx.res_name, idx) assert scalar_o == expected, ( @@ -684,11 +711,7 @@ def test_bitwise_or(ctx, data): scalar_l = int(left[l_idx]) scalar_r = int(right[r_idx]) scalar_o = int(res[o_idx]) - expected = ah.mock_int_dtype( - scalar_l | scalar_r, - dh.dtype_nbits[res.dtype], - dh.dtype_signed[res.dtype], - ) + expected = mock_int_dtype(scalar_l | scalar_r, res.dtype) f_l = sh.fmt_idx(ctx.left_sym, l_idx) f_r = sh.fmt_idx(ctx.right_sym, r_idx) f_o = sh.fmt_idx(ctx.res_name, o_idx) @@ -717,11 +740,7 @@ def test_bitwise_right_shift(ctx, data): if ctx.right_is_scalar: for idx in sh.ndindex(res.shape): scalar_l = int(left[idx]) - expected = ah.mock_int_dtype( - scalar_l >> right, - dh.dtype_nbits[res.dtype], - dh.dtype_signed[res.dtype], - ) + expected = mock_int_dtype(scalar_l >> right, res.dtype) scalar_o = int(res[idx]) f_l = sh.fmt_idx(ctx.left_sym, idx) f_o = sh.fmt_idx(ctx.res_name, idx) @@ -733,11 +752,7 @@ def test_bitwise_right_shift(ctx, data): for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, res.shape): scalar_l = int(left[l_idx]) scalar_r = int(right[r_idx]) - expected = ah.mock_int_dtype( - scalar_l >> scalar_r, - dh.dtype_nbits[res.dtype], - dh.dtype_signed[res.dtype], - ) + expected = mock_int_dtype(scalar_l >> scalar_r, res.dtype) scalar_o = int(res[o_idx]) f_l = sh.fmt_idx(ctx.left_sym, l_idx) f_r = sh.fmt_idx(ctx.right_sym, r_idx) @@ -769,11 +784,7 @@ def test_bitwise_xor(ctx, data): else: scalar_l = int(left[idx]) scalar_o = int(res[idx]) - expected = ah.mock_int_dtype( - scalar_l ^ right, - dh.dtype_nbits[res.dtype], - dh.dtype_signed[res.dtype], - ) + expected = mock_int_dtype(scalar_l ^ right, res.dtype) f_l = sh.fmt_idx(ctx.left_sym, idx) f_o = sh.fmt_idx(ctx.res_name, idx) assert scalar_o == expected, ( @@ -791,11 +802,7 @@ def test_bitwise_xor(ctx, data): scalar_l = int(left[l_idx]) scalar_r = int(right[r_idx]) scalar_o = int(res[o_idx]) - expected = ah.mock_int_dtype( - scalar_l ^ scalar_r, - dh.dtype_nbits[res.dtype], - dh.dtype_signed[res.dtype], - ) + expected = mock_int_dtype(scalar_l ^ scalar_r, res.dtype) f_l = sh.fmt_idx(ctx.left_sym, l_idx) f_r = sh.fmt_idx(ctx.right_sym, r_idx) f_o = sh.fmt_idx(ctx.res_name, o_idx) @@ -1309,13 +1316,7 @@ def test_logical_and(x1, x2): ph.assert_dtype("logical_and", [x1.dtype, x2.dtype], out.dtype) ph.assert_result_shape("logical_and", [x1.shape, x2.shape], out.shape) binary_assert_against_refimpl( - "logical_and", - bool, - x1, - x2, - out, - lambda l, r: l and r, - "({} and {})={}", + "logical_and", x1, x2, out, lambda l, r: l and r, "({} and {})={}" ) @@ -1324,9 +1325,7 @@ def test_logical_not(x): out = ah.logical_not(x) ph.assert_dtype("logical_not", x.dtype, out.dtype) ph.assert_shape("logical_not", out.shape, x.shape) - unary_assert_against_refimpl( - "logical_not", bool, x, out, lambda i: not i, "(not {})={}" - ) + unary_assert_against_refimpl("logical_not", x, out, lambda i: not i, "(not {})={}") @given(*hh.two_mutual_arrays([xp.bool])) @@ -1335,7 +1334,7 @@ def test_logical_or(x1, x2): ph.assert_dtype("logical_or", [x1.dtype, x2.dtype], out.dtype) ph.assert_result_shape("logical_or", [x1.shape, x2.shape], out.shape) binary_assert_against_refimpl( - "logical_or", bool, x1, x2, out, lambda l, r: l or r, "({} or {})={}" + "logical_or", x1, x2, out, lambda l, r: l or r, "({} or {})={}" ) @@ -1345,7 +1344,7 @@ def test_logical_xor(x1, x2): ph.assert_dtype("logical_xor", [x1.dtype, x2.dtype], out.dtype) ph.assert_result_shape("logical_xor", [x1.shape, x2.shape], out.shape) binary_assert_against_refimpl( - "logical_xor", bool, x1, x2, out, lambda l, r: l ^ r, "({} ^ {})={}" + "logical_xor", x1, x2, out, lambda l, r: l ^ r, "({} ^ {})={}" ) @@ -1377,9 +1376,7 @@ def test_negative(ctx, data): ph.assert_dtype(ctx.func_name, x.dtype, out.dtype) ph.assert_shape(ctx.func_name, out.shape, x.shape) - unary_assert_against_refimpl( - ctx.func_name, dh.get_scalar_type(x.dtype), x, out, operator.neg, "-({})={}" - ) + unary_assert_against_refimpl(ctx.func_name, x, out, operator.neg, "-({})={}") @pytest.mark.parametrize("ctx", make_binary_params("not_equal", xps.scalar_dtypes())) From 3dfd66553e0108672442aed7bd6deb8a074cdf03 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Mon, 31 Jan 2022 15:39:01 +0000 Subject: [PATCH 18/40] `binary_param_assert_against_refimpl()` to refactor elwise+op tests --- ...est_operators_and_elementwise_functions.py | 250 +++++++++--------- 1 file changed, 130 insertions(+), 120 deletions(-) diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index e530f50c..256a4e81 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -99,68 +99,40 @@ def binary_assert_against_refimpl( res_stype: Optional[ScalarType] = None, left_sym: str = "x1", right_sym: str = "x2", - right_is_scalar: bool = False, res_name: str = "out", - ignorer: Callable[[Scalar, Scalar], bool] = bool, + ignorer: Callable[[Scalar], bool] = bool, ): if in_stype is None: in_stype = dh.get_scalar_type(left.dtype) if res_stype is None: res_stype = in_stype - if right_is_scalar: - if left.dtype != xp.bool: - m, M = dh.dtype_ranges[left.dtype] - for idx in sh.ndindex(res.shape): - scalar_l = in_stype(left[idx]) - if any(ignorer(s) for s in [scalar_l, right]): - continue - expected = refimpl(scalar_l, right) - if left.dtype != xp.bool: - if expected <= m or expected >= M: - continue - scalar_o = res_stype(res[idx]) - f_l = sh.fmt_idx(left_sym, idx) - f_o = sh.fmt_idx(res_name, idx) - expr = expr_template.format(scalar_l, right, expected) - if dh.is_float_dtype(left.dtype): - assert isclose(scalar_o, expected), ( - f"{f_o}={scalar_o}, but should be roughly {expr} [{func_name}()]\n" - f"{f_l}={scalar_l}" - ) - - else: - assert scalar_o == expected, ( - f"{f_o}={scalar_o}, but should be {expr} [{func_name}()]\n" - f"{f_l}={scalar_l}" - ) - else: - result_dtype = dh.result_type(left.dtype, right.dtype) + result_dtype = dh.result_type(left.dtype, right.dtype) + if result_dtype != xp.bool: + m, M = dh.dtype_ranges[result_dtype] + for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, res.shape): + scalar_l = in_stype(left[l_idx]) + scalar_r = in_stype(right[r_idx]) + if any(ignorer(s) for s in [scalar_l, scalar_r]): + continue + expected = refimpl(scalar_l, scalar_r) if result_dtype != xp.bool: - m, M = dh.dtype_ranges[result_dtype] - for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, res.shape): - scalar_l = in_stype(left[l_idx]) - scalar_r = in_stype(right[r_idx]) - if any(ignorer(s) for s in [scalar_l, scalar_r]): + if expected <= m or expected >= M: continue - expected = refimpl(scalar_l, scalar_r) - if result_dtype != xp.bool: - if expected <= m or expected >= M: - continue - scalar_o = res_stype(res[o_idx]) - f_l = sh.fmt_idx(left_sym, l_idx) - f_r = sh.fmt_idx(right_sym, r_idx) - f_o = sh.fmt_idx(res_name, o_idx) - expr = expr_template.format(scalar_l, scalar_r, expected) - if dh.is_float_dtype(result_dtype): - assert isclose(scalar_o, expected), ( - f"{f_o}={scalar_o}, but should be roughly {expr} [{func_name}()]\n" - f"{f_l}={scalar_l}, {f_r}={scalar_r}" - ) - else: - assert scalar_o == expected, ( - f"{f_o}={scalar_o}, but should be {expr} [{func_name}()]\n" - f"{f_l}={scalar_l}, {f_r}={scalar_r}" - ) + scalar_o = res_stype(res[o_idx]) + f_l = sh.fmt_idx(left_sym, l_idx) + f_r = sh.fmt_idx(right_sym, r_idx) + f_o = sh.fmt_idx(res_name, o_idx) + expr = expr_template.format(scalar_l, scalar_r, expected) + if dh.is_float_dtype(result_dtype): + assert isclose(scalar_o, expected), ( + f"{f_o}={scalar_o}, but should be roughly {expr} [{func_name}()]\n" + f"{f_l}={scalar_l}, {f_r}={scalar_r}" + ) + else: + assert scalar_o == expected, ( + f"{f_o}={scalar_o}, but should be {expr} [{func_name}()]\n" + f"{f_l}={scalar_l}, {f_r}={scalar_r}" + ) # When appropiate, this module tests operators alongside their respective @@ -326,7 +298,7 @@ def func(l: Array, r: Union[Scalar, Array]) -> Array: return params -def assert_binary_param_dtype( +def binary_param_assert_dtype( ctx: BinaryParamContext, left: Array, right: Union[Array, Scalar], @@ -342,7 +314,7 @@ def assert_binary_param_dtype( ) -def assert_binary_param_shape( +def binary_param_assert_shape( ctx: BinaryParamContext, left: Array, right: Union[Array, Scalar], @@ -358,6 +330,63 @@ def assert_binary_param_shape( ) +def binary_param_assert_against_refimpl( + ctx: BinaryParamContext, + left: Array, + right: Union[Array, Scalar], + res: Array, + refimpl: Callable[[Scalar, Scalar], Scalar], + expr_template: str, + in_stype: Optional[ScalarType] = None, + res_stype: Optional[ScalarType] = None, + ignorer: Callable[[Scalar], Scalar] = bool, +): + if ctx.right_is_scalar: + if left.dtype != xp.bool: + m, M = dh.dtype_ranges[left.dtype] + if in_stype is None: + in_stype = dh.get_scalar_type(left.dtype) + if res_stype is None: + res_stype = in_stype + for idx in sh.ndindex(res.shape): + scalar_l = in_stype(left[idx]) + if any(ignorer(s) for s in [scalar_l, right]): + continue + expected = refimpl(scalar_l, right) + if left.dtype != xp.bool: + if expected <= m or expected >= M: + continue + scalar_o = res_stype(res[idx]) + f_l = sh.fmt_idx(ctx.left_sym, idx) + f_o = sh.fmt_idx(ctx.res_name, idx) + expr = expr_template.format(scalar_l, right, expected) + if dh.is_float_dtype(left.dtype): + assert isclose(scalar_o, expected), ( + f"{f_o}={scalar_o}, but should be roughly {expr} " + f"[{ctx.func_name}()]\n" + f"{f_l}={scalar_l}" + ) + else: + assert scalar_o == expected, ( + f"{f_o}={scalar_o}, but should be {expr} " + f"[{ctx.func_name}()]\n" + f"{f_l}={scalar_l}" + ) + else: + binary_assert_against_refimpl( + func_name=ctx.func_name, + left_sym=ctx.left_sym, + left=left, + right_sym=ctx.right_sym, + right=right, + res_name=ctx.res_name, + res=res, + refimpl=operator.add, + expr_template=expr_template, + ignorer=lambda s: not math.isfinite(s), + ) + + @pytest.mark.parametrize("ctx", make_unary_params("abs", xps.numeric_dtypes())) @given(data=st.data()) def test_abs(ctx, data): @@ -423,36 +452,17 @@ def test_add(ctx, data): except OverflowError: reject() - assert_binary_param_dtype(ctx, left, right, res) - assert_binary_param_shape(ctx, left, right, res) - if ctx.right_is_scalar: - binary_assert_against_refimpl( - func_name=ctx.func_name, - left_sym=ctx.left_sym, - left=left, - right_sym=ctx.right_sym, - right=right, - right_is_scalar=True, - res_name=ctx.res_name, - res=res, - refimpl=operator.add, - expr_template="({} + {})={}", - ignorer=lambda s: not math.isfinite(s), - ) - else: - ph.assert_array(ctx.func_name, res, ctx.func(right, left)) # cumulative - binary_assert_against_refimpl( - func_name=ctx.func_name, - left_sym=ctx.left_sym, - left=left, - right_sym=ctx.right_sym, - right=right, - res_name=ctx.res_name, - res=res, - refimpl=operator.add, - expr_template="({} + {})={}", - ignorer=lambda s: not math.isfinite(s), - ) + binary_param_assert_dtype(ctx, left, right, res) + binary_param_assert_shape(ctx, left, right, res) + binary_param_assert_against_refimpl( + ctx, + left, + right, + res, + operator.add, + "({} + {})={}", + ignorer=lambda s: not math.isfinite(s), + ) @given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) @@ -562,8 +572,8 @@ def test_bitwise_and(ctx, data): res = ctx.func(left, right) - assert_binary_param_dtype(ctx, left, right, res) - assert_binary_param_shape(ctx, left, right, res) + binary_param_assert_dtype(ctx, left, right, res) + binary_param_assert_shape(ctx, left, right, res) scalar_type = dh.get_scalar_type(res.dtype) if ctx.right_is_scalar: for idx in sh.ndindex(res.shape): @@ -617,8 +627,8 @@ def test_bitwise_left_shift(ctx, data): res = ctx.func(left, right) - assert_binary_param_dtype(ctx, left, right, res) - assert_binary_param_shape(ctx, left, right, res) + binary_param_assert_dtype(ctx, left, right, res) + binary_param_assert_shape(ctx, left, right, res) if ctx.right_is_scalar: for idx in sh.ndindex(res.shape): scalar_l = int(left[idx]) @@ -682,8 +692,8 @@ def test_bitwise_or(ctx, data): res = ctx.func(left, right) - assert_binary_param_dtype(ctx, left, right, res) - assert_binary_param_shape(ctx, left, right, res) + binary_param_assert_dtype(ctx, left, right, res) + binary_param_assert_shape(ctx, left, right, res) if ctx.right_is_scalar: for idx in sh.ndindex(res.shape): if res.dtype == xp.bool: @@ -735,8 +745,8 @@ def test_bitwise_right_shift(ctx, data): res = ctx.func(left, right) - assert_binary_param_dtype(ctx, left, right, res) - assert_binary_param_shape(ctx, left, right, res) + binary_param_assert_dtype(ctx, left, right, res) + binary_param_assert_shape(ctx, left, right, res) if ctx.right_is_scalar: for idx in sh.ndindex(res.shape): scalar_l = int(left[idx]) @@ -773,8 +783,8 @@ def test_bitwise_xor(ctx, data): res = ctx.func(left, right) - assert_binary_param_dtype(ctx, left, right, res) - assert_binary_param_shape(ctx, left, right, res) + binary_param_assert_dtype(ctx, left, right, res) + binary_param_assert_shape(ctx, left, right, res) if ctx.right_is_scalar: for idx in sh.ndindex(res.shape): if res.dtype == xp.bool: @@ -863,8 +873,8 @@ def test_divide(ctx, data): res = ctx.func(left, right) - assert_binary_param_dtype(ctx, left, right, res) - assert_binary_param_shape(ctx, left, right, res) + binary_param_assert_dtype(ctx, left, right, res) + binary_param_assert_shape(ctx, left, right, res) # There isn't much we can test here. The spec doesn't require any behavior # beyond the special cases, and indeed, there aren't many mathematical # properties of division that strictly hold for floating-point numbers. We @@ -880,8 +890,8 @@ def test_equal(ctx, data): out = ctx.func(left, right) - assert_binary_param_dtype(ctx, left, right, out, xp.bool) - assert_binary_param_shape(ctx, left, right, out) + binary_param_assert_dtype(ctx, left, right, out, xp.bool) + binary_param_assert_shape(ctx, left, right, out) if ctx.right_is_scalar: scalar_type = dh.get_scalar_type(left.dtype) for idx in sh.ndindex(left.shape): @@ -982,8 +992,8 @@ def test_floor_divide(ctx, data): res = ctx.func(left, right) - assert_binary_param_dtype(ctx, left, right, res) - assert_binary_param_shape(ctx, left, right, res) + binary_param_assert_dtype(ctx, left, right, res) + binary_param_assert_shape(ctx, left, right, res) scalar_type = dh.get_scalar_type(res.dtype) if ctx.right_is_scalar: for idx in sh.ndindex(res.shape): @@ -1025,8 +1035,8 @@ def test_greater(ctx, data): out = ctx.func(left, right) - assert_binary_param_dtype(ctx, left, right, out, xp.bool) - assert_binary_param_shape(ctx, left, right, out) + binary_param_assert_dtype(ctx, left, right, out, xp.bool) + binary_param_assert_shape(ctx, left, right, out) if ctx.right_is_scalar: scalar_type = dh.get_scalar_type(left.dtype) for idx in sh.ndindex(left.shape): @@ -1069,8 +1079,8 @@ def test_greater_equal(ctx, data): out = ctx.func(left, right) - assert_binary_param_dtype(ctx, left, right, out, xp.bool) - assert_binary_param_shape(ctx, left, right, out) + binary_param_assert_dtype(ctx, left, right, out, xp.bool) + binary_param_assert_shape(ctx, left, right, out) if ctx.right_is_scalar: scalar_type = dh.get_scalar_type(left.dtype) for idx in sh.ndindex(left.shape): @@ -1167,8 +1177,8 @@ def test_less(ctx, data): out = ctx.func(left, right) - assert_binary_param_dtype(ctx, left, right, out, xp.bool) - assert_binary_param_shape(ctx, left, right, out) + binary_param_assert_dtype(ctx, left, right, out, xp.bool) + binary_param_assert_shape(ctx, left, right, out) if ctx.right_is_scalar: scalar_type = dh.get_scalar_type(left.dtype) for idx in sh.ndindex(left.shape): @@ -1209,8 +1219,8 @@ def test_less_equal(ctx, data): out = ctx.func(left, right) - assert_binary_param_dtype(ctx, left, right, out, xp.bool) - assert_binary_param_shape(ctx, left, right, out) + binary_param_assert_dtype(ctx, left, right, out, xp.bool) + binary_param_assert_shape(ctx, left, right, out) if ctx.right_is_scalar: scalar_type = dh.get_scalar_type(left.dtype) for idx in sh.ndindex(left.shape): @@ -1356,8 +1366,8 @@ def test_multiply(ctx, data): res = ctx.func(left, right) - assert_binary_param_dtype(ctx, left, right, res) - assert_binary_param_shape(ctx, left, right, res) + binary_param_assert_dtype(ctx, left, right, res) + binary_param_assert_shape(ctx, left, right, res) if not ctx.right_is_scalar: # multiply is commutative expected = ctx.func(right, left) @@ -1387,8 +1397,8 @@ def test_not_equal(ctx, data): out = ctx.func(left, right) - assert_binary_param_dtype(ctx, left, right, out, xp.bool) - assert_binary_param_shape(ctx, left, right, out) + binary_param_assert_dtype(ctx, left, right, out, xp.bool) + binary_param_assert_shape(ctx, left, right, out) if ctx.right_is_scalar: scalar_type = dh.get_scalar_type(left.dtype) for idx in sh.ndindex(left.shape): @@ -1450,8 +1460,8 @@ def test_pow(ctx, data): except OverflowError: reject() - assert_binary_param_dtype(ctx, left, right, res) - assert_binary_param_shape(ctx, left, right, res) + binary_param_assert_dtype(ctx, left, right, res) + binary_param_assert_shape(ctx, left, right, res) # There isn't much we can test here. The spec doesn't require any behavior # beyond the special cases, and indeed, there aren't many mathematical # properties of exponentiation that strictly hold for floating-point @@ -1471,8 +1481,8 @@ def test_remainder(ctx, data): res = ctx.func(left, right) - assert_binary_param_dtype(ctx, left, right, res) - assert_binary_param_shape(ctx, left, right, res) + binary_param_assert_dtype(ctx, left, right, res) + binary_param_assert_shape(ctx, left, right, res) scalar_type = dh.get_scalar_type(res.dtype) if ctx.right_is_scalar: for idx in sh.ndindex(res.shape): @@ -1599,8 +1609,8 @@ def test_subtract(ctx, data): except OverflowError: reject() - assert_binary_param_dtype(ctx, left, right, res) - assert_binary_param_shape(ctx, left, right, res) + binary_param_assert_dtype(ctx, left, right, res) + binary_param_assert_shape(ctx, left, right, res) m, M = dh.dtype_ranges[res.dtype] scalar_type = dh.get_scalar_type(res.dtype) if ctx.right_is_scalar: From a4a7e048bfdea7cfb057ce20ec979fff50c6d338 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Tue, 1 Feb 2022 10:10:13 +0000 Subject: [PATCH 19/40] Refactor remaining parametrized elwise+op tests Also `ignorer` -> `filter_` --- ...est_operators_and_elementwise_functions.py | 555 ++++-------------- 1 file changed, 106 insertions(+), 449 deletions(-) diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index 256a4e81..de216764 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -32,17 +32,20 @@ def all_integer_dtypes() -> st.SearchStrategy[DataType]: + """Returns a strategy for signed and unsigned integer dtype objects.""" return xps.unsigned_integer_dtypes() | xps.integer_dtypes() def boolean_and_all_integer_dtypes() -> st.SearchStrategy[DataType]: + """Returns a strategy for boolean and all integer dtype objects.""" return xps.boolean_dtypes() | all_integer_dtypes() -def isclose(n1: Union[int, float], n2: Union[int, float]) -> bool: - if not (math.isfinite(n1) and math.isfinite(n2)): - raise ValueError(f"{n1=} and {n1=}, but input must be finite") - return math.isclose(n1, n2, rel_tol=0.25, abs_tol=1) +def isclose(a: float, b: float, rel_tol: float = 0.25, abs_tol: float = 1) -> bool: + """Wraps math.isclose with more generous defaults.""" + if not (math.isfinite(a) and math.isfinite(b)): + raise ValueError(f"{a=} and {b=}, but input must be finite") + return math.isclose(a, b, rel_tol=rel_tol, abs_tol=abs_tol) def mock_int_dtype(n: int, dtype: DataType) -> int: @@ -65,7 +68,7 @@ def unary_assert_against_refimpl( expr_template: str, in_stype: Optional[ScalarType] = None, res_stype: Optional[ScalarType] = None, - ignorer: Callable[[Scalar], bool] = bool, + filter_: Callable[[Scalar], bool] = math.isfinite, ): if in_.shape != res.shape: raise ValueError(f"{res.shape=}, but should be {in_.shape=}") @@ -75,13 +78,13 @@ def unary_assert_against_refimpl( res_stype = in_stype for idx in sh.ndindex(in_.shape): scalar_i = in_stype(in_[idx]) - if ignorer(scalar_i): + if not filter_(scalar_i): continue expected = refimpl(scalar_i) scalar_o = res_stype(res[idx]) f_i = sh.fmt_idx("x", idx) f_o = sh.fmt_idx("out", idx) - expr = expr_template.format(scalar_i, expected) + expr = expr_template.format(f_i, expected) assert scalar_o == expected, ( f"{f_o}={scalar_o}, but should be {expr} [{func_name}()]\n" f"{f_i}={scalar_i}" @@ -100,7 +103,7 @@ def binary_assert_against_refimpl( left_sym: str = "x1", right_sym: str = "x2", res_name: str = "out", - ignorer: Callable[[Scalar], bool] = bool, + filter_: Callable[[Scalar], bool] = math.isfinite, ): if in_stype is None: in_stype = dh.get_scalar_type(left.dtype) @@ -112,7 +115,7 @@ def binary_assert_against_refimpl( for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, res.shape): scalar_l = in_stype(left[l_idx]) scalar_r = in_stype(right[r_idx]) - if any(ignorer(s) for s in [scalar_l, scalar_r]): + if not (filter_(scalar_l) and filter_(scalar_r)): continue expected = refimpl(scalar_l, scalar_r) if result_dtype != xp.bool: @@ -122,7 +125,7 @@ def binary_assert_against_refimpl( f_l = sh.fmt_idx(left_sym, l_idx) f_r = sh.fmt_idx(right_sym, r_idx) f_o = sh.fmt_idx(res_name, o_idx) - expr = expr_template.format(scalar_l, scalar_r, expected) + expr = expr_template.format(f_l, f_r, expected) if dh.is_float_dtype(result_dtype): assert isclose(scalar_o, expected), ( f"{f_o}={scalar_o}, but should be roughly {expr} [{func_name}()]\n" @@ -339,9 +342,10 @@ def binary_param_assert_against_refimpl( expr_template: str, in_stype: Optional[ScalarType] = None, res_stype: Optional[ScalarType] = None, - ignorer: Callable[[Scalar], Scalar] = bool, + filter_: Callable[[Scalar], bool] = math.isfinite, ): if ctx.right_is_scalar: + assert filter_(right) # sanity check if left.dtype != xp.bool: m, M = dh.dtype_ranges[left.dtype] if in_stype is None: @@ -350,7 +354,7 @@ def binary_param_assert_against_refimpl( res_stype = in_stype for idx in sh.ndindex(res.shape): scalar_l = in_stype(left[idx]) - if any(ignorer(s) for s in [scalar_l, right]): + if not filter_(scalar_l): continue expected = refimpl(scalar_l, right) if left.dtype != xp.bool: @@ -359,7 +363,7 @@ def binary_param_assert_against_refimpl( scalar_o = res_stype(res[idx]) f_l = sh.fmt_idx(ctx.left_sym, idx) f_o = sh.fmt_idx(ctx.res_name, idx) - expr = expr_template.format(scalar_l, right, expected) + expr = expr_template.format(f_l, right, expected) if dh.is_float_dtype(left.dtype): assert isclose(scalar_o, expected), ( f"{f_o}={scalar_o}, but should be roughly {expr} " @@ -375,15 +379,17 @@ def binary_param_assert_against_refimpl( else: binary_assert_against_refimpl( func_name=ctx.func_name, + in_stype=in_stype, left_sym=ctx.left_sym, left=left, right_sym=ctx.right_sym, right=right, + res_stype=res_stype, res_name=ctx.res_name, res=res, - refimpl=operator.add, + refimpl=refimpl, expr_template=expr_template, - ignorer=lambda s: not math.isfinite(s), + filter_=filter_, ) @@ -405,7 +411,9 @@ def test_abs(ctx, data): out, abs, "abs({})={}", - ignorer=lambda s: math.isnan(s) or s is -0.0 or s == float("-infinity"), + filter_=lambda s: ( + s == float("infinity") or (math.isfinite(s) and s is not -0.0) + ), ) @@ -455,13 +463,7 @@ def test_add(ctx, data): binary_param_assert_dtype(ctx, left, right, res) binary_param_assert_shape(ctx, left, right, res) binary_param_assert_against_refimpl( - ctx, - left, - right, - res, - operator.add, - "({} + {})={}", - ignorer=lambda s: not math.isfinite(s), + ctx, left, right, res, operator.add, "({} + {})={}" ) @@ -574,43 +576,11 @@ def test_bitwise_and(ctx, data): binary_param_assert_dtype(ctx, left, right, res) binary_param_assert_shape(ctx, left, right, res) - scalar_type = dh.get_scalar_type(res.dtype) - if ctx.right_is_scalar: - for idx in sh.ndindex(res.shape): - scalar_l = scalar_type(left[idx]) - if res.dtype == xp.bool: - expected = scalar_l and right - else: - # for mypy - assert isinstance(scalar_l, int) - assert isinstance(right, int) - expected = mock_int_dtype(scalar_l & right, res.dtype) - scalar_o = scalar_type(res[idx]) - f_l = sh.fmt_idx(ctx.left_sym, idx) - f_o = sh.fmt_idx(ctx.res_name, idx) - assert scalar_o == expected, ( - f"{f_o}={scalar_o}, but should be ({f_l} & {right})={expected} " - f"[{ctx.func_name}()]\n{f_l}={scalar_l}" - ) + if left.dtype == xp.bool: + refimpl = lambda l, r: l and r else: - for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, res.shape): - scalar_l = scalar_type(left[l_idx]) - scalar_r = scalar_type(right[r_idx]) - if res.dtype == xp.bool: - expected = scalar_l and scalar_r - else: - # for mypy - assert isinstance(scalar_l, int) - assert isinstance(scalar_r, int) - expected = mock_int_dtype(scalar_l & scalar_r, res.dtype) - scalar_o = scalar_type(res[o_idx]) - f_l = sh.fmt_idx(ctx.left_sym, l_idx) - f_r = sh.fmt_idx(ctx.right_sym, r_idx) - f_o = sh.fmt_idx(ctx.res_name, o_idx) - assert scalar_o == expected, ( - f"{f_o}={scalar_o}, but should be ({f_l} & {f_r})={expected} " - f"[{ctx.func_name}()]\n{f_l}={scalar_l}, {f_r}={scalar_r}" - ) + refimpl = lambda l, r: mock_int_dtype(l & r, res.dtype) + binary_param_assert_against_refimpl(ctx, left, right, res, refimpl, "({} & {})={}") @pytest.mark.parametrize( @@ -629,38 +599,16 @@ def test_bitwise_left_shift(ctx, data): binary_param_assert_dtype(ctx, left, right, res) binary_param_assert_shape(ctx, left, right, res) - if ctx.right_is_scalar: - for idx in sh.ndindex(res.shape): - scalar_l = int(left[idx]) - expected = mock_int_dtype( - # We avoid shifting very large ints - scalar_l << right if right < dh.dtype_nbits[res.dtype] else 0, - res.dtype, - ) - scalar_o = int(res[idx]) - f_l = sh.fmt_idx(ctx.left_sym, idx) - f_o = sh.fmt_idx(ctx.res_name, idx) - assert scalar_o == expected, ( - f"{f_o}={scalar_o}, but should be ({f_l} << {right})={expected} " - f"[{ctx.func_name}()]\n{f_l}={scalar_l}" - ) - else: - for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, res.shape): - scalar_l = int(left[l_idx]) - scalar_r = int(right[r_idx]) - expected = mock_int_dtype( - # We avoid shifting very large ints - scalar_l << scalar_r if scalar_r < dh.dtype_nbits[res.dtype] else 0, - res.dtype, - ) - scalar_o = int(res[o_idx]) - f_l = sh.fmt_idx(ctx.left_sym, l_idx) - f_r = sh.fmt_idx(ctx.right_sym, r_idx) - f_o = sh.fmt_idx(ctx.res_name, o_idx) - assert scalar_o == expected, ( - f"{f_o}={scalar_o}, but should be ({f_l} << {f_r})={expected} " - f"[{ctx.func_name}()]\n{f_l}={scalar_l}, {f_r}={scalar_r}" - ) + binary_param_assert_against_refimpl( + ctx, + left, + right, + res, + lambda l, r: mock_int_dtype(l << r, res.dtype) + if r < dh.dtype_nbits[res.dtype] + else 0, + "({} << {})={}", + ) @pytest.mark.parametrize( @@ -675,7 +623,6 @@ def test_bitwise_invert(ctx, data): ph.assert_dtype(ctx.func_name, x.dtype, out.dtype) ph.assert_shape(ctx.func_name, out.shape, x.shape) if x.dtype == xp.bool: - # invert op for booleans is weird, so use not refimpl = lambda s: not s else: refimpl = lambda s: mock_int_dtype(~s, x.dtype) @@ -694,41 +641,11 @@ def test_bitwise_or(ctx, data): binary_param_assert_dtype(ctx, left, right, res) binary_param_assert_shape(ctx, left, right, res) - if ctx.right_is_scalar: - for idx in sh.ndindex(res.shape): - if res.dtype == xp.bool: - scalar_l = bool(left[idx]) - scalar_o = bool(res[idx]) - expected = scalar_l or right - else: - scalar_l = int(left[idx]) - scalar_o = int(res[idx]) - expected = mock_int_dtype(scalar_l | right, res.dtype) - f_l = sh.fmt_idx(ctx.left_sym, idx) - f_o = sh.fmt_idx(ctx.res_name, idx) - assert scalar_o == expected, ( - f"{f_o}={scalar_o}, but should be ({f_l} | {right})={expected} " - f"[{ctx.func_name}()]\n{f_l}={scalar_l}" - ) + if left.dtype == xp.bool: + refimpl = lambda l, r: l or r else: - for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, res.shape): - if res.dtype == xp.bool: - scalar_l = bool(left[l_idx]) - scalar_r = bool(right[r_idx]) - scalar_o = bool(res[o_idx]) - expected = scalar_l or scalar_r - else: - scalar_l = int(left[l_idx]) - scalar_r = int(right[r_idx]) - scalar_o = int(res[o_idx]) - expected = mock_int_dtype(scalar_l | scalar_r, res.dtype) - f_l = sh.fmt_idx(ctx.left_sym, l_idx) - f_r = sh.fmt_idx(ctx.right_sym, r_idx) - f_o = sh.fmt_idx(ctx.res_name, o_idx) - assert scalar_o == expected, ( - f"{f_o}={scalar_o}, but should be ({f_l} | {f_r})={expected} " - f"[{ctx.func_name}()]\n{f_l}={scalar_l}, {f_r}={scalar_r}" - ) + refimpl = lambda l, r: mock_int_dtype(l | r, res.dtype) + binary_param_assert_against_refimpl(ctx, left, right, res, refimpl, "({} | {})={}") @pytest.mark.parametrize( @@ -747,30 +664,14 @@ def test_bitwise_right_shift(ctx, data): binary_param_assert_dtype(ctx, left, right, res) binary_param_assert_shape(ctx, left, right, res) - if ctx.right_is_scalar: - for idx in sh.ndindex(res.shape): - scalar_l = int(left[idx]) - expected = mock_int_dtype(scalar_l >> right, res.dtype) - scalar_o = int(res[idx]) - f_l = sh.fmt_idx(ctx.left_sym, idx) - f_o = sh.fmt_idx(ctx.res_name, idx) - assert scalar_o == expected, ( - f"{f_o}={scalar_o}, but should be ({f_l} >> {right})={expected} " - f"[{ctx.func_name}()]\n{f_l}={scalar_l}" - ) - else: - for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, res.shape): - scalar_l = int(left[l_idx]) - scalar_r = int(right[r_idx]) - expected = mock_int_dtype(scalar_l >> scalar_r, res.dtype) - scalar_o = int(res[o_idx]) - f_l = sh.fmt_idx(ctx.left_sym, l_idx) - f_r = sh.fmt_idx(ctx.right_sym, r_idx) - f_o = sh.fmt_idx(ctx.res_name, o_idx) - assert scalar_o == expected, ( - f"{f_o}={scalar_o}, but should be ({f_l} >> {f_r})={expected} " - f"[{ctx.func_name}()]\n{f_l}={scalar_l}, {f_r}={scalar_r}" - ) + binary_param_assert_against_refimpl( + ctx, + left, + right, + res, + lambda l, r: mock_int_dtype(l >> r, res.dtype), + "({} >> {})={}", + ) @pytest.mark.parametrize( @@ -785,41 +686,11 @@ def test_bitwise_xor(ctx, data): binary_param_assert_dtype(ctx, left, right, res) binary_param_assert_shape(ctx, left, right, res) - if ctx.right_is_scalar: - for idx in sh.ndindex(res.shape): - if res.dtype == xp.bool: - scalar_l = bool(left[idx]) - scalar_o = bool(res[idx]) - expected = scalar_l ^ right - else: - scalar_l = int(left[idx]) - scalar_o = int(res[idx]) - expected = mock_int_dtype(scalar_l ^ right, res.dtype) - f_l = sh.fmt_idx(ctx.left_sym, idx) - f_o = sh.fmt_idx(ctx.res_name, idx) - assert scalar_o == expected, ( - f"{f_o}={scalar_o}, but should be ({f_l} ^ {right})={expected} " - f"[{ctx.func_name}()]\n{f_l}={scalar_l}" - ) + if left.dtype == xp.bool: + refimpl = lambda l, r: l ^ r else: - for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, res.shape): - if res.dtype == xp.bool: - scalar_l = bool(left[l_idx]) - scalar_r = bool(right[r_idx]) - scalar_o = bool(res[o_idx]) - expected = scalar_l ^ scalar_r - else: - scalar_l = int(left[l_idx]) - scalar_r = int(right[r_idx]) - scalar_o = int(res[o_idx]) - expected = mock_int_dtype(scalar_l ^ scalar_r, res.dtype) - f_l = sh.fmt_idx(ctx.left_sym, l_idx) - f_r = sh.fmt_idx(ctx.right_sym, r_idx) - f_o = sh.fmt_idx(ctx.res_name, o_idx) - assert scalar_o == expected, ( - f"{f_o}={scalar_o}, but should be ({f_l} ^ {f_r})={expected} " - f"[{ctx.func_name}()]\n{f_l}={scalar_l}, {f_r}={scalar_r}" - ) + refimpl = lambda l, r: mock_int_dtype(l ^ r, res.dtype) + binary_param_assert_against_refimpl(ctx, left, right, res, refimpl, "({} ^ {})={}") @given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes())) @@ -892,19 +763,7 @@ def test_equal(ctx, data): binary_param_assert_dtype(ctx, left, right, out, xp.bool) binary_param_assert_shape(ctx, left, right, out) - if ctx.right_is_scalar: - scalar_type = dh.get_scalar_type(left.dtype) - for idx in sh.ndindex(left.shape): - scalar_l = scalar_type(left[idx]) - expected = scalar_l == right - scalar_o = bool(out[idx]) - f_l = sh.fmt_idx(ctx.left_sym, idx) - f_o = sh.fmt_idx(ctx.res_name, idx) - assert scalar_o == expected, ( - f"{f_o}={scalar_o}, but should be ({f_l} == {right})={expected} " - f"[{ctx.func_name}()]\n{f_l}={scalar_l}" - ) - else: + if not ctx.right_is_scalar: # We manually promote the dtypes as incorrect internal type promotion # could lead to false positives. For example # @@ -915,21 +774,11 @@ def test_equal(ctx, data): # # would erroneously be True if float64 downcasted to float32. promoted_dtype = dh.promotion_table[left.dtype, right.dtype] - _left = xp.astype(left, promoted_dtype) - _right = xp.astype(right, promoted_dtype) - scalar_type = dh.get_scalar_type(promoted_dtype) - for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, out.shape): - scalar_l = scalar_type(_left[l_idx]) - scalar_r = scalar_type(_right[r_idx]) - expected = scalar_l == scalar_r - scalar_o = bool(out[o_idx]) - f_l = sh.fmt_idx(ctx.left_sym, l_idx) - f_r = sh.fmt_idx(ctx.right_sym, r_idx) - f_o = sh.fmt_idx(ctx.res_name, o_idx) - assert scalar_o == expected, ( - f"{f_o}={scalar_o}, but should be ({f_l} == {f_r})={expected} " - f"[{ctx.func_name}()]\n{f_l}={scalar_l}, {f_r}={scalar_r}" - ) + left = xp.astype(left, promoted_dtype) + right = xp.astype(right, promoted_dtype) + binary_param_assert_against_refimpl( + ctx, left, right, out, operator.eq, "({} == {})={}", res_stype=bool + ) @given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) @@ -994,37 +843,9 @@ def test_floor_divide(ctx, data): binary_param_assert_dtype(ctx, left, right, res) binary_param_assert_shape(ctx, left, right, res) - scalar_type = dh.get_scalar_type(res.dtype) - if ctx.right_is_scalar: - for idx in sh.ndindex(res.shape): - scalar_l = scalar_type(left[idx]) - expected = scalar_l // right - scalar_o = scalar_type(res[idx]) - if not all(math.isfinite(n) for n in [scalar_l, right, scalar_o, expected]): - continue - f_l = sh.fmt_idx(ctx.left_sym, idx) - f_o = sh.fmt_idx(ctx.res_name, idx) - assert isclose(scalar_o, expected), ( - f"{f_o}={scalar_o}, but should be roughly ({f_l} // {right})={expected} " - f"[{ctx.func_name}()]\n{f_l}={scalar_l}" - ) - else: - for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, res.shape): - scalar_l = scalar_type(left[l_idx]) - scalar_r = scalar_type(right[r_idx]) - expected = scalar_l // scalar_r - scalar_o = scalar_type(res[o_idx]) - if not all( - math.isfinite(n) for n in [scalar_l, scalar_r, scalar_o, expected] - ): - continue - f_l = sh.fmt_idx(ctx.left_sym, l_idx) - f_r = sh.fmt_idx(ctx.right_sym, r_idx) - f_o = sh.fmt_idx(ctx.res_name, o_idx) - assert isclose(scalar_o, expected), ( - f"{f_o}={scalar_o}, but should be roughly ({f_l} // {f_r})={expected} " - f"[{ctx.func_name}()]\n{f_l}={scalar_l}, {f_r}={scalar_r}" - ) + binary_param_assert_against_refimpl( + ctx, left, right, res, operator.floordiv, "({} // {})={}" + ) @pytest.mark.parametrize("ctx", make_binary_params("greater", xps.numeric_dtypes())) @@ -1037,36 +858,14 @@ def test_greater(ctx, data): binary_param_assert_dtype(ctx, left, right, out, xp.bool) binary_param_assert_shape(ctx, left, right, out) - if ctx.right_is_scalar: - scalar_type = dh.get_scalar_type(left.dtype) - for idx in sh.ndindex(left.shape): - scalar_l = scalar_type(left[idx]) - expected = scalar_l > right - scalar_o = bool(out[idx]) - f_l = sh.fmt_idx(ctx.left_sym, idx) - f_o = sh.fmt_idx(ctx.res_name, idx) - assert scalar_o == expected, ( - f"{f_o}={scalar_o}, but should be ({f_l} > {right})={expected} " - f"[{ctx.func_name}()]\n{f_l}={scalar_l}" - ) - else: + if not ctx.right_is_scalar: # See test_equal note promoted_dtype = dh.promotion_table[left.dtype, right.dtype] - _left = xp.astype(left, promoted_dtype) - _right = xp.astype(right, promoted_dtype) - scalar_type = dh.get_scalar_type(promoted_dtype) - for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, out.shape): - scalar_l = scalar_type(_left[l_idx]) - scalar_r = scalar_type(_right[r_idx]) - expected = scalar_l > scalar_r - scalar_o = bool(out[o_idx]) - f_l = sh.fmt_idx(ctx.left_sym, l_idx) - f_r = sh.fmt_idx(ctx.right_sym, r_idx) - f_o = sh.fmt_idx(ctx.res_name, o_idx) - assert scalar_o == expected, ( - f"{f_o}={scalar_o}, but should be ({f_l} > {f_r})={expected} " - f"[{ctx.func_name}()]\n{f_l}={scalar_l}, {f_r}={scalar_r}" - ) + left = xp.astype(left, promoted_dtype) + right = xp.astype(right, promoted_dtype) + binary_param_assert_against_refimpl( + ctx, left, right, out, operator.gt, "({} > {})={}", res_stype=bool + ) @pytest.mark.parametrize( @@ -1081,36 +880,14 @@ def test_greater_equal(ctx, data): binary_param_assert_dtype(ctx, left, right, out, xp.bool) binary_param_assert_shape(ctx, left, right, out) - if ctx.right_is_scalar: - scalar_type = dh.get_scalar_type(left.dtype) - for idx in sh.ndindex(left.shape): - scalar_l = scalar_type(left[idx]) - expected = scalar_l >= right - scalar_o = bool(out[idx]) - f_l = sh.fmt_idx(ctx.left_sym, idx) - f_o = sh.fmt_idx(ctx.res_name, idx) - assert scalar_o == expected, ( - f"{f_o}={scalar_o}, but should be ({f_l} >= {right})={expected} " - f"[{ctx.func_name}()]\n{f_l}={scalar_l}" - ) - else: + if not ctx.right_is_scalar: # See test_equal note promoted_dtype = dh.promotion_table[left.dtype, right.dtype] - _left = xp.astype(left, promoted_dtype) - _right = xp.astype(right, promoted_dtype) - scalar_type = dh.get_scalar_type(promoted_dtype) - for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, out.shape): - scalar_l = scalar_type(_left[l_idx]) - scalar_r = scalar_type(_right[r_idx]) - expected = scalar_l >= scalar_r - scalar_o = bool(out[o_idx]) - f_l = sh.fmt_idx(ctx.left_sym, l_idx) - f_r = sh.fmt_idx(ctx.right_sym, r_idx) - f_o = sh.fmt_idx(ctx.res_name, o_idx) - assert scalar_o == expected, ( - f"{f_o}={scalar_o}, but should be ({f_l} >= {f_r})={expected} " - f"[{ctx.func_name}()]\n{f_l}={scalar_l}, {f_r}={scalar_r}" - ) + left = xp.astype(left, promoted_dtype) + right = xp.astype(right, promoted_dtype) + binary_param_assert_against_refimpl( + ctx, left, right, out, operator.ge, "({} >= {})={}", res_stype=bool + ) @given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes())) @@ -1179,36 +956,14 @@ def test_less(ctx, data): binary_param_assert_dtype(ctx, left, right, out, xp.bool) binary_param_assert_shape(ctx, left, right, out) - if ctx.right_is_scalar: - scalar_type = dh.get_scalar_type(left.dtype) - for idx in sh.ndindex(left.shape): - scalar_l = scalar_type(left[idx]) - expected = scalar_l < right - scalar_o = bool(out[idx]) - f_l = sh.fmt_idx(ctx.left_sym, idx) - f_o = sh.fmt_idx(ctx.res_name, idx) - assert scalar_o == expected, ( - f"{f_o}={scalar_o}, but should be ({f_l} < {right})={expected} " - f"[{ctx.func_name}()]\n{f_l}={scalar_l}" - ) - else: + if not ctx.right_is_scalar: # See test_equal note promoted_dtype = dh.promotion_table[left.dtype, right.dtype] - _left = xp.astype(left, promoted_dtype) - _right = xp.astype(right, promoted_dtype) - scalar_type = dh.get_scalar_type(promoted_dtype) - for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, out.shape): - scalar_l = scalar_type(_left[l_idx]) - scalar_r = scalar_type(_right[r_idx]) - expected = scalar_l < scalar_r - scalar_o = bool(out[o_idx]) - f_l = sh.fmt_idx(ctx.left_sym, l_idx) - f_r = sh.fmt_idx(ctx.right_sym, r_idx) - f_o = sh.fmt_idx(ctx.res_name, o_idx) - assert scalar_o == expected, ( - f"{f_o}={scalar_o}, but should be ({f_l} < {f_r})={expected} " - f"[{ctx.func_name}()]\n{f_l}={scalar_l}, {f_r}={scalar_r}" - ) + left = xp.astype(left, promoted_dtype) + right = xp.astype(right, promoted_dtype) + binary_param_assert_against_refimpl( + ctx, left, right, out, operator.lt, "({} < {})={}", res_stype=bool + ) @pytest.mark.parametrize("ctx", make_binary_params("less_equal", xps.numeric_dtypes())) @@ -1221,36 +976,14 @@ def test_less_equal(ctx, data): binary_param_assert_dtype(ctx, left, right, out, xp.bool) binary_param_assert_shape(ctx, left, right, out) - if ctx.right_is_scalar: - scalar_type = dh.get_scalar_type(left.dtype) - for idx in sh.ndindex(left.shape): - scalar_l = scalar_type(left[idx]) - expected = scalar_l <= right - scalar_o = bool(out[idx]) - f_l = sh.fmt_idx(ctx.left_sym, idx) - f_o = sh.fmt_idx(ctx.res_name, idx) - assert scalar_o == expected, ( - f"{f_o}={scalar_o}, but should be ({f_l} <= {right})={expected} " - f"[{ctx.func_name}()]\n{f_l}={scalar_l}" - ) - else: + if not ctx.right_is_scalar: # See test_equal note promoted_dtype = dh.promotion_table[left.dtype, right.dtype] - _left = xp.astype(left, promoted_dtype) - _right = xp.astype(right, promoted_dtype) - scalar_type = dh.get_scalar_type(promoted_dtype) - for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, out.shape): - scalar_l = scalar_type(_left[l_idx]) - scalar_r = scalar_type(_right[r_idx]) - expected = scalar_l <= scalar_r - scalar_o = bool(out[o_idx]) - f_l = sh.fmt_idx(ctx.left_sym, l_idx) - f_r = sh.fmt_idx(ctx.right_sym, r_idx) - f_o = sh.fmt_idx(ctx.res_name, o_idx) - assert scalar_o == expected, ( - f"{f_o}={scalar_o}, but should be ({f_l} <= {f_r})={expected} " - f"[{ctx.func_name}()]\n{f_l}={scalar_l}, {f_r}={scalar_r}" - ) + left = xp.astype(left, promoted_dtype) + right = xp.astype(right, promoted_dtype) + binary_param_assert_against_refimpl( + ctx, left, right, out, operator.le, "({} <= {})={}", res_stype=bool + ) @given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) @@ -1368,13 +1101,14 @@ def test_multiply(ctx, data): binary_param_assert_dtype(ctx, left, right, res) binary_param_assert_shape(ctx, left, right, res) - if not ctx.right_is_scalar: - # multiply is commutative - expected = ctx.func(right, left) - ah.assert_exactly_equal(res, expected) + binary_param_assert_against_refimpl( + ctx, left, right, res, operator.mul, "({} * {})={}" + ) -@pytest.mark.parametrize("ctx", make_unary_params("negative", xps.numeric_dtypes())) +@pytest.mark.parametrize( + "ctx", make_unary_params("negative", xps.integer_dtypes() | xps.floating_dtypes()) +) @given(data=st.data()) def test_negative(ctx, data): x = data.draw(ctx.strat, label="x") @@ -1399,36 +1133,14 @@ def test_not_equal(ctx, data): binary_param_assert_dtype(ctx, left, right, out, xp.bool) binary_param_assert_shape(ctx, left, right, out) - if ctx.right_is_scalar: - scalar_type = dh.get_scalar_type(left.dtype) - for idx in sh.ndindex(left.shape): - scalar_l = scalar_type(left[idx]) - expected = scalar_l != right - scalar_o = bool(out[idx]) - f_l = sh.fmt_idx(ctx.left_sym, idx) - f_o = sh.fmt_idx(ctx.res_name, idx) - assert scalar_o == expected, ( - f"{f_o}={scalar_o}, but should be ({f_l} != {right})={expected} " - f"[{ctx.func_name}()]\n{f_l}={scalar_l}" - ) - else: + if not ctx.right_is_scalar: # See test_equal note promoted_dtype = dh.promotion_table[left.dtype, right.dtype] - _left = xp.astype(left, promoted_dtype) - _right = xp.astype(right, promoted_dtype) - scalar_type = dh.get_scalar_type(promoted_dtype) - for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, out.shape): - scalar_l = scalar_type(_left[l_idx]) - scalar_r = scalar_type(_right[r_idx]) - expected = scalar_l != scalar_r - scalar_o = bool(out[o_idx]) - f_l = sh.fmt_idx(ctx.left_sym, l_idx) - f_r = sh.fmt_idx(ctx.right_sym, r_idx) - f_o = sh.fmt_idx(ctx.res_name, o_idx) - assert scalar_o == expected, ( - f"{f_o}={scalar_o}, but should be ({f_l} != {f_r})={expected} " - f"[{ctx.func_name}()]\n{f_l}={scalar_l}, {f_r}={scalar_r}" - ) + left = xp.astype(left, promoted_dtype) + right = xp.astype(right, promoted_dtype) + binary_param_assert_against_refimpl( + ctx, left, right, out, operator.ne, "({} != {})={}", res_stype=bool + ) @pytest.mark.parametrize("ctx", make_unary_params("positive", xps.numeric_dtypes())) @@ -1483,37 +1195,9 @@ def test_remainder(ctx, data): binary_param_assert_dtype(ctx, left, right, res) binary_param_assert_shape(ctx, left, right, res) - scalar_type = dh.get_scalar_type(res.dtype) - if ctx.right_is_scalar: - for idx in sh.ndindex(res.shape): - scalar_l = scalar_type(left[idx]) - expected = scalar_l % right - scalar_o = scalar_type(res[idx]) - if not all(math.isfinite(n) for n in [scalar_l, right, scalar_o, expected]): - continue - f_l = sh.fmt_idx(ctx.left_sym, idx) - f_o = sh.fmt_idx(ctx.res_name, idx) - assert isclose(scalar_o, expected), ( - f"{f_o}={scalar_o}, but should be roughly ({f_l} % {right})={expected} " - f"[{ctx.func_name}()]\n{f_l}={scalar_l}" - ) - else: - for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, res.shape): - scalar_l = scalar_type(left[l_idx]) - scalar_r = scalar_type(right[r_idx]) - expected = scalar_l % scalar_r - scalar_o = scalar_type(res[o_idx]) - if not all( - math.isfinite(n) for n in [scalar_l, scalar_r, scalar_o, expected] - ): - continue - f_l = sh.fmt_idx(ctx.left_sym, l_idx) - f_r = sh.fmt_idx(ctx.right_sym, r_idx) - f_o = sh.fmt_idx(ctx.res_name, o_idx) - assert isclose(scalar_o, expected), ( - f"{f_o}={scalar_o}, but should be roughly ({f_l} % {f_r})={expected} " - f"[{ctx.func_name}()]\n{f_l}={scalar_l}, {f_r}={scalar_r}" - ) + binary_param_assert_against_refimpl( + ctx, left, right, res, operator.mod, "({} % {})={}" + ) @given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes())) @@ -1611,36 +1295,9 @@ def test_subtract(ctx, data): binary_param_assert_dtype(ctx, left, right, res) binary_param_assert_shape(ctx, left, right, res) - m, M = dh.dtype_ranges[res.dtype] - scalar_type = dh.get_scalar_type(res.dtype) - if ctx.right_is_scalar: - for idx in sh.ndindex(res.shape): - scalar_l = scalar_type(left[idx]) - expected = scalar_l - right - if not math.isfinite(expected) or expected <= m or expected >= M: - continue - scalar_o = scalar_type(res[idx]) - f_l = sh.fmt_idx(ctx.left_sym, idx) - f_o = sh.fmt_idx(ctx.res_name, idx) - assert isclose(scalar_o, expected), ( - f"{f_o}={scalar_o}, but should be roughly ({f_l} - {right})={expected} " - f"[{ctx.func_name}()]\n{f_l}={scalar_l}" - ) - else: - for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, res.shape): - scalar_l = scalar_type(left[l_idx]) - scalar_r = scalar_type(right[r_idx]) - expected = scalar_l - scalar_r - if not math.isfinite(expected) or expected <= m or expected >= M: - continue - scalar_o = scalar_type(res[o_idx]) - f_l = sh.fmt_idx(ctx.left_sym, l_idx) - f_r = sh.fmt_idx(ctx.right_sym, r_idx) - f_o = sh.fmt_idx(ctx.res_name, o_idx) - assert isclose(scalar_o, expected), ( - f"{f_o}={scalar_o}, but should be roughly ({f_l} - {f_r})={expected} " - f"[{ctx.func_name}()]\n{f_l}={scalar_l}, {f_r}={scalar_r}" - ) + binary_param_assert_against_refimpl( + ctx, left, right, res, operator.sub, "({} - {})={}" + ) @given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) From 4d849f17484da3c0ed2356f491c4fcc0526f5f8c Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Tue, 1 Feb 2022 10:37:46 +0000 Subject: [PATCH 20/40] Finish elwise TODOs --- ...est_operators_and_elementwise_functions.py | 35 ++++++++++++++----- 1 file changed, 26 insertions(+), 9 deletions(-) diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index de216764..68c6f5b1 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -76,19 +76,30 @@ def unary_assert_against_refimpl( in_stype = dh.get_scalar_type(in_.dtype) if res_stype is None: res_stype = in_stype + if res.dtype != xp.bool: + m, M = dh.dtype_ranges[res.dtype] for idx in sh.ndindex(in_.shape): scalar_i = in_stype(in_[idx]) if not filter_(scalar_i): continue expected = refimpl(scalar_i) + if res.dtype != xp.bool: + if expected <= m or expected >= M: + continue scalar_o = res_stype(res[idx]) f_i = sh.fmt_idx("x", idx) f_o = sh.fmt_idx("out", idx) expr = expr_template.format(f_i, expected) - assert scalar_o == expected, ( - f"{f_o}={scalar_o}, but should be {expr} [{func_name}()]\n" - f"{f_i}={scalar_i}" - ) + if dh.is_float_dtype(res.dtype): + assert isclose(scalar_o, expected), ( + f"{f_o}={scalar_o}, but should be roughly {expr} [{func_name}()]\n" + f"{f_i}={scalar_i}" + ) + else: + assert scalar_o == expected, ( + f"{f_o}={scalar_o}, but should be {expr} [{func_name}()]\n" + f"{f_i}={scalar_i}" + ) def binary_assert_against_refimpl( @@ -1257,7 +1268,7 @@ def test_sin(x): out = xp.sin(x) ph.assert_dtype("sin", x.dtype, out.dtype) ph.assert_shape("sin", out.shape, x.shape) - # TODO + unary_assert_against_refimpl("sin", x, out, math.sin, "sin({})={}") @given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) @@ -1265,7 +1276,7 @@ def test_sinh(x): out = xp.sinh(x) ph.assert_dtype("sinh", x.dtype, out.dtype) ph.assert_shape("sinh", out.shape, x.shape) - # TODO + unary_assert_against_refimpl("sinh", x, out, math.sinh, "sinh({})={}") @given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes())) @@ -1273,13 +1284,19 @@ def test_square(x): out = xp.square(x) ph.assert_dtype("square", x.dtype, out.dtype) ph.assert_shape("square", out.shape, x.shape) + unary_assert_against_refimpl("square", x, out, lambda s: s ** 2, "{}²={}") -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) +@given( + xps.arrays( + dtype=xps.floating_dtypes(), shape=hh.shapes(), elements={"min_value": 0} + ) +) def test_sqrt(x): out = xp.sqrt(x) ph.assert_dtype("sqrt", x.dtype, out.dtype) ph.assert_shape("sqrt", out.shape, x.shape) + unary_assert_against_refimpl("sqrt", x, out, math.sqrt, "sqrt({})={}") @pytest.mark.parametrize("ctx", make_binary_params("subtract", xps.numeric_dtypes())) @@ -1305,7 +1322,7 @@ def test_tan(x): out = xp.tan(x) ph.assert_dtype("tan", x.dtype, out.dtype) ph.assert_shape("tan", out.shape, x.shape) - # TODO + unary_assert_against_refimpl("tan", x, out, math.tan, "tan({})={}") @given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) @@ -1313,7 +1330,7 @@ def test_tanh(x): out = xp.tanh(x) ph.assert_dtype("tanh", x.dtype, out.dtype) ph.assert_shape("tanh", out.shape, x.shape) - # TODO + unary_assert_against_refimpl("tanh", x, out, math.tanh, "tanh({})={}") @given(xps.arrays(dtype=hh.numeric_dtypes, shape=xps.array_shapes())) From 5a82a3369097a3e97de4ea920f0fbb85e4c2e73a Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Tue, 1 Feb 2022 10:47:29 +0000 Subject: [PATCH 21/40] Fix typing issues with refimpl utils --- ...est_operators_and_elementwise_functions.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index 68c6f5b1..2ff46be2 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -76,14 +76,14 @@ def unary_assert_against_refimpl( in_stype = dh.get_scalar_type(in_.dtype) if res_stype is None: res_stype = in_stype - if res.dtype != xp.bool: - m, M = dh.dtype_ranges[res.dtype] + m, M = dh.dtype_ranges.get(res.dtype, (None, None)) for idx in sh.ndindex(in_.shape): scalar_i = in_stype(in_[idx]) if not filter_(scalar_i): continue expected = refimpl(scalar_i) if res.dtype != xp.bool: + assert m is not None and M is not None # for mypy if expected <= m or expected >= M: continue scalar_o = res_stype(res[idx]) @@ -105,7 +105,7 @@ def unary_assert_against_refimpl( def binary_assert_against_refimpl( func_name: str, left: Array, - right: Union[Scalar, Array], + right: Array, res: Array, refimpl: Callable[[Scalar, Scalar], Scalar], expr_template: str, @@ -120,16 +120,15 @@ def binary_assert_against_refimpl( in_stype = dh.get_scalar_type(left.dtype) if res_stype is None: res_stype = in_stype - result_dtype = dh.result_type(left.dtype, right.dtype) - if result_dtype != xp.bool: - m, M = dh.dtype_ranges[result_dtype] + m, M = dh.dtype_ranges.get(res.dtype, (None, None)) for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, res.shape): scalar_l = in_stype(left[l_idx]) scalar_r = in_stype(right[r_idx]) if not (filter_(scalar_l) and filter_(scalar_r)): continue expected = refimpl(scalar_l, scalar_r) - if result_dtype != xp.bool: + if res.dtype != xp.bool: + assert m is not None and M is not None # for mypy if expected <= m or expected >= M: continue scalar_o = res_stype(res[o_idx]) @@ -137,7 +136,7 @@ def binary_assert_against_refimpl( f_r = sh.fmt_idx(right_sym, r_idx) f_o = sh.fmt_idx(res_name, o_idx) expr = expr_template.format(f_l, f_r, expected) - if dh.is_float_dtype(result_dtype): + if dh.is_float_dtype(res.dtype): assert isclose(scalar_o, expected), ( f"{f_o}={scalar_o}, but should be roughly {expr} [{func_name}()]\n" f"{f_l}={scalar_l}, {f_r}={scalar_r}" @@ -357,18 +356,18 @@ def binary_param_assert_against_refimpl( ): if ctx.right_is_scalar: assert filter_(right) # sanity check - if left.dtype != xp.bool: - m, M = dh.dtype_ranges[left.dtype] if in_stype is None: in_stype = dh.get_scalar_type(left.dtype) if res_stype is None: res_stype = in_stype + m, M = dh.dtype_ranges.get(left.dtype, (None, None)) for idx in sh.ndindex(res.shape): scalar_l = in_stype(left[idx]) if not filter_(scalar_l): continue expected = refimpl(scalar_l, right) if left.dtype != xp.bool: + assert m is not None and M is not None # for mypy if expected <= m or expected >= M: continue scalar_o = res_stype(res[idx]) From 73866157cb333db8d59939a19ceb31d8c8b4d0e4 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Tue, 1 Feb 2022 10:50:11 +0000 Subject: [PATCH 22/40] Remove redundant `in_stype` arg in refimpl utils --- .../test_operators_and_elementwise_functions.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index 2ff46be2..aa50846c 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -66,14 +66,12 @@ def unary_assert_against_refimpl( res: Array, refimpl: Callable[[Scalar], Scalar], expr_template: str, - in_stype: Optional[ScalarType] = None, res_stype: Optional[ScalarType] = None, filter_: Callable[[Scalar], bool] = math.isfinite, ): if in_.shape != res.shape: raise ValueError(f"{res.shape=}, but should be {in_.shape=}") - if in_stype is None: - in_stype = dh.get_scalar_type(in_.dtype) + in_stype = dh.get_scalar_type(in_.dtype) if res_stype is None: res_stype = in_stype m, M = dh.dtype_ranges.get(res.dtype, (None, None)) @@ -109,15 +107,13 @@ def binary_assert_against_refimpl( res: Array, refimpl: Callable[[Scalar, Scalar], Scalar], expr_template: str, - in_stype: Optional[ScalarType] = None, res_stype: Optional[ScalarType] = None, left_sym: str = "x1", right_sym: str = "x2", res_name: str = "out", filter_: Callable[[Scalar], bool] = math.isfinite, ): - if in_stype is None: - in_stype = dh.get_scalar_type(left.dtype) + in_stype = dh.get_scalar_type(left.dtype) if res_stype is None: res_stype = in_stype m, M = dh.dtype_ranges.get(res.dtype, (None, None)) @@ -350,14 +346,12 @@ def binary_param_assert_against_refimpl( res: Array, refimpl: Callable[[Scalar, Scalar], Scalar], expr_template: str, - in_stype: Optional[ScalarType] = None, res_stype: Optional[ScalarType] = None, filter_: Callable[[Scalar], bool] = math.isfinite, ): if ctx.right_is_scalar: assert filter_(right) # sanity check - if in_stype is None: - in_stype = dh.get_scalar_type(left.dtype) + in_stype = dh.get_scalar_type(left.dtype) if res_stype is None: res_stype = in_stype m, M = dh.dtype_ranges.get(left.dtype, (None, None)) @@ -389,7 +383,6 @@ def binary_param_assert_against_refimpl( else: binary_assert_against_refimpl( func_name=ctx.func_name, - in_stype=in_stype, left_sym=ctx.left_sym, left=left, right_sym=ctx.right_sym, From 80d29092401f9063d3c4cccbf457152b00a64cbb Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Tue, 1 Feb 2022 10:52:39 +0000 Subject: [PATCH 23/40] Skip when refimpl overflows --- .../test_operators_and_elementwise_functions.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index aa50846c..81d3dba2 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -79,7 +79,10 @@ def unary_assert_against_refimpl( scalar_i = in_stype(in_[idx]) if not filter_(scalar_i): continue - expected = refimpl(scalar_i) + try: + expected = refimpl(scalar_i) + except OverflowError: + continue if res.dtype != xp.bool: assert m is not None and M is not None # for mypy if expected <= m or expected >= M: @@ -122,7 +125,10 @@ def binary_assert_against_refimpl( scalar_r = in_stype(right[r_idx]) if not (filter_(scalar_l) and filter_(scalar_r)): continue - expected = refimpl(scalar_l, scalar_r) + try: + expected = refimpl(scalar_l, scalar_r) + except OverflowError: + continue if res.dtype != xp.bool: assert m is not None and M is not None # for mypy if expected <= m or expected >= M: @@ -359,7 +365,10 @@ def binary_param_assert_against_refimpl( scalar_l = in_stype(left[idx]) if not filter_(scalar_l): continue - expected = refimpl(scalar_l, right) + try: + expected = refimpl(scalar_l, right) + except OverflowError: + continue if left.dtype != xp.bool: assert m is not None and M is not None # for mypy if expected <= m or expected >= M: From 9521f6be71e1c1489bcad884a0a03e98bda6f7f7 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Tue, 1 Feb 2022 12:16:12 +0000 Subject: [PATCH 24/40] Values testing for remaining tests for elwise funcs starting with a --- ...est_operators_and_elementwise_functions.py | 143 ++++++------------ 1 file changed, 50 insertions(+), 93 deletions(-) diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index 81d3dba2..a8baffc8 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -60,6 +60,14 @@ def mock_int_dtype(n: int, dtype: DataType) -> int: return n +def default_filter(s: Scalar) -> bool: + """Returns False when s is a non-finite or a signed zero. + + Used by default as these values are typically special-cased. + """ + return math.isfinite(s) and s is not -0.0 and s is not +0.0 + + def unary_assert_against_refimpl( func_name: str, in_: Array, @@ -67,7 +75,7 @@ def unary_assert_against_refimpl( refimpl: Callable[[Scalar], Scalar], expr_template: str, res_stype: Optional[ScalarType] = None, - filter_: Callable[[Scalar], bool] = math.isfinite, + filter_: Callable[[Scalar], bool] = default_filter, ): if in_.shape != res.shape: raise ValueError(f"{res.shape=}, but should be {in_.shape=}") @@ -114,7 +122,7 @@ def binary_assert_against_refimpl( left_sym: str = "x1", right_sym: str = "x2", res_name: str = "out", - filter_: Callable[[Scalar], bool] = math.isfinite, + filter_: Callable[[Scalar], bool] = default_filter, ): in_stype = dh.get_scalar_type(left.dtype) if res_stype is None: @@ -353,7 +361,7 @@ def binary_param_assert_against_refimpl( refimpl: Callable[[Scalar, Scalar], Scalar], expr_template: str, res_stype: Optional[ScalarType] = None, - filter_: Callable[[Scalar], bool] = math.isfinite, + filter_: Callable[[Scalar], bool] = default_filter, ): if ctx.right_is_scalar: assert filter_(right) # sanity check @@ -429,36 +437,30 @@ def test_abs(ctx, data): ) -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) +@given( + xps.arrays( + dtype=xps.floating_dtypes(), + shape=hh.shapes(), + elements={"min_value": -1, "max_value": 1}, + ) +) def test_acos(x): - res = xp.acos(x) - ph.assert_dtype("acos", x.dtype, res.dtype) - ph.assert_shape("acos", res.shape, x.shape) - ONE = ah.one(x.shape, x.dtype) - # Here (and elsewhere), should technically be res.dtype, but this is the - # same as x.dtype, as tested by the type_promotion tests. - PI = ah.π(x.shape, x.dtype) - ZERO = ah.zero(x.shape, x.dtype) - domain = ah.inrange(x, -ONE, ONE) - codomain = ah.inrange(res, ZERO, PI) - # acos maps [-1, 1] to [0, pi]. Values outside this domain are mapped to - # nan, which is already tested in the special cases. - ah.assert_exactly_equal(domain, codomain) + out = xp.acos(x) + ph.assert_dtype("acos", x.dtype, out.dtype) + ph.assert_shape("acos", out.shape, x.shape) + unary_assert_against_refimpl("acos", x, out, math.acos, "acos({})={}") -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) +@given( + xps.arrays( + dtype=xps.floating_dtypes(), shape=hh.shapes(), elements={"min_value": 1} + ) +) def test_acosh(x): - res = xp.acosh(x) - ph.assert_dtype("acosh", x.dtype, res.dtype) - ph.assert_shape("acosh", res.shape, x.shape) - ONE = ah.one(x.shape, x.dtype) - INFINITY = ah.infinity(x.shape, x.dtype) - ZERO = ah.zero(x.shape, x.dtype) - domain = ah.inrange(x, ONE, INFINITY) - codomain = ah.inrange(res, ZERO, INFINITY) - # acosh maps [-1, inf] to [0, inf]. Values outside this domain are mapped - # to nan, which is already tested in the special cases. - ah.assert_exactly_equal(domain, codomain) + out = xp.acosh(x) + ph.assert_dtype("acosh", x.dtype, out.dtype) + ph.assert_shape("acosh", out.shape, x.shape) + unary_assert_against_refimpl("acosh", x, out, math.acosh, "acosh({})={}") @pytest.mark.parametrize("ctx,", make_binary_params("add", xps.numeric_dtypes())) @@ -479,18 +481,18 @@ def test_add(ctx, data): ) -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) +@given( + xps.arrays( + dtype=xps.floating_dtypes(), + shape=hh.shapes(), + elements={"min_value": -1, "max_value": 1}, + ) +) def test_asin(x): out = xp.asin(x) ph.assert_dtype("asin", x.dtype, out.dtype) ph.assert_shape("asin", out.shape, x.shape) - ONE = ah.one(x.shape, x.dtype) - PI = ah.π(x.shape, x.dtype) - domain = ah.inrange(x, -ONE, ONE) - codomain = ah.inrange(out, -PI / 2, PI / 2) - # asin maps [-1, 1] to [-pi/2, pi/2]. Values outside this domain are - # mapped to nan, which is already tested in the special cases. - ah.assert_exactly_equal(domain, codomain) + unary_assert_against_refimpl("asin", x, out, math.asin, "asin({})={}") @given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) @@ -498,12 +500,7 @@ def test_asinh(x): out = xp.asinh(x) ph.assert_dtype("asinh", x.dtype, out.dtype) ph.assert_shape("asinh", out.shape, x.shape) - INFINITY = ah.infinity(x.shape, x.dtype) - domain = ah.inrange(x, -INFINITY, INFINITY) - codomain = ah.inrange(out, -INFINITY, INFINITY) - # asinh maps [-inf, inf] to [-inf, inf]. Values outside this domain are - # mapped to nan, which is already tested in the special cases. - ah.assert_exactly_equal(domain, codomain) + unary_assert_against_refimpl("asinh", x, out, math.asinh, "asinh({})={}") @given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) @@ -511,13 +508,7 @@ def test_atan(x): out = xp.atan(x) ph.assert_dtype("atan", x.dtype, out.dtype) ph.assert_shape("atan", out.shape, x.shape) - INFINITY = ah.infinity(x.shape, x.dtype) - PI = ah.π(x.shape, x.dtype) - domain = ah.inrange(x, -INFINITY, INFINITY) - codomain = ah.inrange(out, -PI / 2, PI / 2) - # atan maps [-inf, inf] to [-pi/2, pi/2]. Values outside this domain are - # mapped to nan, which is already tested in the special cases. - ah.assert_exactly_equal(domain, codomain) + unary_assert_against_refimpl("atan", x, out, math.atan, "atan({})={}") @given(*hh.two_mutual_arrays(dh.float_dtypes)) @@ -525,55 +516,21 @@ def test_atan2(x1, x2): out = xp.atan2(x1, x2) ph.assert_dtype("atan2", [x1.dtype, x2.dtype], out.dtype) ph.assert_result_shape("atan2", [x1.shape, x2.shape], out.shape) - INFINITY1 = ah.infinity(x1.shape, x1.dtype) - INFINITY2 = ah.infinity(x2.shape, x2.dtype) - PI = ah.π(out.shape, out.dtype) - domainx1 = ah.inrange(x1, -INFINITY1, INFINITY1) - domainx2 = ah.inrange(x2, -INFINITY2, INFINITY2) - # codomain = ah.inrange(out, -PI, PI, 1e-5) - codomain = ah.inrange(out, -PI, PI) - # atan2 maps [-inf, inf] x [-inf, inf] to [-pi, pi]. Values outside - # this domain are mapped to nan, which is already tested in the special - # cases. - ah.assert_exactly_equal(ah.logical_and(domainx1, domainx2), codomain) - # From the spec: - # - # The mathematical signs of `x1_i` and `x2_i` determine the quadrant of - # each element-wise out. The quadrant (i.e., branch) is chosen such - # that each element-wise out is the signed angle in radians between the - # ray ending at the origin and passing through the point `(1,0)` and the - # ray ending at the origin and passing through the point `(x2_i, x1_i)`. - - # This is equivalent to atan2(x1, x2) has the same sign as x1 when x2 is - # finite. - pos_x1 = ah.positive_mathematical_sign(x1) - neg_x1 = ah.negative_mathematical_sign(x1) - pos_x2 = ah.positive_mathematical_sign(x2) - neg_x2 = ah.negative_mathematical_sign(x2) - pos_out = ah.positive_mathematical_sign(out) - neg_out = ah.negative_mathematical_sign(out) - ah.assert_exactly_equal( - ah.logical_or(ah.logical_and(pos_x1, pos_x2), ah.logical_and(pos_x1, neg_x2)), - pos_out, - ) - ah.assert_exactly_equal( - ah.logical_or(ah.logical_and(neg_x1, pos_x2), ah.logical_and(neg_x1, neg_x2)), - neg_out, - ) + binary_assert_against_refimpl("atan2", x1, x2, out, math.atan2, "atan2({})={}") -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) +@given( + xps.arrays( + dtype=xps.floating_dtypes(), + shape=hh.shapes(), + elements={"min_value": -1, "max_value": 1}, + ) +) def test_atanh(x): out = xp.atanh(x) ph.assert_dtype("atanh", x.dtype, out.dtype) ph.assert_shape("atanh", out.shape, x.shape) - ONE = ah.one(x.shape, x.dtype) - INFINITY = ah.infinity(x.shape, x.dtype) - domain = ah.inrange(x, -ONE, ONE) - codomain = ah.inrange(out, -INFINITY, INFINITY) - # atanh maps [-1, 1] to [-inf, inf]. Values outside this domain are - # mapped to nan, which is already tested in the special cases. - ah.assert_exactly_equal(domain, codomain) + unary_assert_against_refimpl("atanh", x, out, math.atanh, "atanh({})={}") @pytest.mark.parametrize( From e50fc1a3cbeb6da2bc8b2f080b168db746b5d7cf Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Tue, 1 Feb 2022 12:55:32 +0000 Subject: [PATCH 25/40] Defaults for `expr_template` in refimpl utils --- ...est_operators_and_elementwise_functions.py | 108 +++++++++--------- 1 file changed, 55 insertions(+), 53 deletions(-) diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index a8baffc8..9fb41ca9 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -73,12 +73,14 @@ def unary_assert_against_refimpl( in_: Array, res: Array, refimpl: Callable[[Scalar], Scalar], - expr_template: str, + expr_template: Optional[str] = None, res_stype: Optional[ScalarType] = None, filter_: Callable[[Scalar], bool] = default_filter, ): if in_.shape != res.shape: raise ValueError(f"{res.shape=}, but should be {in_.shape=}") + if expr_template is None: + expr_template = func_name + "({})={}" in_stype = dh.get_scalar_type(in_.dtype) if res_stype is None: res_stype = in_stype @@ -117,13 +119,15 @@ def binary_assert_against_refimpl( right: Array, res: Array, refimpl: Callable[[Scalar, Scalar], Scalar], - expr_template: str, + expr_template: Optional[str] = None, res_stype: Optional[ScalarType] = None, left_sym: str = "x1", right_sym: str = "x2", res_name: str = "out", filter_: Callable[[Scalar], bool] = default_filter, ): + if expr_template is None: + expr_template = func_name + "({}, {})={}" in_stype = dh.get_scalar_type(left.dtype) if res_stype is None: res_stype = in_stype @@ -358,11 +362,12 @@ def binary_param_assert_against_refimpl( left: Array, right: Union[Array, Scalar], res: Array, + op_sym: str, refimpl: Callable[[Scalar, Scalar], Scalar], - expr_template: str, res_stype: Optional[ScalarType] = None, filter_: Callable[[Scalar], bool] = default_filter, ): + expr_template = "({} " + op_sym + " {})={}" if ctx.right_is_scalar: assert filter_(right) # sanity check in_stype = dh.get_scalar_type(left.dtype) @@ -430,7 +435,7 @@ def test_abs(ctx, data): x, out, abs, - "abs({})={}", + expr_template="abs({})={}", filter_=lambda s: ( s == float("infinity") or (math.isfinite(s) and s is not -0.0) ), @@ -448,7 +453,7 @@ def test_acos(x): out = xp.acos(x) ph.assert_dtype("acos", x.dtype, out.dtype) ph.assert_shape("acos", out.shape, x.shape) - unary_assert_against_refimpl("acos", x, out, math.acos, "acos({})={}") + unary_assert_against_refimpl("acos", x, out, math.acos) @given( @@ -460,7 +465,7 @@ def test_acosh(x): out = xp.acosh(x) ph.assert_dtype("acosh", x.dtype, out.dtype) ph.assert_shape("acosh", out.shape, x.shape) - unary_assert_against_refimpl("acosh", x, out, math.acosh, "acosh({})={}") + unary_assert_against_refimpl("acosh", x, out, math.acosh) @pytest.mark.parametrize("ctx,", make_binary_params("add", xps.numeric_dtypes())) @@ -476,9 +481,7 @@ def test_add(ctx, data): binary_param_assert_dtype(ctx, left, right, res) binary_param_assert_shape(ctx, left, right, res) - binary_param_assert_against_refimpl( - ctx, left, right, res, operator.add, "({} + {})={}" - ) + binary_param_assert_against_refimpl(ctx, left, right, res, "+", operator.add) @given( @@ -492,7 +495,7 @@ def test_asin(x): out = xp.asin(x) ph.assert_dtype("asin", x.dtype, out.dtype) ph.assert_shape("asin", out.shape, x.shape) - unary_assert_against_refimpl("asin", x, out, math.asin, "asin({})={}") + unary_assert_against_refimpl("asin", x, out, math.asin) @given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) @@ -500,7 +503,7 @@ def test_asinh(x): out = xp.asinh(x) ph.assert_dtype("asinh", x.dtype, out.dtype) ph.assert_shape("asinh", out.shape, x.shape) - unary_assert_against_refimpl("asinh", x, out, math.asinh, "asinh({})={}") + unary_assert_against_refimpl("asinh", x, out, math.asinh) @given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) @@ -508,7 +511,7 @@ def test_atan(x): out = xp.atan(x) ph.assert_dtype("atan", x.dtype, out.dtype) ph.assert_shape("atan", out.shape, x.shape) - unary_assert_against_refimpl("atan", x, out, math.atan, "atan({})={}") + unary_assert_against_refimpl("atan", x, out, math.atan) @given(*hh.two_mutual_arrays(dh.float_dtypes)) @@ -516,7 +519,7 @@ def test_atan2(x1, x2): out = xp.atan2(x1, x2) ph.assert_dtype("atan2", [x1.dtype, x2.dtype], out.dtype) ph.assert_result_shape("atan2", [x1.shape, x2.shape], out.shape) - binary_assert_against_refimpl("atan2", x1, x2, out, math.atan2, "atan2({})={}") + binary_assert_against_refimpl("atan2", x1, x2, out, math.atan2) @given( @@ -530,7 +533,7 @@ def test_atanh(x): out = xp.atanh(x) ph.assert_dtype("atanh", x.dtype, out.dtype) ph.assert_shape("atanh", out.shape, x.shape) - unary_assert_against_refimpl("atanh", x, out, math.atanh, "atanh({})={}") + unary_assert_against_refimpl("atanh", x, out, math.atanh) @pytest.mark.parametrize( @@ -549,7 +552,7 @@ def test_bitwise_and(ctx, data): refimpl = lambda l, r: l and r else: refimpl = lambda l, r: mock_int_dtype(l & r, res.dtype) - binary_param_assert_against_refimpl(ctx, left, right, res, refimpl, "({} & {})={}") + binary_param_assert_against_refimpl(ctx, left, right, res, "&", refimpl) @pytest.mark.parametrize( @@ -573,10 +576,10 @@ def test_bitwise_left_shift(ctx, data): left, right, res, - lambda l, r: mock_int_dtype(l << r, res.dtype) - if r < dh.dtype_nbits[res.dtype] - else 0, - "({} << {})={}", + "<<", + lambda l, r: ( + mock_int_dtype(l << r, res.dtype) if r < dh.dtype_nbits[res.dtype] else 0 + ), ) @@ -595,7 +598,7 @@ def test_bitwise_invert(ctx, data): refimpl = lambda s: not s else: refimpl = lambda s: mock_int_dtype(~s, x.dtype) - unary_assert_against_refimpl(ctx.func_name, x, out, refimpl, "~{}={}") + unary_assert_against_refimpl(ctx.func_name, x, out, refimpl, expr_template="~{}={}") @pytest.mark.parametrize( @@ -614,7 +617,7 @@ def test_bitwise_or(ctx, data): refimpl = lambda l, r: l or r else: refimpl = lambda l, r: mock_int_dtype(l | r, res.dtype) - binary_param_assert_against_refimpl(ctx, left, right, res, refimpl, "({} | {})={}") + binary_param_assert_against_refimpl(ctx, left, right, res, "|", refimpl) @pytest.mark.parametrize( @@ -638,8 +641,8 @@ def test_bitwise_right_shift(ctx, data): left, right, res, + ">>", lambda l, r: mock_int_dtype(l >> r, res.dtype), - "({} >> {})={}", ) @@ -656,10 +659,10 @@ def test_bitwise_xor(ctx, data): binary_param_assert_dtype(ctx, left, right, res) binary_param_assert_shape(ctx, left, right, res) if left.dtype == xp.bool: - refimpl = lambda l, r: l ^ r + refimpl = operator.xor else: refimpl = lambda l, r: mock_int_dtype(l ^ r, res.dtype) - binary_param_assert_against_refimpl(ctx, left, right, res, refimpl, "({} ^ {})={}") + binary_param_assert_against_refimpl(ctx, left, right, res, "^", refimpl) @given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes())) @@ -746,7 +749,7 @@ def test_equal(ctx, data): left = xp.astype(left, promoted_dtype) right = xp.astype(right, promoted_dtype) binary_param_assert_against_refimpl( - ctx, left, right, out, operator.eq, "({} == {})={}", res_stype=bool + ctx, left, right, out, "==", operator.eq, res_stype=bool ) @@ -812,9 +815,7 @@ def test_floor_divide(ctx, data): binary_param_assert_dtype(ctx, left, right, res) binary_param_assert_shape(ctx, left, right, res) - binary_param_assert_against_refimpl( - ctx, left, right, res, operator.floordiv, "({} // {})={}" - ) + binary_param_assert_against_refimpl(ctx, left, right, res, "//", operator.floordiv) @pytest.mark.parametrize("ctx", make_binary_params("greater", xps.numeric_dtypes())) @@ -833,7 +834,7 @@ def test_greater(ctx, data): left = xp.astype(left, promoted_dtype) right = xp.astype(right, promoted_dtype) binary_param_assert_against_refimpl( - ctx, left, right, out, operator.gt, "({} > {})={}", res_stype=bool + ctx, left, right, out, ">", operator.gt, res_stype=bool ) @@ -855,7 +856,7 @@ def test_greater_equal(ctx, data): left = xp.astype(left, promoted_dtype) right = xp.astype(right, promoted_dtype) binary_param_assert_against_refimpl( - ctx, left, right, out, operator.ge, "({} >= {})={}", res_stype=bool + ctx, left, right, out, ">=", operator.ge, res_stype=bool ) @@ -931,7 +932,7 @@ def test_less(ctx, data): left = xp.astype(left, promoted_dtype) right = xp.astype(right, promoted_dtype) binary_param_assert_against_refimpl( - ctx, left, right, out, operator.lt, "({} < {})={}", res_stype=bool + ctx, left, right, out, "<", operator.lt, res_stype=bool ) @@ -951,7 +952,7 @@ def test_less_equal(ctx, data): left = xp.astype(left, promoted_dtype) right = xp.astype(right, promoted_dtype) binary_param_assert_against_refimpl( - ctx, left, right, out, operator.le, "({} <= {})={}", res_stype=bool + ctx, left, right, out, "<=", operator.le, res_stype=bool ) @@ -1028,7 +1029,7 @@ def test_logical_and(x1, x2): ph.assert_dtype("logical_and", [x1.dtype, x2.dtype], out.dtype) ph.assert_result_shape("logical_and", [x1.shape, x2.shape], out.shape) binary_assert_against_refimpl( - "logical_and", x1, x2, out, lambda l, r: l and r, "({} and {})={}" + "logical_and", x1, x2, out, lambda l, r: l and r, expr_template="({} and {})={}" ) @@ -1037,7 +1038,9 @@ def test_logical_not(x): out = ah.logical_not(x) ph.assert_dtype("logical_not", x.dtype, out.dtype) ph.assert_shape("logical_not", out.shape, x.shape) - unary_assert_against_refimpl("logical_not", x, out, lambda i: not i, "(not {})={}") + unary_assert_against_refimpl( + "logical_not", x, out, lambda i: not i, expr_template="(not {})={}" + ) @given(*hh.two_mutual_arrays([xp.bool])) @@ -1046,7 +1049,7 @@ def test_logical_or(x1, x2): ph.assert_dtype("logical_or", [x1.dtype, x2.dtype], out.dtype) ph.assert_result_shape("logical_or", [x1.shape, x2.shape], out.shape) binary_assert_against_refimpl( - "logical_or", x1, x2, out, lambda l, r: l or r, "({} or {})={}" + "logical_or", x1, x2, out, lambda l, r: l or r, expr_template="({} or {})={}" ) @@ -1056,7 +1059,7 @@ def test_logical_xor(x1, x2): ph.assert_dtype("logical_xor", [x1.dtype, x2.dtype], out.dtype) ph.assert_result_shape("logical_xor", [x1.shape, x2.shape], out.shape) binary_assert_against_refimpl( - "logical_xor", x1, x2, out, lambda l, r: l ^ r, "({} ^ {})={}" + "logical_xor", x1, x2, out, operator.xor, expr_template="({} ^ {})={}" ) @@ -1070,11 +1073,10 @@ def test_multiply(ctx, data): binary_param_assert_dtype(ctx, left, right, res) binary_param_assert_shape(ctx, left, right, res) - binary_param_assert_against_refimpl( - ctx, left, right, res, operator.mul, "({} * {})={}" - ) + binary_param_assert_against_refimpl(ctx, left, right, res, "*", operator.mul) +# TODO: clarify if uints are acceptable, adjust accordingly @pytest.mark.parametrize( "ctx", make_unary_params("negative", xps.integer_dtypes() | xps.floating_dtypes()) ) @@ -1089,7 +1091,9 @@ def test_negative(ctx, data): ph.assert_dtype(ctx.func_name, x.dtype, out.dtype) ph.assert_shape(ctx.func_name, out.shape, x.shape) - unary_assert_against_refimpl(ctx.func_name, x, out, operator.neg, "-({})={}") + unary_assert_against_refimpl( + ctx.func_name, x, out, operator.neg, expr_template="-({})={}" + ) @pytest.mark.parametrize("ctx", make_binary_params("not_equal", xps.scalar_dtypes())) @@ -1108,7 +1112,7 @@ def test_not_equal(ctx, data): left = xp.astype(left, promoted_dtype) right = xp.astype(right, promoted_dtype) binary_param_assert_against_refimpl( - ctx, left, right, out, operator.ne, "({} != {})={}", res_stype=bool + ctx, left, right, out, "!=", operator.ne, res_stype=bool ) @@ -1164,9 +1168,7 @@ def test_remainder(ctx, data): binary_param_assert_dtype(ctx, left, right, res) binary_param_assert_shape(ctx, left, right, res) - binary_param_assert_against_refimpl( - ctx, left, right, res, operator.mod, "({} % {})={}" - ) + binary_param_assert_against_refimpl(ctx, left, right, res, "%", operator.mod) @given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes())) @@ -1226,7 +1228,7 @@ def test_sin(x): out = xp.sin(x) ph.assert_dtype("sin", x.dtype, out.dtype) ph.assert_shape("sin", out.shape, x.shape) - unary_assert_against_refimpl("sin", x, out, math.sin, "sin({})={}") + unary_assert_against_refimpl("sin", x, out, math.sin) @given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) @@ -1234,7 +1236,7 @@ def test_sinh(x): out = xp.sinh(x) ph.assert_dtype("sinh", x.dtype, out.dtype) ph.assert_shape("sinh", out.shape, x.shape) - unary_assert_against_refimpl("sinh", x, out, math.sinh, "sinh({})={}") + unary_assert_against_refimpl("sinh", x, out, math.sinh) @given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes())) @@ -1242,7 +1244,9 @@ def test_square(x): out = xp.square(x) ph.assert_dtype("square", x.dtype, out.dtype) ph.assert_shape("square", out.shape, x.shape) - unary_assert_against_refimpl("square", x, out, lambda s: s ** 2, "{}²={}") + unary_assert_against_refimpl( + "square", x, out, lambda s: s ** 2, expr_template="{}²={}" + ) @given( @@ -1254,7 +1258,7 @@ def test_sqrt(x): out = xp.sqrt(x) ph.assert_dtype("sqrt", x.dtype, out.dtype) ph.assert_shape("sqrt", out.shape, x.shape) - unary_assert_against_refimpl("sqrt", x, out, math.sqrt, "sqrt({})={}") + unary_assert_against_refimpl("sqrt", x, out, math.sqrt) @pytest.mark.parametrize("ctx", make_binary_params("subtract", xps.numeric_dtypes())) @@ -1270,9 +1274,7 @@ def test_subtract(ctx, data): binary_param_assert_dtype(ctx, left, right, res) binary_param_assert_shape(ctx, left, right, res) - binary_param_assert_against_refimpl( - ctx, left, right, res, operator.sub, "({} - {})={}" - ) + binary_param_assert_against_refimpl(ctx, left, right, res, "-", operator.sub) @given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) @@ -1280,7 +1282,7 @@ def test_tan(x): out = xp.tan(x) ph.assert_dtype("tan", x.dtype, out.dtype) ph.assert_shape("tan", out.shape, x.shape) - unary_assert_against_refimpl("tan", x, out, math.tan, "tan({})={}") + unary_assert_against_refimpl("tan", x, out, math.tan) @given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) @@ -1288,7 +1290,7 @@ def test_tanh(x): out = xp.tanh(x) ph.assert_dtype("tanh", x.dtype, out.dtype) ph.assert_shape("tanh", out.shape, x.shape) - unary_assert_against_refimpl("tanh", x, out, math.tanh, "tanh({})={}") + unary_assert_against_refimpl("tanh", x, out, math.tanh) @given(xps.arrays(dtype=hh.numeric_dtypes, shape=xps.array_shapes())) From 4a364a54a3d4f4148de47dcd17d061c5f018c2b9 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Tue, 1 Feb 2022 17:52:40 +0000 Subject: [PATCH 26/40] Refactor majority of elwise tests with refimpl utils --- ...est_operators_and_elementwise_functions.py | 122 +++++------------- 1 file changed, 35 insertions(+), 87 deletions(-) diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index 9fb41ca9..8197447f 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -667,7 +667,6 @@ def test_bitwise_xor(ctx, data): @given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes())) def test_ceil(x): - # This test is almost identical to test_floor() out = xp.ceil(x) ph.assert_dtype("ceil", x.dtype, out.dtype) ph.assert_shape("ceil", out.shape, x.shape) @@ -686,13 +685,7 @@ def test_cos(x): out = xp.cos(x) ph.assert_dtype("cos", x.dtype, out.dtype) ph.assert_shape("cos", out.shape, x.shape) - ONE = ah.one(x.shape, x.dtype) - INFINITY = ah.infinity(x.shape, x.dtype) - domain = ah.inrange(x, -INFINITY, INFINITY, open=True) - codomain = ah.inrange(out, -ONE, ONE) - # cos maps (-inf, inf) to [-1, 1]. Values outside this domain are mapped - # to nan, which is already tested in the special cases. - ah.assert_exactly_equal(domain, codomain) + unary_assert_against_refimpl("cos", x, out, math.cos) @given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) @@ -700,12 +693,7 @@ def test_cosh(x): out = xp.cosh(x) ph.assert_dtype("cosh", x.dtype, out.dtype) ph.assert_shape("cosh", out.shape, x.shape) - INFINITY = ah.infinity(x.shape, x.dtype) - domain = ah.inrange(x, -INFINITY, INFINITY) - codomain = ah.inrange(out, -INFINITY, INFINITY) - # cosh maps [-inf, inf] to [-inf, inf]. Values outside this domain are - # mapped to nan, which is already tested in the special cases. - ah.assert_exactly_equal(domain, codomain) + unary_assert_against_refimpl("cosh", x, out, math.cosh) @pytest.mark.parametrize("ctx", make_binary_params("divide", xps.floating_dtypes())) @@ -758,13 +746,7 @@ def test_exp(x): out = xp.exp(x) ph.assert_dtype("exp", x.dtype, out.dtype) ph.assert_shape("exp", out.shape, x.shape) - INFINITY = ah.infinity(x.shape, x.dtype) - ZERO = ah.zero(x.shape, x.dtype) - domain = ah.inrange(x, -INFINITY, INFINITY) - codomain = ah.inrange(out, ZERO, INFINITY) - # exp maps [-inf, inf] to [0, inf]. Values outside this domain are - # mapped to nan, which is already tested in the special cases. - ah.assert_exactly_equal(domain, codomain) + unary_assert_against_refimpl("exp", x, out, math.exp) @given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) @@ -772,13 +754,7 @@ def test_expm1(x): out = xp.expm1(x) ph.assert_dtype("expm1", x.dtype, out.dtype) ph.assert_shape("expm1", out.shape, x.shape) - INFINITY = ah.infinity(x.shape, x.dtype) - NEGONE = -ah.one(x.shape, x.dtype) - domain = ah.inrange(x, -INFINITY, INFINITY) - codomain = ah.inrange(out, NEGONE, INFINITY) - # expm1 maps [-inf, inf] to [1, inf]. Values outside this domain are - # mapped to nan, which is already tested in the special cases. - ah.assert_exactly_equal(domain, codomain) + unary_assert_against_refimpl("expm1", x, out, math.expm1) @given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes())) @@ -881,39 +857,17 @@ def test_isfinite(x): @given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes())) def test_isinf(x): out = xp.isinf(x) - ph.assert_dtype("isfinite", x.dtype, out.dtype, xp.bool) ph.assert_shape("isinf", out.shape, x.shape) - - if dh.is_int_dtype(x.dtype): - ah.assert_exactly_equal(out, ah.false(x.shape)) - finite_or_nan = ah.logical_or(ah.isfinite(x), ah.isnan(x)) - ah.assert_exactly_equal(out, ah.logical_not(finite_or_nan)) - - # Test the exact value by comparing to the math version - if dh.is_float_dtype(x.dtype): - for idx in sh.ndindex(x.shape): - s = float(x[idx]) - assert bool(out[idx]) == math.isinf(s) + unary_assert_against_refimpl("isinf", x, out, math.isinf, res_stype=bool) @given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes())) def test_isnan(x): out = ah.isnan(x) - ph.assert_dtype("isnan", x.dtype, out.dtype, xp.bool) ph.assert_shape("isnan", out.shape, x.shape) - - if dh.is_int_dtype(x.dtype): - ah.assert_exactly_equal(out, ah.false(x.shape)) - finite_or_inf = ah.logical_or(ah.isfinite(x), xp.isinf(x)) - ah.assert_exactly_equal(out, ah.logical_not(finite_or_inf)) - - # Test the exact value by comparing to the math version - if dh.is_float_dtype(x.dtype): - for idx in sh.ndindex(x.shape): - s = float(x[idx]) - assert bool(out[idx]) == math.isnan(s) + unary_assert_against_refimpl("isnan", x, out, math.isnan, res_stype=bool) @pytest.mark.parametrize("ctx", make_binary_params("less", xps.numeric_dtypes())) @@ -956,62 +910,56 @@ def test_less_equal(ctx, data): ) -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) +@given( + xps.arrays( + dtype=xps.floating_dtypes(), shape=hh.shapes(), elements={"min_value": 1} + ) +) def test_log(x): out = xp.log(x) - ph.assert_dtype("log", x.dtype, out.dtype) ph.assert_shape("log", out.shape, x.shape) - - INFINITY = ah.infinity(x.shape, x.dtype) - ZERO = ah.zero(x.shape, x.dtype) - domain = ah.inrange(x, ZERO, INFINITY) - codomain = ah.inrange(out, -INFINITY, INFINITY) - # log maps [0, inf] to [-inf, inf]. Values outside this domain are - # mapped to nan, which is already tested in the special cases. - ah.assert_exactly_equal(domain, codomain) + unary_assert_against_refimpl("log", x, out, math.log) -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) +@given( + xps.arrays( + dtype=xps.floating_dtypes(), shape=hh.shapes(), elements={"min_value": 1} + ) +) def test_log1p(x): out = xp.log1p(x) ph.assert_dtype("log1p", x.dtype, out.dtype) ph.assert_shape("log1p", out.shape, x.shape) - INFINITY = ah.infinity(x.shape, x.dtype) - NEGONE = -ah.one(x.shape, x.dtype) - codomain = ah.inrange(x, NEGONE, INFINITY) - domain = ah.inrange(out, -INFINITY, INFINITY) - # log1p maps [1, inf] to [-inf, inf]. Values outside this domain are - # mapped to nan, which is already tested in the special cases. - ah.assert_exactly_equal(domain, codomain) + unary_assert_against_refimpl("log1p", x, out, math.log1p) -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) +@given( + xps.arrays( + dtype=xps.floating_dtypes(), + shape=hh.shapes(), + elements={"min_value": 0, "exclude_min": True}, + ) +) def test_log2(x): out = xp.log2(x) ph.assert_dtype("log2", x.dtype, out.dtype) ph.assert_shape("log2", out.shape, x.shape) - INFINITY = ah.infinity(x.shape, x.dtype) - ZERO = ah.zero(x.shape, x.dtype) - domain = ah.inrange(x, ZERO, INFINITY) - codomain = ah.inrange(out, -INFINITY, INFINITY) - # log2 maps [0, inf] to [-inf, inf]. Values outside this domain are - # mapped to nan, which is already tested in the special cases. - ah.assert_exactly_equal(domain, codomain) + unary_assert_against_refimpl("log2", x, out, math.log2) -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) +@given( + xps.arrays( + dtype=xps.floating_dtypes(), + shape=hh.shapes(), + elements={"min_value": 0, "exclude_min": True}, + ) +) def test_log10(x): out = xp.log10(x) ph.assert_dtype("log10", x.dtype, out.dtype) ph.assert_shape("log10", out.shape, x.shape) - INFINITY = ah.infinity(x.shape, x.dtype) - ZERO = ah.zero(x.shape, x.dtype) - domain = ah.inrange(x, ZERO, INFINITY) - codomain = ah.inrange(out, -INFINITY, INFINITY) - # log10 maps [0, inf] to [-inf, inf]. Values outside this domain are - # mapped to nan, which is already tested in the special cases. - ah.assert_exactly_equal(domain, codomain) + unary_assert_against_refimpl("log10", x, out, math.log10) @given(*hh.two_mutual_arrays(dh.float_dtypes)) @@ -1204,7 +1152,7 @@ def test_sign(x): out = xp.sign(x) ph.assert_dtype("sign", x.dtype, out.dtype) ph.assert_shape("sign", out.shape, x.shape) - scalar_type = dh.get_scalar_type(x.dtype) + scalar_type = dh.get_scalar_type(out.dtype) for idx in sh.ndindex(x.shape): scalar_x = scalar_type(x[idx]) f_x = sh.fmt_idx("x", idx) From 56aa06dee332510e8ba0a717f7dfeec4a68eb79d Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Tue, 1 Feb 2022 18:59:29 +0000 Subject: [PATCH 27/40] `strict_check` kwarg for refiml utils for testing integrals --- ...est_operators_and_elementwise_functions.py | 59 ++++--------------- 1 file changed, 12 insertions(+), 47 deletions(-) diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index 8197447f..6be0e321 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -49,7 +49,7 @@ def isclose(a: float, b: float, rel_tol: float = 0.25, abs_tol: float = 1) -> bo def mock_int_dtype(n: int, dtype: DataType) -> int: - """Returns equivalent of `n` that mocks `dtype` behaviour""" + """Returns equivalent of `n` that mocks `dtype` behaviour.""" nbits = dh.dtype_nbits[dtype] mask = (1 << nbits) - 1 n &= mask @@ -76,6 +76,7 @@ def unary_assert_against_refimpl( expr_template: Optional[str] = None, res_stype: Optional[ScalarType] = None, filter_: Callable[[Scalar], bool] = default_filter, + strict_check: bool = False, ): if in_.shape != res.shape: raise ValueError(f"{res.shape=}, but should be {in_.shape=}") @@ -101,7 +102,7 @@ def unary_assert_against_refimpl( f_i = sh.fmt_idx("x", idx) f_o = sh.fmt_idx("out", idx) expr = expr_template.format(f_i, expected) - if dh.is_float_dtype(res.dtype): + if not strict_check and dh.is_float_dtype(res.dtype): assert isclose(scalar_o, expected), ( f"{f_o}={scalar_o}, but should be roughly {expr} [{func_name}()]\n" f"{f_i}={scalar_i}" @@ -125,6 +126,7 @@ def binary_assert_against_refimpl( right_sym: str = "x2", res_name: str = "out", filter_: Callable[[Scalar], bool] = default_filter, + strict_check: bool = False, ): if expr_template is None: expr_template = func_name + "({}, {})={}" @@ -150,7 +152,7 @@ def binary_assert_against_refimpl( f_r = sh.fmt_idx(right_sym, r_idx) f_o = sh.fmt_idx(res_name, o_idx) expr = expr_template.format(f_l, f_r, expected) - if dh.is_float_dtype(res.dtype): + if not strict_check and dh.is_float_dtype(res.dtype): assert isclose(scalar_o, expected), ( f"{f_o}={scalar_o}, but should be roughly {expr} [{func_name}()]\n" f"{f_l}={scalar_l}, {f_r}={scalar_r}" @@ -366,6 +368,7 @@ def binary_param_assert_against_refimpl( refimpl: Callable[[Scalar, Scalar], Scalar], res_stype: Optional[ScalarType] = None, filter_: Callable[[Scalar], bool] = default_filter, + strict_check: bool = False, ): expr_template = "({} " + op_sym + " {})={}" if ctx.right_is_scalar: @@ -390,7 +393,7 @@ def binary_param_assert_against_refimpl( f_l = sh.fmt_idx(ctx.left_sym, idx) f_o = sh.fmt_idx(ctx.res_name, idx) expr = expr_template.format(f_l, right, expected) - if dh.is_float_dtype(left.dtype): + if not strict_check and dh.is_float_dtype(left.dtype): assert isclose(scalar_o, expected), ( f"{f_o}={scalar_o}, but should be roughly {expr} " f"[{ctx.func_name}()]\n" @@ -415,6 +418,7 @@ def binary_param_assert_against_refimpl( refimpl=refimpl, expr_template=expr_template, filter_=filter_, + strict_check=strict_check, ) @@ -670,14 +674,7 @@ def test_ceil(x): out = xp.ceil(x) ph.assert_dtype("ceil", x.dtype, out.dtype) ph.assert_shape("ceil", out.shape, x.shape) - finite = ah.isfinite(x) - ah.assert_integral(out[finite]) - assert ah.all(ah.less_equal(x[finite], out[finite])) - assert ah.all( - ah.less_equal(out[finite] - x[finite], ah.one(x[finite].shape, x.dtype)) - ) - integers = ah.isintegral(x) - ah.assert_exactly_equal(out[integers], x[integers]) + unary_assert_against_refimpl("ceil", x, out, math.ceil, strict_check=True) @given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) @@ -759,18 +756,10 @@ def test_expm1(x): @given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes())) def test_floor(x): - # This test is almost identical to test_ceil out = xp.floor(x) ph.assert_dtype("floor", x.dtype, out.dtype) ph.assert_shape("floor", out.shape, x.shape) - finite = ah.isfinite(x) - ah.assert_integral(out[finite]) - assert ah.all(ah.less_equal(out[finite], x[finite])) - assert ah.all( - ah.less_equal(x[finite] - out[finite], ah.one(x[finite].shape, x.dtype)) - ) - integers = ah.isintegral(x) - ah.assert_exactly_equal(out[integers], x[integers]) + unary_assert_against_refimpl("floor", x, out, math.floor, strict_check=True) @pytest.mark.parametrize( @@ -1122,29 +1111,9 @@ def test_remainder(ctx, data): @given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes())) def test_round(x): out = xp.round(x) - ph.assert_dtype("round", x.dtype, out.dtype) - ph.assert_shape("round", out.shape, x.shape) - - # Test that the out is integral - finite = ah.isfinite(x) - ah.assert_integral(out[finite]) - - # round(x) should be the neaoutt integer to x. The case where there is a - # tie (round to even) is already handled by the special cases tests. - - # This is the same strategy used in the mask in the - # test_round_special_cases_one_arg_two_integers_equally_close special - # cases test. - floor = xp.floor(x) - ceil = xp.ceil(x) - over = xp.subtract(x, floor) - under = xp.subtract(ceil, x) - round_down = ah.less(over, under) - round_up = ah.less(under, over) - ah.assert_exactly_equal(out[round_down], floor[round_down]) - ah.assert_exactly_equal(out[round_up], ceil[round_up]) + unary_assert_against_refimpl("round", x, out, round, strict_check=True) @given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes())) @@ -1246,8 +1215,4 @@ def test_trunc(x): out = xp.trunc(x) ph.assert_dtype("trunc", x.dtype, out.dtype) ph.assert_shape("trunc", out.shape, x.shape) - if dh.is_int_dtype(x.dtype): - ah.assert_exactly_equal(out, x) - else: - finite = ah.isfinite(x) - ah.assert_integral(out[finite]) + unary_assert_against_refimpl("trunc", x, out, math.trunc, strict_check=True) From dfda4f545c7dcf588b6f74c86695e2f8be5e36de Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Tue, 1 Feb 2022 19:13:37 +0000 Subject: [PATCH 28/40] Pass but filter out-of-range values for trig function tests --- ...est_operators_and_elementwise_functions.py | 104 +++++++----------- 1 file changed, 40 insertions(+), 64 deletions(-) diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index 6be0e321..8e225a4e 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -446,30 +446,24 @@ def test_abs(ctx, data): ) -@given( - xps.arrays( - dtype=xps.floating_dtypes(), - shape=hh.shapes(), - elements={"min_value": -1, "max_value": 1}, - ) -) +@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) def test_acos(x): out = xp.acos(x) ph.assert_dtype("acos", x.dtype, out.dtype) ph.assert_shape("acos", out.shape, x.shape) - unary_assert_against_refimpl("acos", x, out, math.acos) + unary_assert_against_refimpl( + "acos", x, out, math.acos, filter_=lambda s: default_filter(s) and -1 <= s <= 1 + ) -@given( - xps.arrays( - dtype=xps.floating_dtypes(), shape=hh.shapes(), elements={"min_value": 1} - ) -) +@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) def test_acosh(x): out = xp.acosh(x) ph.assert_dtype("acosh", x.dtype, out.dtype) ph.assert_shape("acosh", out.shape, x.shape) - unary_assert_against_refimpl("acosh", x, out, math.acosh) + unary_assert_against_refimpl( + "acosh", x, out, math.acosh, filter_=lambda s: default_filter(s) and s >= 1 + ) @pytest.mark.parametrize("ctx,", make_binary_params("add", xps.numeric_dtypes())) @@ -488,18 +482,14 @@ def test_add(ctx, data): binary_param_assert_against_refimpl(ctx, left, right, res, "+", operator.add) -@given( - xps.arrays( - dtype=xps.floating_dtypes(), - shape=hh.shapes(), - elements={"min_value": -1, "max_value": 1}, - ) -) +@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) def test_asin(x): out = xp.asin(x) ph.assert_dtype("asin", x.dtype, out.dtype) ph.assert_shape("asin", out.shape, x.shape) - unary_assert_against_refimpl("asin", x, out, math.asin) + unary_assert_against_refimpl( + "asin", x, out, math.asin, filter_=lambda s: default_filter(s) and -1 <= s <= 1 + ) @given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) @@ -526,18 +516,18 @@ def test_atan2(x1, x2): binary_assert_against_refimpl("atan2", x1, x2, out, math.atan2) -@given( - xps.arrays( - dtype=xps.floating_dtypes(), - shape=hh.shapes(), - elements={"min_value": -1, "max_value": 1}, - ) -) +@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) def test_atanh(x): out = xp.atanh(x) ph.assert_dtype("atanh", x.dtype, out.dtype) ph.assert_shape("atanh", out.shape, x.shape) - unary_assert_against_refimpl("atanh", x, out, math.atanh) + unary_assert_against_refimpl( + "atanh", + x, + out, + math.atanh, + filter_=lambda s: default_filter(s) and -1 <= s <= 1, + ) @pytest.mark.parametrize( @@ -899,56 +889,44 @@ def test_less_equal(ctx, data): ) -@given( - xps.arrays( - dtype=xps.floating_dtypes(), shape=hh.shapes(), elements={"min_value": 1} - ) -) +@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) def test_log(x): out = xp.log(x) ph.assert_dtype("log", x.dtype, out.dtype) ph.assert_shape("log", out.shape, x.shape) - unary_assert_against_refimpl("log", x, out, math.log) + unary_assert_against_refimpl( + "log", x, out, math.log, filter_=lambda s: default_filter(s) and s >= 1 + ) -@given( - xps.arrays( - dtype=xps.floating_dtypes(), shape=hh.shapes(), elements={"min_value": 1} - ) -) +@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) def test_log1p(x): out = xp.log1p(x) ph.assert_dtype("log1p", x.dtype, out.dtype) ph.assert_shape("log1p", out.shape, x.shape) - unary_assert_against_refimpl("log1p", x, out, math.log1p) + unary_assert_against_refimpl( + "log1p", x, out, math.log1p, filter_=lambda s: default_filter(s) and s >= 1 + ) -@given( - xps.arrays( - dtype=xps.floating_dtypes(), - shape=hh.shapes(), - elements={"min_value": 0, "exclude_min": True}, - ) -) +@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) def test_log2(x): out = xp.log2(x) ph.assert_dtype("log2", x.dtype, out.dtype) ph.assert_shape("log2", out.shape, x.shape) - unary_assert_against_refimpl("log2", x, out, math.log2) + unary_assert_against_refimpl( + "log2", x, out, math.log2, filter_=lambda s: default_filter(s) and s > 1 + ) -@given( - xps.arrays( - dtype=xps.floating_dtypes(), - shape=hh.shapes(), - elements={"min_value": 0, "exclude_min": True}, - ) -) +@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) def test_log10(x): out = xp.log10(x) ph.assert_dtype("log10", x.dtype, out.dtype) ph.assert_shape("log10", out.shape, x.shape) - unary_assert_against_refimpl("log10", x, out, math.log10) + unary_assert_against_refimpl( + "log10", x, out, math.log10, filter_=lambda s: default_filter(s) and s > 0 + ) @given(*hh.two_mutual_arrays(dh.float_dtypes)) @@ -1166,16 +1144,14 @@ def test_square(x): ) -@given( - xps.arrays( - dtype=xps.floating_dtypes(), shape=hh.shapes(), elements={"min_value": 0} - ) -) +@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) def test_sqrt(x): out = xp.sqrt(x) ph.assert_dtype("sqrt", x.dtype, out.dtype) ph.assert_shape("sqrt", out.shape, x.shape) - unary_assert_against_refimpl("sqrt", x, out, math.sqrt) + unary_assert_against_refimpl( + "sqrt", x, out, math.sqrt, filter_=lambda s: default_filter(s) and s >= 0 + ) @pytest.mark.parametrize("ctx", make_binary_params("subtract", xps.numeric_dtypes())) From 9d1f4da9ad508a00313a56574387b6178e8ca7cc Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Tue, 1 Feb 2022 19:54:59 +0000 Subject: [PATCH 29/40] Extend note on refimpl utils --- ...est_operators_and_elementwise_functions.py | 60 ++++++++++++------- 1 file changed, 38 insertions(+), 22 deletions(-) diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index 8e225a4e..1e390a6a 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -1,14 +1,3 @@ -""" -Tests for elementwise functions - -https://data-apis.github.io/array-api/latest/API_specification/elementwise_functions.html - -This tests behavior that is explicitly mentioned in the spec. Note that the -spec does not make any accuracy requirements for functions, so this does not -test that. Tests for the special cases are generated and tested separately in -special_cases/ -""" - import math import operator from enum import Enum, auto @@ -41,13 +30,6 @@ def boolean_and_all_integer_dtypes() -> st.SearchStrategy[DataType]: return xps.boolean_dtypes() | all_integer_dtypes() -def isclose(a: float, b: float, rel_tol: float = 0.25, abs_tol: float = 1) -> bool: - """Wraps math.isclose with more generous defaults.""" - if not (math.isfinite(a) and math.isfinite(b)): - raise ValueError(f"{a=} and {b=}, but input must be finite") - return math.isclose(a, b, rel_tol=rel_tol, abs_tol=abs_tol) - - def mock_int_dtype(n: int, dtype: DataType) -> int: """Returns equivalent of `n` that mocks `dtype` behaviour.""" nbits = dh.dtype_nbits[dtype] @@ -60,6 +42,40 @@ def mock_int_dtype(n: int, dtype: DataType) -> int: return n +# This module tests elementwise functions/operators against a reference +# implementation. We iterate through the input array(s) and resulting array, +# casting the indexed arrays to Python scalars and calculating the expected +# output with `refimpl` function. +# +# This is finicky to refactor, but possible and ultimately worthwhile - hence +# why these *_assert_again_refimpl() utilities exist. +# +# Values which are special-cased are generated and passed, but are filtered by +# the `filter_` callable before they can be asserted against `refimpl`. We +# automatically generate tests for special cases in the special_cases/ dir. We +# still pass them here so as to ensure their presence doesn't affect the outputs +# respective to non-special-cased elements. +# +# By default, results are casted to scalars the same way that the inputs are. +# You can specify a cast via `res_stype, i.e. when a function accepts numerical +# inputs but returns boolean arrays. +# +# By default, floating-point functions/methods are loosely asserted against. Use +# `strict_check=True` when they should be strictly asserted against, i.e. +# when a function should return intergrals. + + +def isclose(a: float, b: float, rel_tol: float = 0.25, abs_tol: float = 1) -> bool: + """Wraps math.isclose with very generous defaults. + + This is useful for many floating-point operations where the spec does not + make accuracy requirements. + """ + if not (math.isfinite(a) and math.isfinite(b)): + raise ValueError(f"{a=} and {b=}, but input must be finite") + return math.isclose(a, b, rel_tol=rel_tol, abs_tol=abs_tol) + + def default_filter(s: Scalar) -> bool: """Returns False when s is a non-finite or a signed zero. @@ -168,14 +184,14 @@ def binary_assert_against_refimpl( # elementwise methods. We do this by parametrizing a generalised test method # with every relevant method and operator. # -# Notable arguments in the parameter: +# Notable arguments in the parameter's context object: # - The function object, which for operator test cases is a wrapper that allows # test logic to be generalised. # - The argument strategies, which can be used to draw arguments for the test # case. They may require additional filtering for certain test cases. -# - right_is_scalar (binary parameters), which denotes if the right argument is -# a scalar in a test case. This can be used to appropiately adjust draw -# filtering and test logic. +# - right_is_scalar (binary parameters only), which denotes if the right +# argument is a scalar in a test case. This can be used to appropiately adjust +# draw filtering and test logic. func_to_op = {v: k for k, v in dh.op_to_func.items()} From e72184e54bb111fd95d1e0514ef7cb7b6980bdc3 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Wed, 2 Feb 2022 10:02:17 +0000 Subject: [PATCH 30/40] Refactor remaining elwise/op tests --- ...est_operators_and_elementwise_functions.py | 89 ++++++++----------- 1 file changed, 36 insertions(+), 53 deletions(-) diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index 1e390a6a..7c0f45f3 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -62,7 +62,8 @@ def mock_int_dtype(n: int, dtype: DataType) -> int: # # By default, floating-point functions/methods are loosely asserted against. Use # `strict_check=True` when they should be strictly asserted against, i.e. -# when a function should return intergrals. +# when a function should return intergrals. Likewise, use `strict_check=False` +# when integer function/methods should be loosely asserted against. def isclose(a: float, b: float, rel_tol: float = 0.25, abs_tol: float = 1) -> bool: @@ -92,7 +93,7 @@ def unary_assert_against_refimpl( expr_template: Optional[str] = None, res_stype: Optional[ScalarType] = None, filter_: Callable[[Scalar], bool] = default_filter, - strict_check: bool = False, + strict_check: Optional[bool] = None, ): if in_.shape != res.shape: raise ValueError(f"{res.shape=}, but should be {in_.shape=}") @@ -108,7 +109,7 @@ def unary_assert_against_refimpl( continue try: expected = refimpl(scalar_i) - except OverflowError: + except Exception: continue if res.dtype != xp.bool: assert m is not None and M is not None # for mypy @@ -118,7 +119,7 @@ def unary_assert_against_refimpl( f_i = sh.fmt_idx("x", idx) f_o = sh.fmt_idx("out", idx) expr = expr_template.format(f_i, expected) - if not strict_check and dh.is_float_dtype(res.dtype): + if strict_check == False or dh.is_float_dtype(res.dtype): assert isclose(scalar_o, expected), ( f"{f_o}={scalar_o}, but should be roughly {expr} [{func_name}()]\n" f"{f_i}={scalar_i}" @@ -142,7 +143,7 @@ def binary_assert_against_refimpl( right_sym: str = "x2", res_name: str = "out", filter_: Callable[[Scalar], bool] = default_filter, - strict_check: bool = False, + strict_check: Optional[bool] = None, ): if expr_template is None: expr_template = func_name + "({}, {})={}" @@ -157,7 +158,7 @@ def binary_assert_against_refimpl( continue try: expected = refimpl(scalar_l, scalar_r) - except OverflowError: + except Exception: continue if res.dtype != xp.bool: assert m is not None and M is not None # for mypy @@ -168,7 +169,7 @@ def binary_assert_against_refimpl( f_r = sh.fmt_idx(right_sym, r_idx) f_o = sh.fmt_idx(res_name, o_idx) expr = expr_template.format(f_l, f_r, expected) - if not strict_check and dh.is_float_dtype(res.dtype): + if strict_check == False or dh.is_float_dtype(res.dtype): assert isclose(scalar_o, expected), ( f"{f_o}={scalar_o}, but should be roughly {expr} [{func_name}()]\n" f"{f_l}={scalar_l}, {f_r}={scalar_r}" @@ -384,11 +385,12 @@ def binary_param_assert_against_refimpl( refimpl: Callable[[Scalar, Scalar], Scalar], res_stype: Optional[ScalarType] = None, filter_: Callable[[Scalar], bool] = default_filter, - strict_check: bool = False, + strict_check: Optional[bool] = None, ): expr_template = "({} " + op_sym + " {})={}" if ctx.right_is_scalar: - assert filter_(right) # sanity check + if filter_(right): + return # short-circuit here as there will be nothing to test in_stype = dh.get_scalar_type(left.dtype) if res_stype is None: res_stype = in_stype @@ -399,7 +401,7 @@ def binary_param_assert_against_refimpl( continue try: expected = refimpl(scalar_l, right) - except OverflowError: + except Exception: continue if left.dtype != xp.bool: assert m is not None and M is not None # for mypy @@ -409,7 +411,7 @@ def binary_param_assert_against_refimpl( f_l = sh.fmt_idx(ctx.left_sym, idx) f_o = sh.fmt_idx(ctx.res_name, idx) expr = expr_template.format(f_l, right, expected) - if not strict_check and dh.is_float_dtype(left.dtype): + if strict_check == False or dh.is_float_dtype(res.dtype): assert isclose(scalar_o, expected), ( f"{f_o}={scalar_o}, but should be roughly {expr} " f"[{ctx.func_name}()]\n" @@ -704,16 +706,22 @@ def test_cosh(x): def test_divide(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) right = data.draw(ctx.right_strat, label=ctx.right_sym) + if ctx.right_is_scalar: + assume res = ctx.func(left, right) binary_param_assert_dtype(ctx, left, right, res) binary_param_assert_shape(ctx, left, right, res) - # There isn't much we can test here. The spec doesn't require any behavior - # beyond the special cases, and indeed, there aren't many mathematical - # properties of division that strictly hold for floating-point numbers. We - # could test that this does implement IEEE 754 division, but we don't yet - # have those sorts in general for this module. + binary_param_assert_against_refimpl( + ctx, + left, + right, + res, + "/", + operator.truediv, + filter_=lambda s: math.isfinite(s) and s != 0, + ) @pytest.mark.parametrize("ctx", make_binary_params("equal", xps.scalar_dtypes())) @@ -836,17 +844,7 @@ def test_isfinite(x): out = ah.isfinite(x) ph.assert_dtype("isfinite", x.dtype, out.dtype, xp.bool) ph.assert_shape("isfinite", out.shape, x.shape) - if dh.is_int_dtype(x.dtype): - ah.assert_exactly_equal(out, ah.true(x.shape)) - # Test that isfinite, isinf, and isnan are self-consistent. - inf = ah.logical_or(xp.isinf(x), ah.isnan(x)) - ah.assert_exactly_equal(out, ah.logical_not(inf)) - - # Test the exact value by comparing to the math version - if dh.is_float_dtype(x.dtype): - for idx in sh.ndindex(x.shape): - s = float(x[idx]) - assert bool(out[idx]) == math.isfinite(s) + unary_assert_against_refimpl("isfinite", x, out, math.isfinite, res_stype=bool) @given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes())) @@ -949,9 +947,10 @@ def test_log10(x): def test_logaddexp(x1, x2): out = xp.logaddexp(x1, x2) ph.assert_dtype("logaddexp", [x1.dtype, x2.dtype], out.dtype) - # The spec doesn't require any behavior for this function. We could test - # that this is indeed an approximation of log(exp(x1) + exp(x2)), but we - # don't have tests for this sort of thing for any functions yet. + ph.assert_result_shape("logaddexp", [x1.shape, x2.shape], out.shape) + binary_assert_against_refimpl( + "logaddexp", x1, x2, out, lambda l, r: math.log(math.exp(l) + math.exp(r)) + ) @given(*hh.two_mutual_arrays([xp.bool])) @@ -1078,11 +1077,9 @@ def test_pow(ctx, data): binary_param_assert_dtype(ctx, left, right, res) binary_param_assert_shape(ctx, left, right, res) - # There isn't much we can test here. The spec doesn't require any behavior - # beyond the special cases, and indeed, there aren't many mathematical - # properties of exponentiation that strictly hold for floating-point - # numbers. We could test that this does implement IEEE 754 pow, but we - # don't yet have those sorts in general for this module. + binary_param_assert_against_refimpl( + ctx, left, right, res, "**", math.pow, strict_check=False + ) @pytest.mark.parametrize("ctx", make_binary_params("remainder", xps.numeric_dtypes())) @@ -1110,28 +1107,14 @@ def test_round(x): unary_assert_against_refimpl("round", x, out, round, strict_check=True) -@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(), elements=finite_kw)) def test_sign(x): out = xp.sign(x) ph.assert_dtype("sign", x.dtype, out.dtype) ph.assert_shape("sign", out.shape, x.shape) - scalar_type = dh.get_scalar_type(out.dtype) - for idx in sh.ndindex(x.shape): - scalar_x = scalar_type(x[idx]) - f_x = sh.fmt_idx("x", idx) - if math.isnan(scalar_x): - continue - if scalar_x == 0: - expected = 0 - expr = f"{f_x}=0" - else: - expected = 1 if scalar_x > 0 else -1 - expr = f"({f_x} / |{f_x}|)={expected}" - scalar_o = scalar_type(out[idx]) - f_o = sh.fmt_idx("out", idx) - assert ( - scalar_o == expected - ), f"{f_o}={scalar_o}, but should be {expr} [sign()]\n{f_x}={scalar_x}" + unary_assert_against_refimpl( + "sign", x, out, lambda s: math.copysign(1, s), filter_=lambda s: s != 0 + ) @given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) From 9edcfccffd6de22e1d33e4681cee00c945f91f1a Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Wed, 2 Feb 2022 10:31:51 +0000 Subject: [PATCH 31/40] Favour use of `operator` for `refimpl` --- ...est_operators_and_elementwise_functions.py | 37 +++++++------------ 1 file changed, 14 insertions(+), 23 deletions(-) diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index 7c0f45f3..7fc2a10c 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -561,7 +561,7 @@ def test_bitwise_and(ctx, data): binary_param_assert_dtype(ctx, left, right, res) binary_param_assert_shape(ctx, left, right, res) if left.dtype == xp.bool: - refimpl = lambda l, r: l and r + refimpl = operator.and_ else: refimpl = lambda l, r: mock_int_dtype(l & r, res.dtype) binary_param_assert_against_refimpl(ctx, left, right, res, "&", refimpl) @@ -583,15 +583,9 @@ def test_bitwise_left_shift(ctx, data): binary_param_assert_dtype(ctx, left, right, res) binary_param_assert_shape(ctx, left, right, res) + nbits = res.dtype binary_param_assert_against_refimpl( - ctx, - left, - right, - res, - "<<", - lambda l, r: ( - mock_int_dtype(l << r, res.dtype) if r < dh.dtype_nbits[res.dtype] else 0 - ), + ctx, left, right, res, "<<", lambda l, r: l << r if r < nbits else 0 ) @@ -607,7 +601,7 @@ def test_bitwise_invert(ctx, data): ph.assert_dtype(ctx.func_name, x.dtype, out.dtype) ph.assert_shape(ctx.func_name, out.shape, x.shape) if x.dtype == xp.bool: - refimpl = lambda s: not s + refimpl = operator.not_ else: refimpl = lambda s: mock_int_dtype(~s, x.dtype) unary_assert_against_refimpl(ctx.func_name, x, out, refimpl, expr_template="~{}={}") @@ -626,7 +620,7 @@ def test_bitwise_or(ctx, data): binary_param_assert_dtype(ctx, left, right, res) binary_param_assert_shape(ctx, left, right, res) if left.dtype == xp.bool: - refimpl = lambda l, r: l or r + refimpl = operator.or_ else: refimpl = lambda l, r: mock_int_dtype(l | r, res.dtype) binary_param_assert_against_refimpl(ctx, left, right, res, "|", refimpl) @@ -649,12 +643,7 @@ def test_bitwise_right_shift(ctx, data): binary_param_assert_dtype(ctx, left, right, res) binary_param_assert_shape(ctx, left, right, res) binary_param_assert_against_refimpl( - ctx, - left, - right, - res, - ">>", - lambda l, r: mock_int_dtype(l >> r, res.dtype), + ctx, left, right, res, ">>", lambda l, r: mock_int_dtype(l >> r, res.dtype) ) @@ -943,14 +932,16 @@ def test_log10(x): ) +def logaddexp(l: float, r: float) -> float: + return math.log(math.exp(l) + math.exp(r)) + + @given(*hh.two_mutual_arrays(dh.float_dtypes)) def test_logaddexp(x1, x2): out = xp.logaddexp(x1, x2) ph.assert_dtype("logaddexp", [x1.dtype, x2.dtype], out.dtype) ph.assert_result_shape("logaddexp", [x1.shape, x2.shape], out.shape) - binary_assert_against_refimpl( - "logaddexp", x1, x2, out, lambda l, r: math.log(math.exp(l) + math.exp(r)) - ) + binary_assert_against_refimpl("logaddexp", x1, x2, out, logaddexp) @given(*hh.two_mutual_arrays([xp.bool])) @@ -959,7 +950,7 @@ def test_logical_and(x1, x2): ph.assert_dtype("logical_and", [x1.dtype, x2.dtype], out.dtype) ph.assert_result_shape("logical_and", [x1.shape, x2.shape], out.shape) binary_assert_against_refimpl( - "logical_and", x1, x2, out, lambda l, r: l and r, expr_template="({} and {})={}" + "logical_and", x1, x2, out, operator.and_, expr_template="({} and {})={}" ) @@ -969,7 +960,7 @@ def test_logical_not(x): ph.assert_dtype("logical_not", x.dtype, out.dtype) ph.assert_shape("logical_not", out.shape, x.shape) unary_assert_against_refimpl( - "logical_not", x, out, lambda i: not i, expr_template="(not {})={}" + "logical_not", x, out, operator.not_, expr_template="(not {})={}" ) @@ -979,7 +970,7 @@ def test_logical_or(x1, x2): ph.assert_dtype("logical_or", [x1.dtype, x2.dtype], out.dtype) ph.assert_result_shape("logical_or", [x1.shape, x2.shape], out.shape) binary_assert_against_refimpl( - "logical_or", x1, x2, out, lambda l, r: l or r, expr_template="({} or {})={}" + "logical_or", x1, x2, out, operator.or_, expr_template="({} or {})={}" ) From 6e8cda6d979dea20f82a6b5e4ff0bfe1b7190b49 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Wed, 2 Feb 2022 11:36:08 +0000 Subject: [PATCH 32/40] Filter undefined dtypes in `hh.two_mutual_arrays()` --- array_api_tests/hypothesis_helpers.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/array_api_tests/hypothesis_helpers.py b/array_api_tests/hypothesis_helpers.py index f77301da..38771225 100644 --- a/array_api_tests/hypothesis_helpers.py +++ b/array_api_tests/hypothesis_helpers.py @@ -369,6 +369,9 @@ def two_mutual_arrays( ) -> Tuple[SearchStrategy[Array], SearchStrategy[Array]]: if not isinstance(dtypes, Sequence): raise TypeError(f"{dtypes=} not a sequence") + if FILTER_UNDEFINED_DTYPES: + dtypes = [d for d in dtypes if not isinstance(d, _UndefinedStub)] + assert len(dtypes) > 0 # sanity check mutual_dtypes = shared(mutually_promotable_dtypes(dtypes=dtypes)) mutual_shapes = shared(two_shapes) arrays1 = xps.arrays( From 493f669699254d390ed1e62060fac0034a8e9614 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Wed, 2 Feb 2022 11:40:31 +0000 Subject: [PATCH 33/40] Generic type hint for `refimpl` args --- .../test_operators_and_elementwise_functions.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index 7fc2a10c..5703aabb 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -1,7 +1,7 @@ import math import operator from enum import Enum, auto -from typing import Callable, List, NamedTuple, Optional, Union +from typing import Callable, List, NamedTuple, Optional, TypeVar, Union import pytest from hypothesis import assume, given @@ -85,11 +85,14 @@ def default_filter(s: Scalar) -> bool: return math.isfinite(s) and s is not -0.0 and s is not +0.0 +T = TypeVar("T") + + def unary_assert_against_refimpl( func_name: str, in_: Array, res: Array, - refimpl: Callable[[Scalar], Scalar], + refimpl: Callable[[T], T], expr_template: Optional[str] = None, res_stype: Optional[ScalarType] = None, filter_: Callable[[Scalar], bool] = default_filter, @@ -136,7 +139,7 @@ def binary_assert_against_refimpl( left: Array, right: Array, res: Array, - refimpl: Callable[[Scalar, Scalar], Scalar], + refimpl: Callable[[T, T], T], expr_template: Optional[str] = None, res_stype: Optional[ScalarType] = None, left_sym: str = "x1", @@ -382,7 +385,7 @@ def binary_param_assert_against_refimpl( right: Union[Array, Scalar], res: Array, op_sym: str, - refimpl: Callable[[Scalar, Scalar], Scalar], + refimpl: Callable[[T, T], T], res_stype: Optional[ScalarType] = None, filter_: Callable[[Scalar], bool] = default_filter, strict_check: Optional[bool] = None, @@ -456,7 +459,7 @@ def test_abs(ctx, data): ctx.func_name, x, out, - abs, + abs, # type: ignore expr_template="abs({})={}", filter_=lambda s: ( s == float("infinity") or (math.isfinite(s) and s is not -0.0) @@ -1013,7 +1016,7 @@ def test_negative(ctx, data): ph.assert_dtype(ctx.func_name, x.dtype, out.dtype) ph.assert_shape(ctx.func_name, out.shape, x.shape) unary_assert_against_refimpl( - ctx.func_name, x, out, operator.neg, expr_template="-({})={}" + ctx.func_name, x, out, operator.neg, expr_template="-({})={}" # type: ignore ) From d924ce4151c34586a0192e5b8ec44efa6c9b60cf Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Wed, 2 Feb 2022 11:47:57 +0000 Subject: [PATCH 34/40] Introduce `right_scalar_assert_against_refimpl()` Keeps all refimpl logic near eachother --- ...est_operators_and_elementwise_functions.py | 94 ++++++++++++------- 1 file changed, 60 insertions(+), 34 deletions(-) diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index 5703aabb..53526d79 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -184,6 +184,53 @@ def binary_assert_against_refimpl( ) +def right_scalar_assert_against_refimpl( + func_name: str, + left: Array, + right: Scalar, + res: Array, + refimpl: Callable[[T, T], T], + expr_template: str = None, + res_stype: Optional[ScalarType] = None, + left_sym: str = "x1", + res_name: str = "out", + filter_: Callable[[Scalar], bool] = default_filter, + strict_check: Optional[bool] = None, +): + if filter_(right): + return # short-circuit here as there will be nothing to test + in_stype = dh.get_scalar_type(left.dtype) + if res_stype is None: + res_stype = in_stype + m, M = dh.dtype_ranges.get(left.dtype, (None, None)) + for idx in sh.ndindex(res.shape): + scalar_l = in_stype(left[idx]) + if not filter_(scalar_l): + continue + try: + expected = refimpl(scalar_l, right) + except Exception: + continue + if left.dtype != xp.bool: + assert m is not None and M is not None # for mypy + if expected <= m or expected >= M: + continue + scalar_o = res_stype(res[idx]) + f_l = sh.fmt_idx(left_sym, idx) + f_o = sh.fmt_idx(res_name, idx) + expr = expr_template.format(f_l, right, expected) + if strict_check == False or dh.is_float_dtype(res.dtype): + assert isclose(scalar_o, expected), ( + f"{f_o}={scalar_o}, but should be roughly {expr} [{func_name}()]\n" + f"{f_l}={scalar_l}" + ) + else: + assert scalar_o == expected, ( + f"{f_o}={scalar_o}, but should be {expr} [{func_name}()]\n" + f"{f_l}={scalar_l}" + ) + + # When appropiate, this module tests operators alongside their respective # elementwise methods. We do this by parametrizing a generalised test method # with every relevant method and operator. @@ -392,40 +439,19 @@ def binary_param_assert_against_refimpl( ): expr_template = "({} " + op_sym + " {})={}" if ctx.right_is_scalar: - if filter_(right): - return # short-circuit here as there will be nothing to test - in_stype = dh.get_scalar_type(left.dtype) - if res_stype is None: - res_stype = in_stype - m, M = dh.dtype_ranges.get(left.dtype, (None, None)) - for idx in sh.ndindex(res.shape): - scalar_l = in_stype(left[idx]) - if not filter_(scalar_l): - continue - try: - expected = refimpl(scalar_l, right) - except Exception: - continue - if left.dtype != xp.bool: - assert m is not None and M is not None # for mypy - if expected <= m or expected >= M: - continue - scalar_o = res_stype(res[idx]) - f_l = sh.fmt_idx(ctx.left_sym, idx) - f_o = sh.fmt_idx(ctx.res_name, idx) - expr = expr_template.format(f_l, right, expected) - if strict_check == False or dh.is_float_dtype(res.dtype): - assert isclose(scalar_o, expected), ( - f"{f_o}={scalar_o}, but should be roughly {expr} " - f"[{ctx.func_name}()]\n" - f"{f_l}={scalar_l}" - ) - else: - assert scalar_o == expected, ( - f"{f_o}={scalar_o}, but should be {expr} " - f"[{ctx.func_name}()]\n" - f"{f_l}={scalar_l}" - ) + right_scalar_assert_against_refimpl( + func_name=ctx.func_name, + left_sym=ctx.left_sym, + left=left, + right=right, + res_stype=res_stype, + res_name=ctx.res_name, + res=res, + refimpl=refimpl, + expr_template=expr_template, + filter_=filter_, + strict_check=strict_check, + ) else: binary_assert_against_refimpl( func_name=ctx.func_name, From 3c85cae7cf6c4e5059a2036467a4219c2ff0f09d Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Wed, 2 Feb 2022 11:54:16 +0000 Subject: [PATCH 35/40] Note why you'd want to not strictly check int outputs --- array_api_tests/test_operators_and_elementwise_functions.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index 53526d79..0eb15462 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -63,7 +63,8 @@ def mock_int_dtype(n: int, dtype: DataType) -> int: # By default, floating-point functions/methods are loosely asserted against. Use # `strict_check=True` when they should be strictly asserted against, i.e. # when a function should return intergrals. Likewise, use `strict_check=False` -# when integer function/methods should be loosely asserted against. +# when integer function/methods should be loosely asserted against, i.e. when +# floats are used internally for optimisation or legacy reasons. def isclose(a: float, b: float, rel_tol: float = 0.25, abs_tol: float = 1) -> bool: From 3364e48cc8b0e2f7ebffaf6c156acc23487c1ffe Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Wed, 2 Feb 2022 18:20:34 +0000 Subject: [PATCH 36/40] Test broadcastable shapes for in-place operators --- array_api_tests/meta/test_utils.py | 12 +++++-- ...est_operators_and_elementwise_functions.py | 32 +++++++++++++++++-- 2 files changed, 39 insertions(+), 5 deletions(-) diff --git a/array_api_tests/meta/test_utils.py b/array_api_tests/meta/test_utils.py index 3cd819b4..e51d0a02 100644 --- a/array_api_tests/meta/test_utils.py +++ b/array_api_tests/meta/test_utils.py @@ -3,11 +3,14 @@ from hypothesis import strategies as st from .. import _array_module as xp -from .. import xps from .. import shape_helpers as sh +from .. import xps from ..test_creation_functions import frange from ..test_manipulation_functions import roll_ndindex -from ..test_operators_and_elementwise_functions import mock_int_dtype +from ..test_operators_and_elementwise_functions import ( + mock_int_dtype, + oneway_broadcastable_shapes, +) from ..test_signatures import extension_module @@ -115,3 +118,8 @@ def test_int_to_dtype(x, dtype): except OverflowError: reject() assert mock_int_dtype(x, dtype) == d + + +@given(oneway_broadcastable_shapes()) +def test_oneway_broadcastable_shapes(S): + assert sh.broadcast_shapes(*S) == S.result_shape diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index 0eb15462..4fde437f 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -30,6 +30,26 @@ def boolean_and_all_integer_dtypes() -> st.SearchStrategy[DataType]: return xps.boolean_dtypes() | all_integer_dtypes() +class OnewayBroadcastableShapes(NamedTuple): + input_shape: Shape + result_shape: Shape + + +@st.composite +def oneway_broadcastable_shapes(draw) -> st.SearchStrategy[OnewayBroadcastableShapes]: + """Return a strategy for input shapes that broadcast to result shapes.""" + result_shape = draw(hh.shapes(min_side=1)) + input_shape = draw( + xps.broadcastable_shapes( + result_shape, + # Override defaults so bad shapes are less likely to be generated. + max_side=None if result_shape == () else max(result_shape), + max_dims=len(result_shape), + ).filter(lambda s: sh.broadcast_shapes(result_shape, s) == result_shape) + ) + return OnewayBroadcastableShapes(input_shape, result_shape) + + def mock_int_dtype(n: int, dtype: DataType) -> int: """Returns equivalent of `n` that mocks `dtype` behaviour.""" nbits = dh.dtype_nbits[dtype] @@ -326,9 +346,15 @@ def make_param( ) else: if func_type is FuncType.IOP: - shared_shapes = st.shared(hh.shapes(**shapes_kw)) - left_strat = xps.arrays(dtype=shared_dtypes, shape=shared_shapes) - right_strat = xps.arrays(dtype=shared_dtypes, shape=shared_shapes) + shared_oneway_shapes = st.shared(oneway_broadcastable_shapes()) + left_strat = xps.arrays( + dtype=shared_dtypes, + shape=shared_oneway_shapes.map(lambda S: S.result_shape), + ) + right_strat = xps.arrays( + dtype=shared_dtypes, + shape=shared_oneway_shapes.map(lambda S: S.input_shape), + ) else: mutual_shapes = st.shared( hh.mutually_broadcastable_shapes(2, **shapes_kw) From 263b764299a8047ea98eabc95512c4e767f79047 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Wed, 2 Feb 2022 19:45:12 +0000 Subject: [PATCH 37/40] Generate oneway promotable dtypes for elwise/op tests --- array_api_tests/meta/test_utils.py | 9 +- ...est_operators_and_elementwise_functions.py | 83 ++++++++++++------- 2 files changed, 59 insertions(+), 33 deletions(-) diff --git a/array_api_tests/meta/test_utils.py b/array_api_tests/meta/test_utils.py index e51d0a02..588cfb1b 100644 --- a/array_api_tests/meta/test_utils.py +++ b/array_api_tests/meta/test_utils.py @@ -3,6 +3,7 @@ from hypothesis import strategies as st from .. import _array_module as xp +from .. import dtype_helpers as dh from .. import shape_helpers as sh from .. import xps from ..test_creation_functions import frange @@ -10,6 +11,7 @@ from ..test_operators_and_elementwise_functions import ( mock_int_dtype, oneway_broadcastable_shapes, + oneway_promotable_dtypes, ) from ..test_signatures import extension_module @@ -120,6 +122,11 @@ def test_int_to_dtype(x, dtype): assert mock_int_dtype(x, dtype) == d +@given(oneway_promotable_dtypes(dh.all_dtypes)) +def test_oneway_promotable_dtypes(D): + assert D.result_dtype == dh.result_type(*D) + + @given(oneway_broadcastable_shapes()) def test_oneway_broadcastable_shapes(S): - assert sh.broadcast_shapes(*S) == S.result_shape + assert S.result_shape == sh.broadcast_shapes(*S) diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index 4fde437f..6947c061 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -30,6 +30,26 @@ def boolean_and_all_integer_dtypes() -> st.SearchStrategy[DataType]: return xps.boolean_dtypes() | all_integer_dtypes() +class OnewayPromotableDtypes(NamedTuple): + input_dtype: DataType + result_dtype: DataType + + +@st.composite +def oneway_promotable_dtypes( + draw, dtypes: List[DataType] +) -> st.SearchStrategy[OnewayPromotableDtypes]: + """Return a strategy for input dtypes that promote to result dtypes.""" + d1, d2 = draw(hh.mutually_promotable_dtypes(dtypes=dtypes)) + result_dtype = dh.result_type(d1, d2) + if d1 == result_dtype: + return OnewayPromotableDtypes(d2, d1) + elif d2 == result_dtype: + return OnewayPromotableDtypes(d1, d2) + else: + reject() + + class OnewayBroadcastableShapes(NamedTuple): input_shape: Shape result_shape: Shape @@ -326,8 +346,14 @@ def __repr__(self): def make_binary_params( - elwise_func_name: str, dtypes_strat: st.SearchStrategy[DataType] + elwise_func_name: str, dtypes: List[DataType] ) -> List[Param[BinaryParamContext]]: + if hh.FILTER_UNDEFINED_DTYPES: + dtypes = [d for d in dtypes if not isinstance(d, xp._UndefinedStub)] + shared_oneway_dtypes = st.shared(oneway_promotable_dtypes(dtypes)) + left_dtypes = shared_oneway_dtypes.map(lambda D: D.result_dtype) + right_dtypes = shared_oneway_dtypes.map(lambda D: D.input_dtype) + def make_param( func_name: str, func_type: FuncType, right_is_scalar: bool ) -> Param[BinaryParamContext]: @@ -338,21 +364,18 @@ def make_param( left_sym = "x1" right_sym = "x2" - shared_dtypes = st.shared(dtypes_strat) if right_is_scalar: - left_strat = xps.arrays(dtype=shared_dtypes, shape=hh.shapes(**shapes_kw)) - right_strat = shared_dtypes.flatmap( - lambda d: xps.from_dtype(d, **finite_kw) - ) + left_strat = xps.arrays(dtype=left_dtypes, shape=hh.shapes(**shapes_kw)) + right_strat = right_dtypes.flatmap(lambda d: xps.from_dtype(d, **finite_kw)) else: if func_type is FuncType.IOP: shared_oneway_shapes = st.shared(oneway_broadcastable_shapes()) left_strat = xps.arrays( - dtype=shared_dtypes, + dtype=left_dtypes, shape=shared_oneway_shapes.map(lambda S: S.result_shape), ) right_strat = xps.arrays( - dtype=shared_dtypes, + dtype=right_dtypes, shape=shared_oneway_shapes.map(lambda S: S.input_shape), ) else: @@ -360,10 +383,10 @@ def make_param( hh.mutually_broadcastable_shapes(2, **shapes_kw) ) left_strat = xps.arrays( - dtype=shared_dtypes, shape=mutual_shapes.map(lambda pair: pair[0]) + dtype=left_dtypes, shape=mutual_shapes.map(lambda pair: pair[0]) ) right_strat = xps.arrays( - dtype=shared_dtypes, shape=mutual_shapes.map(lambda pair: pair[1]) + dtype=right_dtypes, shape=mutual_shapes.map(lambda pair: pair[1]) ) if func_type is FuncType.FUNC: @@ -540,7 +563,7 @@ def test_acosh(x): ) -@pytest.mark.parametrize("ctx,", make_binary_params("add", xps.numeric_dtypes())) +@pytest.mark.parametrize("ctx,", make_binary_params("add", dh.numeric_dtypes)) @given(data=st.data()) def test_add(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) @@ -605,7 +628,7 @@ def test_atanh(x): @pytest.mark.parametrize( - "ctx", make_binary_params("bitwise_and", boolean_and_all_integer_dtypes()) + "ctx", make_binary_params("bitwise_and", dh.bool_and_all_int_dtypes) ) @given(data=st.data()) def test_bitwise_and(ctx, data): @@ -624,7 +647,7 @@ def test_bitwise_and(ctx, data): @pytest.mark.parametrize( - "ctx", make_binary_params("bitwise_left_shift", all_integer_dtypes()) + "ctx", make_binary_params("bitwise_left_shift", dh.all_int_dtypes) ) @given(data=st.data()) def test_bitwise_left_shift(ctx, data): @@ -664,7 +687,7 @@ def test_bitwise_invert(ctx, data): @pytest.mark.parametrize( - "ctx", make_binary_params("bitwise_or", boolean_and_all_integer_dtypes()) + "ctx", make_binary_params("bitwise_or", dh.bool_and_all_int_dtypes) ) @given(data=st.data()) def test_bitwise_or(ctx, data): @@ -683,7 +706,7 @@ def test_bitwise_or(ctx, data): @pytest.mark.parametrize( - "ctx", make_binary_params("bitwise_right_shift", all_integer_dtypes()) + "ctx", make_binary_params("bitwise_right_shift", dh.all_int_dtypes) ) @given(data=st.data()) def test_bitwise_right_shift(ctx, data): @@ -704,7 +727,7 @@ def test_bitwise_right_shift(ctx, data): @pytest.mark.parametrize( - "ctx", make_binary_params("bitwise_xor", boolean_and_all_integer_dtypes()) + "ctx", make_binary_params("bitwise_xor", dh.bool_and_all_int_dtypes) ) @given(data=st.data()) def test_bitwise_xor(ctx, data): @@ -746,7 +769,7 @@ def test_cosh(x): unary_assert_against_refimpl("cosh", x, out, math.cosh) -@pytest.mark.parametrize("ctx", make_binary_params("divide", xps.floating_dtypes())) +@pytest.mark.parametrize("ctx", make_binary_params("divide", dh.float_dtypes)) @given(data=st.data()) def test_divide(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) @@ -769,7 +792,7 @@ def test_divide(ctx, data): ) -@pytest.mark.parametrize("ctx", make_binary_params("equal", xps.scalar_dtypes())) +@pytest.mark.parametrize("ctx", make_binary_params("equal", dh.all_dtypes)) @given(data=st.data()) def test_equal(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) @@ -821,9 +844,7 @@ def test_floor(x): unary_assert_against_refimpl("floor", x, out, math.floor, strict_check=True) -@pytest.mark.parametrize( - "ctx", make_binary_params("floor_divide", xps.numeric_dtypes()) -) +@pytest.mark.parametrize("ctx", make_binary_params("floor_divide", dh.numeric_dtypes)) @given(data=st.data()) def test_floor_divide(ctx, data): left = data.draw( @@ -842,7 +863,7 @@ def test_floor_divide(ctx, data): binary_param_assert_against_refimpl(ctx, left, right, res, "//", operator.floordiv) -@pytest.mark.parametrize("ctx", make_binary_params("greater", xps.numeric_dtypes())) +@pytest.mark.parametrize("ctx", make_binary_params("greater", dh.numeric_dtypes)) @given(data=st.data()) def test_greater(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) @@ -862,9 +883,7 @@ def test_greater(ctx, data): ) -@pytest.mark.parametrize( - "ctx", make_binary_params("greater_equal", xps.numeric_dtypes()) -) +@pytest.mark.parametrize("ctx", make_binary_params("greater_equal", dh.numeric_dtypes)) @given(data=st.data()) def test_greater_equal(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) @@ -908,7 +927,7 @@ def test_isnan(x): unary_assert_against_refimpl("isnan", x, out, math.isnan, res_stype=bool) -@pytest.mark.parametrize("ctx", make_binary_params("less", xps.numeric_dtypes())) +@pytest.mark.parametrize("ctx", make_binary_params("less", dh.numeric_dtypes)) @given(data=st.data()) def test_less(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) @@ -928,7 +947,7 @@ def test_less(ctx, data): ) -@pytest.mark.parametrize("ctx", make_binary_params("less_equal", xps.numeric_dtypes())) +@pytest.mark.parametrize("ctx", make_binary_params("less_equal", dh.numeric_dtypes)) @given(data=st.data()) def test_less_equal(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) @@ -1040,7 +1059,7 @@ def test_logical_xor(x1, x2): ) -@pytest.mark.parametrize("ctx", make_binary_params("multiply", xps.numeric_dtypes())) +@pytest.mark.parametrize("ctx", make_binary_params("multiply", dh.numeric_dtypes)) @given(data=st.data()) def test_multiply(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) @@ -1073,7 +1092,7 @@ def test_negative(ctx, data): ) -@pytest.mark.parametrize("ctx", make_binary_params("not_equal", xps.scalar_dtypes())) +@pytest.mark.parametrize("ctx", make_binary_params("not_equal", dh.all_dtypes)) @given(data=st.data()) def test_not_equal(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) @@ -1105,7 +1124,7 @@ def test_positive(ctx, data): ph.assert_array(ctx.func_name, out, x) -@pytest.mark.parametrize("ctx", make_binary_params("pow", xps.numeric_dtypes())) +@pytest.mark.parametrize("ctx", make_binary_params("pow", dh.numeric_dtypes)) @given(data=st.data()) def test_pow(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) @@ -1129,7 +1148,7 @@ def test_pow(ctx, data): ) -@pytest.mark.parametrize("ctx", make_binary_params("remainder", xps.numeric_dtypes())) +@pytest.mark.parametrize("ctx", make_binary_params("remainder", dh.numeric_dtypes)) @given(data=st.data()) def test_remainder(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) @@ -1200,7 +1219,7 @@ def test_sqrt(x): ) -@pytest.mark.parametrize("ctx", make_binary_params("subtract", xps.numeric_dtypes())) +@pytest.mark.parametrize("ctx", make_binary_params("subtract", dh.numeric_dtypes)) @given(data=st.data()) def test_subtract(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) From 9deff00bb133db4b25cf3d9e6815bade01d99abf Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Thu, 3 Feb 2022 09:16:10 +0000 Subject: [PATCH 38/40] Remove values testing from `test_mean()` --- array_api_tests/test_statistical_functions.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/array_api_tests/test_statistical_functions.py b/array_api_tests/test_statistical_functions.py index c86111a0..c7d0e842 100644 --- a/array_api_tests/test_statistical_functions.py +++ b/array_api_tests/test_statistical_functions.py @@ -69,15 +69,7 @@ def test_mean(x, data): ph.assert_keepdimable_shape( "mean", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw ) - for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)): - mean = float(out[out_idx]) - assume(not math.isinf(mean)) # mean may become inf due to internal overflows - elements = [] - for idx in indices: - s = float(x[idx]) - elements.append(s) - expected = sum(elements) / len(elements) - ph.assert_scalar_equals("mean", float, out_idx, mean, expected) + # Values testing mean is too finicky @given( From 87cd96d6d94827480901a4d14350ab00a2b09266 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Thu, 3 Feb 2022 09:20:31 +0000 Subject: [PATCH 39/40] Skip instead of xfail on workflow --- .github/workflows/numpy.yml | 4 ++-- conftest.py | 16 ++++++++-------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/.github/workflows/numpy.yml b/.github/workflows/numpy.yml index 3581edee..7ec3b5d3 100644 --- a/.github/workflows/numpy.yml +++ b/.github/workflows/numpy.yml @@ -25,8 +25,8 @@ jobs: env: ARRAY_API_TESTS_MODULE: numpy.array_api run: | - # Mark some known issues as XFAIL - cat << EOF >> xfails.txt + # Skip test cases with known issues + cat << EOF >> skips.txt # copy not implemented array_api_tests/test_creation_functions.py::test_asarray_arrays diff --git a/conftest.py b/conftest.py index efa3a46a..2af3fef1 100644 --- a/conftest.py +++ b/conftest.py @@ -71,14 +71,14 @@ def xp_has_ext(ext: str) -> bool: return False -xfail_ids = [] -xfails_path = Path(__file__).parent / "xfails.txt" -if xfails_path.exists(): - with open(xfails_path) as f: +skip_ids = [] +skips_path = Path(__file__).parent / "skips.txt" +if skips_path.exists(): + with open(skips_path) as f: for line in f: if line.startswith("array_api_tests"): id_ = line.strip("\n") - xfail_ids.append(id_) + skip_ids.append(id_) def pytest_collection_modifyitems(config, items): @@ -96,10 +96,10 @@ def pytest_collection_modifyitems(config, items): ) elif not xp_has_ext(ext): item.add_marker(mark.skip(reason=f"{ext} not found in array module")) - # xfail if specified in xfails.txt - for id_ in xfail_ids: + # skip if specified in skips.txt + for id_ in skip_ids: if item.nodeid.startswith(id_): - item.add_marker(mark.xfail(reason="xfails.txt")) + item.add_marker(mark.skip(reason="skips.txt")) break # skip if test not appropiate for CI if ci: From 8ce379158a2e2925ecdb2cc6eea484ef33e0aede Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Thu, 3 Feb 2022 09:21:41 +0000 Subject: [PATCH 40/40] Skip `test_eigh` --- array_api_tests/test_linalg.py | 1 + 1 file changed, 1 insertion(+) diff --git a/array_api_tests/test_linalg.py b/array_api_tests/test_linalg.py index 7117c20b..764d0df4 100644 --- a/array_api_tests/test_linalg.py +++ b/array_api_tests/test_linalg.py @@ -229,6 +229,7 @@ def true_diag(x_stack): _test_stacks(linalg.diagonal, x, **kw, res=res, dims=1, true_val=true_diag) +@pytest.mark.skip(reason="Inputs need to be restricted") # TODO @pytest.mark.xp_extension('linalg') @given(x=symmetric_matrices(finite=True)) def test_eigh(x):