diff --git a/array_api_tests/array_helpers.py b/array_api_tests/array_helpers.py index 46de8a00..398f1994 100644 --- a/array_api_tests/array_helpers.py +++ b/array_api_tests/array_helpers.py @@ -5,14 +5,12 @@ zeros, ones, full, bool, int8, int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64, nan, inf, pi, remainder, divide, isinf, - negative, _integer_dtypes, _floating_dtypes, - _numeric_dtypes, _boolean_dtypes, _dtypes, - asarray) -from . import _array_module - + negative, asarray) # These are exported here so that they can be included in the special cases # tests from this file. from ._array_module import logical_not, subtract, floor, ceil, where +from . import dtype_helpers as dh + __all__ = ['all', 'any', 'logical_and', 'logical_or', 'logical_not', 'less', 'less_equal', 'greater', 'subtract', 'negative', 'floor', 'ceil', @@ -25,9 +23,8 @@ 'assert_isinf', 'positive_mathematical_sign', 'assert_positive_mathematical_sign', 'negative_mathematical_sign', 'assert_negative_mathematical_sign', 'same_sign', - 'assert_same_sign', 'ndindex', 'promote_dtypes', 'float64', - 'asarray', 'is_integer_dtype', 'is_float_dtype', 'dtype_ranges', - 'full', 'true', 'false', 'isnan'] + 'assert_same_sign', 'ndindex', 'float64', + 'asarray', 'full', 'true', 'false', 'isnan'] def zero(shape, dtype): """ @@ -111,7 +108,7 @@ def isnegzero(x): # TODO: If copysign or signbit are added to the spec, use those instead. shape = x.shape dtype = x.dtype - if is_integer_dtype(dtype): + if dh.is_int_dtype(dtype): return false(shape) return equal(divide(one(shape, dtype), x), -infinity(shape, dtype)) @@ -122,7 +119,7 @@ def isposzero(x): # TODO: If copysign or signbit are added to the spec, use those instead. shape = x.shape dtype = x.dtype - if is_integer_dtype(dtype): + if dh.is_int_dtype(dtype): return true(shape) return equal(divide(one(shape, dtype), x), infinity(shape, dtype)) @@ -311,37 +308,6 @@ def same_sign(x, y): def assert_same_sign(x, y): assert all(same_sign(x, y)), "The input arrays do not have the same sign" -integer_dtype_objects = [getattr(_array_module, t) for t in _integer_dtypes] -floating_dtype_objects = [getattr(_array_module, t) for t in _floating_dtypes] -numeric_dtype_objects = [getattr(_array_module, t) for t in _numeric_dtypes] -boolean_dtype_objects = [getattr(_array_module, t) for t in _boolean_dtypes] -integer_or_boolean_dtype_objects = integer_dtype_objects + boolean_dtype_objects -dtype_objects = [getattr(_array_module, t) for t in _dtypes] - -def is_integer_dtype(dtype): - if dtype is None: - return False - return dtype in [int8, int16, int32, int64, uint8, uint16, uint32, uint64] - -def is_float_dtype(dtype): - if dtype is None: - # numpy.dtype('float64') == None gives True - return False - # TODO: Return True even for floating point dtypes that aren't part of the - # spec, like np.float16 - return dtype in [float32, float64] - -dtype_ranges = { - int8: [-128, +127], - int16: [-32_768, +32_767], - int32: [-2_147_483_648, +2_147_483_647], - int64: [-9_223_372_036_854_775_808, +9_223_372_036_854_775_807], - uint8: [0, +255], - uint16: [0, +65_535], - uint32: [0, +4_294_967_295], - uint64: [0, +18_446_744_073_709_551_615], -} - def int_to_dtype(x, n, signed): """ Convert the Python integer x into an n bit signed or unsigned number. @@ -363,22 +329,3 @@ def ndindex(shape): """ return itertools.product(*[range(i) for i in shape]) - -def promote_dtypes(dtype1, dtype2): - """ - Special case of result_type() which uses the exact type promotion table - from the spec. - """ - from .test_type_promotion import dtype_mapping, promotion_table - - # Equivalent to this, but some libraries may not work properly with using - # dtype objects as dict keys - # - # d1, d2 = reverse_dtype_mapping[dtype1], reverse_dtype_mapping[dtype2] - - d1 = [i for i in dtype_mapping if dtype_mapping[i] == dtype1][0] - d2 = [i for i in dtype_mapping if dtype_mapping[i] == dtype2][0] - - if (d1, d2) not in promotion_table: - raise ValueError(f"{d1} and {d2} are not type promotable according to the spec (this may indicate a bug in the test suite).") - return dtype_mapping[promotion_table[d1, d2]] diff --git a/array_api_tests/dtype_helpers.py b/array_api_tests/dtype_helpers.py new file mode 100644 index 00000000..b81edb7b --- /dev/null +++ b/array_api_tests/dtype_helpers.py @@ -0,0 +1,338 @@ +from typing import NamedTuple + +from . import _array_module as xp + + +__all__ = [ + 'int_dtypes', + 'uint_dtypes', + 'all_int_dtypes', + 'float_dtypes', + 'numeric_dtypes', + 'all_dtypes', + 'dtype_to_name', + 'bool_and_all_int_dtypes', + 'dtype_to_scalars', + 'is_int_dtype', + 'is_float_dtype', + 'dtype_ranges', + 'promotion_table', + 'dtype_nbits', + 'dtype_signed', + 'func_in_dtypes', + 'func_returns_bool', + 'binary_op_to_symbol', + 'unary_op_to_symbol', + 'inplace_op_to_symbol', +] + + +_int_names = ('int8', 'int16', 'int32', 'int64') +_uint_names = ('uint8', 'uint16', 'uint32', 'uint64') +_float_names = ('float32', 'float64') +_dtype_names = ('bool',) + _int_names + _uint_names + _float_names + + +int_dtypes = tuple(getattr(xp, name) for name in _int_names) +uint_dtypes = tuple(getattr(xp, name) for name in _uint_names) +float_dtypes = tuple(getattr(xp, name) for name in _float_names) +all_int_dtypes = int_dtypes + uint_dtypes +numeric_dtypes = all_int_dtypes + float_dtypes +all_dtypes = (xp.bool,) + numeric_dtypes +bool_and_all_int_dtypes = (xp.bool,) + all_int_dtypes + + +dtype_to_name = {getattr(xp, name): name for name in _dtype_names} + + +dtype_to_scalars = { + xp.bool: [bool], + **{d: [int] for d in all_int_dtypes}, + **{d: [int, float] for d in float_dtypes}, +} + + +def is_int_dtype(dtype): + return dtype in all_int_dtypes + + +def is_float_dtype(dtype): + # None equals NumPy's xp.float64 object, so we specifically check it here. + # xp.float64 is in fact an alias of np.dtype('float64'), and its equality + # with None is meant to be deprecated at some point. + # See https://github.com/numpy/numpy/issues/18434 + if dtype is None: + return False + # TODO: Return True for float dtypes that aren't part of the spec e.g. np.float16 + return dtype in float_dtypes + + +class MinMax(NamedTuple): + min: int + max: int + + +dtype_ranges = { + xp.int8: MinMax(-128, +127), + xp.int16: MinMax(-32_768, +32_767), + xp.int32: MinMax(-2_147_483_648, +2_147_483_647), + xp.int64: MinMax(-9_223_372_036_854_775_808, +9_223_372_036_854_775_807), + xp.uint8: MinMax(0, +255), + xp.uint16: MinMax(0, +65_535), + xp.uint32: MinMax(0, +4_294_967_295), + xp.uint64: MinMax(0, +18_446_744_073_709_551_615), +} + + +_numeric_promotions = { + # ints + (xp.int8, xp.int8): xp.int8, + (xp.int8, xp.int16): xp.int16, + (xp.int8, xp.int32): xp.int32, + (xp.int8, xp.int64): xp.int64, + (xp.int16, xp.int16): xp.int16, + (xp.int16, xp.int32): xp.int32, + (xp.int16, xp.int64): xp.int64, + (xp.int32, xp.int32): xp.int32, + (xp.int32, xp.int64): xp.int64, + (xp.int64, xp.int64): xp.int64, + # uints + (xp.uint8, xp.uint8): xp.uint8, + (xp.uint8, xp.uint16): xp.uint16, + (xp.uint8, xp.uint32): xp.uint32, + (xp.uint8, xp.uint64): xp.uint64, + (xp.uint16, xp.uint16): xp.uint16, + (xp.uint16, xp.uint32): xp.uint32, + (xp.uint16, xp.uint64): xp.uint64, + (xp.uint32, xp.uint32): xp.uint32, + (xp.uint32, xp.uint64): xp.uint64, + (xp.uint64, xp.uint64): xp.uint64, + # ints and uints (mixed sign) + (xp.int8, xp.uint8): xp.int16, + (xp.int8, xp.uint16): xp.int32, + (xp.int8, xp.uint32): xp.int64, + (xp.int16, xp.uint8): xp.int16, + (xp.int16, xp.uint16): xp.int32, + (xp.int16, xp.uint32): xp.int64, + (xp.int32, xp.uint8): xp.int32, + (xp.int32, xp.uint16): xp.int32, + (xp.int32, xp.uint32): xp.int64, + (xp.int64, xp.uint8): xp.int64, + (xp.int64, xp.uint16): xp.int64, + (xp.int64, xp.uint32): xp.int64, + # floats + (xp.float32, xp.float32): xp.float32, + (xp.float32, xp.float64): xp.float64, + (xp.float64, xp.float64): xp.float64, +} +promotion_table = { + (xp.bool, xp.bool): xp.bool, + **_numeric_promotions, + **{(d2, d1): res for (d1, d2), res in _numeric_promotions.items()}, +} + + +dtype_nbits = { + **{d: 8 for d in [xp.int8, xp.uint8]}, + **{d: 16 for d in [xp.int16, xp.uint16]}, + **{d: 32 for d in [xp.int32, xp.uint32, xp.float32]}, + **{d: 64 for d in [xp.int64, xp.uint64, xp.float64]}, +} + + +dtype_signed = { + **{d: True for d in int_dtypes}, + **{d: False for d in uint_dtypes}, +} + + +func_in_dtypes = { + 'abs': numeric_dtypes, + 'acos': float_dtypes, + 'acosh': float_dtypes, + 'add': numeric_dtypes, + 'asin': float_dtypes, + 'asinh': float_dtypes, + 'atan': float_dtypes, + 'atan2': float_dtypes, + 'atanh': float_dtypes, + 'bitwise_and': bool_and_all_int_dtypes, + 'bitwise_invert': bool_and_all_int_dtypes, + 'bitwise_left_shift': all_int_dtypes, + 'bitwise_or': bool_and_all_int_dtypes, + 'bitwise_right_shift': all_int_dtypes, + 'bitwise_xor': bool_and_all_int_dtypes, + 'ceil': numeric_dtypes, + 'cos': float_dtypes, + 'cosh': float_dtypes, + 'divide': float_dtypes, + 'equal': all_dtypes, + 'exp': float_dtypes, + 'expm1': float_dtypes, + 'floor': numeric_dtypes, + 'floor_divide': numeric_dtypes, + 'greater': numeric_dtypes, + 'greater_equal': numeric_dtypes, + 'isfinite': numeric_dtypes, + 'isinf': numeric_dtypes, + 'isnan': numeric_dtypes, + 'less': numeric_dtypes, + 'less_equal': numeric_dtypes, + 'log': float_dtypes, + 'logaddexp': float_dtypes, + 'log10': float_dtypes, + 'log1p': float_dtypes, + 'log2': float_dtypes, + 'logical_and': (xp.bool,), + 'logical_not': (xp.bool,), + 'logical_or': (xp.bool,), + 'logical_xor': (xp.bool,), + 'multiply': numeric_dtypes, + 'negative': numeric_dtypes, + 'not_equal': all_dtypes, + 'positive': numeric_dtypes, + 'pow': float_dtypes, + 'remainder': numeric_dtypes, + 'round': numeric_dtypes, + 'sign': numeric_dtypes, + 'sin': float_dtypes, + 'sinh': float_dtypes, + 'sqrt': float_dtypes, + 'square': numeric_dtypes, + 'subtract': numeric_dtypes, + 'tan': float_dtypes, + 'tanh': float_dtypes, + 'trunc': numeric_dtypes, +} + + +func_returns_bool = { + 'abs': False, + 'acos': False, + 'acosh': False, + 'add': False, + 'asin': False, + 'asinh': False, + 'atan': False, + 'atan2': False, + 'atanh': False, + 'bitwise_and': False, + 'bitwise_invert': False, + 'bitwise_left_shift': False, + 'bitwise_or': False, + 'bitwise_right_shift': False, + 'bitwise_xor': False, + 'ceil': False, + 'cos': False, + 'cosh': False, + 'divide': False, + 'equal': True, + 'exp': False, + 'expm1': False, + 'floor': False, + 'floor_divide': False, + 'greater': True, + 'greater_equal': True, + 'isfinite': True, + 'isinf': True, + 'isnan': True, + 'less': True, + 'less_equal': True, + 'log': False, + 'logaddexp': False, + 'log10': False, + 'log1p': False, + 'log2': False, + 'logical_and': True, + 'logical_not': True, + 'logical_or': True, + 'logical_xor': True, + 'multiply': False, + 'negative': False, + 'not_equal': True, + 'positive': False, + 'pow': False, + 'remainder': False, + 'round': False, + 'sign': False, + 'sin': False, + 'sinh': False, + 'sqrt': False, + 'square': False, + 'subtract': False, + 'tan': False, + 'tanh': False, + 'trunc': False, +} + + +unary_op_to_symbol = { + '__invert__': '~', + '__neg__': '-', + '__pos__': '+', +} + + +binary_op_to_symbol = { + '__add__': '+', + '__and__': '&', + '__eq__': '==', + '__floordiv__': '//', + '__ge__': '>=', + '__gt__': '>', + '__le__': '<=', + '__lshift__': '<<', + '__lt__': '<', + '__matmul__': '@', + '__mod__': '%', + '__mul__': '*', + '__ne__': '!=', + '__or__': '|', + '__pow__': '**', + '__rshift__': '>>', + '__sub__': '-', + '__truediv__': '/', + '__xor__': '^', +} + + +_op_to_func = { + '__abs__': 'abs', + '__add__': 'add', + '__and__': 'bitwise_and', + '__eq__': 'equal', + '__floordiv__': 'floor_divide', + '__ge__': 'greater_equal', + '__gt__': 'greater', + '__le__': 'less_equal', + '__lshift__': 'bitwise_left_shift', + '__lt__': 'less', + # '__matmul__': 'matmul', # TODO: support matmul + '__mod__': 'remainder', + '__mul__': 'multiply', + '__ne__': 'not_equal', + '__or__': 'bitwise_or', + '__pow__': 'pow', + '__rshift__': 'bitwise_right_shift', + '__sub__': 'subtract', + '__truediv__': 'divide', + '__xor__': 'bitwise_xor', + '__invert__': 'bitwise_invert', + '__neg__': 'negative', + '__pos__': 'positive', +} + + +for op, elwise_func in _op_to_func.items(): + func_in_dtypes[op] = func_in_dtypes[elwise_func] + func_returns_bool[op] = func_returns_bool[elwise_func] + + +inplace_op_to_symbol = {} +for op, symbol in binary_op_to_symbol.items(): + if op == '__matmul__' or func_returns_bool[op]: + continue + iop = f'__i{op[2:]}' + inplace_op_to_symbol[iop] = f'{symbol}=' + func_in_dtypes[iop] = func_in_dtypes[op] + func_returns_bool[iop] = func_returns_bool[op] diff --git a/array_api_tests/hypothesis_helpers.py b/array_api_tests/hypothesis_helpers.py index d0df1890..ae96482c 100644 --- a/array_api_tests/hypothesis_helpers.py +++ b/array_api_tests/hypothesis_helpers.py @@ -2,18 +2,17 @@ from operator import mul from math import sqrt import itertools +from typing import Tuple from hypothesis import assume from hypothesis.strategies import (lists, integers, sampled_from, shared, floats, just, composite, one_of, none, booleans) +from hypothesis.strategies._internal.strategies import SearchStrategy from .pytest_helpers import nargs -from .array_helpers import (dtype_ranges, integer_dtype_objects, - floating_dtype_objects, numeric_dtype_objects, - boolean_dtype_objects, - integer_or_boolean_dtype_objects, dtype_objects, - ndindex) +from .array_helpers import ndindex +from . import dtype_helpers as dh from ._array_module import (full, float32, float64, bool as bool_dtype, _UndefinedStub, eye, broadcast_to) from . import _array_module as xp @@ -29,12 +28,12 @@ # places in the tests. FILTER_UNDEFINED_DTYPES = True -integer_dtypes = sampled_from(integer_dtype_objects) -floating_dtypes = sampled_from(floating_dtype_objects) -numeric_dtypes = sampled_from(numeric_dtype_objects) -integer_or_boolean_dtypes = sampled_from(integer_or_boolean_dtype_objects) -boolean_dtypes = sampled_from(boolean_dtype_objects) -dtypes = sampled_from(dtype_objects) +integer_dtypes = sampled_from(dh.all_int_dtypes) +floating_dtypes = sampled_from(dh.float_dtypes) +numeric_dtypes = sampled_from(dh.numeric_dtypes) +integer_or_boolean_dtypes = sampled_from(dh.bool_and_all_int_dtypes) +boolean_dtypes = just(xp.bool) +dtypes = sampled_from(dh.all_dtypes) if FILTER_UNDEFINED_DTYPES: integer_dtypes = integer_dtypes.filter(lambda x: not isinstance(x, _UndefinedStub)) @@ -48,32 +47,40 @@ shared_dtypes = shared(dtypes, key="dtype") shared_floating_dtypes = shared(floating_dtypes, key="dtype") -# TODO: Importing things from test_type_promotion should be replaced by -# something that won't cause a circular import. Right now we use @st.composite -# only because it returns a lazy-evaluated strategy - in the future this method -# should remove the composite wrapper, just returning sampled_from(dtype_pairs) -# instead of drawing from it. -@composite -def mutually_promotable_dtypes(draw, dtype_objects=dtype_objects): - from .test_type_promotion import dtype_mapping, promotion_table - # sort for shrinking (sampled_from shrinks to the earlier elements in the - # list). Give pairs of the same dtypes first, then smaller dtypes, - # preferring float, then int, then unsigned int. Note, this might not - # always result in examples shrinking to these pairs because strategies - # that draw from dtypes might not draw the same example from different - # pairs (XXX: Can we redesign the strategies so that they can prefer - # shrinking dtypes over values?) - sorted_table = sorted(promotion_table) - sorted_table = sorted( - sorted_table, key=lambda ij: -1 if ij[0] == ij[1] else sorted_table.index(ij) +_dtype_categories = [(xp.bool,), dh.uint_dtypes, dh.int_dtypes, dh.float_dtypes] +_sorted_dtypes = [d for category in _dtype_categories for d in category] + +def _dtypes_sorter(dtype_pair): + dtype1, dtype2 = dtype_pair + if dtype1 == dtype2: + return _sorted_dtypes.index(dtype1) + key = len(_sorted_dtypes) + rank1 = _sorted_dtypes.index(dtype1) + rank2 = _sorted_dtypes.index(dtype2) + for category in _dtype_categories: + if dtype1 in category and dtype2 in category: + break + else: + key += len(_sorted_dtypes) ** 2 + key += 2 * (rank1 + rank2) + if rank1 > rank2: + key += 1 + return key + +promotable_dtypes = sorted(dh.promotion_table.keys(), key=_dtypes_sorter) + +if FILTER_UNDEFINED_DTYPES: + promotable_dtypes = [ + (i, j) for i, j in promotable_dtypes + if not isinstance(i, _UndefinedStub) + and not isinstance(j, _UndefinedStub) + ] + + +def mutually_promotable_dtypes(dtype_objs=dh.all_dtypes): + return sampled_from( + [(i, j) for i, j in promotable_dtypes if i in dtype_objs and j in dtype_objs] ) - dtype_pairs = [(dtype_mapping[i], dtype_mapping[j]) for i, j in sorted_table] - if FILTER_UNDEFINED_DTYPES: - dtype_pairs = [(i, j) for i, j in dtype_pairs - if not isinstance(i, _UndefinedStub) - and not isinstance(j, _UndefinedStub)] - dtype_pairs = [(i, j) for i, j in dtype_pairs if i in dtype_objects and j in dtype_objects] - return draw(sampled_from(dtype_pairs)) # shared() allows us to draw either the function or the function name and they # will both correspond to the same function. @@ -123,9 +130,16 @@ def matrix_shapes(draw, stack_shapes=shapes): square_matrix_shapes = matrix_shapes().filter(lambda shape: shape[-1] == shape[-2]) -two_mutually_broadcastable_shapes = xps.mutually_broadcastable_shapes(num_shapes=2)\ - .map(lambda S: S.input_shapes)\ - .filter(lambda S: all(prod(i for i in shape if i) < MAX_ARRAY_SIZE for shape in S)) +def mutually_broadcastable_shapes(num_shapes: int) -> SearchStrategy[Tuple[Tuple]]: + return ( + xps.mutually_broadcastable_shapes(num_shapes) + .map(lambda BS: BS.input_shapes) + .filter(lambda shapes: all( + prod(i for i in s if i > 0) < MAX_ARRAY_SIZE for s in shapes + )) + ) + +two_mutually_broadcastable_shapes = mutually_broadcastable_shapes(2) # Note: This should become hermitian_matrices when complex dtypes are added @composite @@ -196,8 +210,8 @@ def scalars(draw, dtypes, finite=False): dtypes should be one of the shared_* dtypes strategies. """ dtype = draw(dtypes) - if dtype in dtype_ranges: - m, M = dtype_ranges[dtype] + if dtype in dh.dtype_ranges: + m, M = dh.dtype_ranges[dtype] return draw(integers(m, M)) elif dtype == bool_dtype: return draw(booleans()) @@ -229,7 +243,7 @@ def integer_indices(draw, sizes): # Return either a Python integer or a 0-D array with some integer dtype idx = draw(python_integer_indices(sizes)) dtype = draw(integer_dtypes) - m, M = dtype_ranges[dtype] + m, M = dh.dtype_ranges[dtype] if m <= idx <= M: return draw(one_of(just(idx), just(full((), idx, dtype=dtype)))) @@ -298,9 +312,10 @@ def multiaxis_indices(draw, shapes): return tuple(res) -def two_mutual_arrays(dtype_objects=dtype_objects, - two_shapes=two_mutually_broadcastable_shapes): - mutual_dtypes = shared(mutually_promotable_dtypes(dtype_objects)) +def two_mutual_arrays( + dtype_objs=dh.all_dtypes, two_shapes=two_mutually_broadcastable_shapes +): + mutual_dtypes = shared(mutually_promotable_dtypes(dtype_objs)) mutual_shapes = shared(two_shapes) arrays1 = xps.arrays( dtype=mutual_dtypes.map(lambda pair: pair[0]), diff --git a/array_api_tests/meta_tests/test_array_helpers.py b/array_api_tests/meta_tests/test_array_helpers.py index 7d1d3b3c..6a6b4849 100644 --- a/array_api_tests/meta_tests/test_array_helpers.py +++ b/array_api_tests/meta_tests/test_array_helpers.py @@ -1,33 +1,33 @@ -from ..array_helpers import exactly_equal, notequal, int_to_dtype -from ..hypothesis_helpers import integer_dtypes -from ..test_type_promotion import dtype_nbits, dtype_signed -from .._array_module import asarray, nan, equal, all - from hypothesis import given, assume from hypothesis.strategies import integers +from ..array_helpers import exactly_equal, notequal, int_to_dtype +from ..hypothesis_helpers import integer_dtypes +from ..dtype_helpers import dtype_nbits, dtype_signed +from .. import _array_module as xp + # TODO: These meta-tests currently only work with NumPy def test_exactly_equal(): - a = asarray([0, 0., -0., -0., nan, nan, 1, 1]) - b = asarray([0, -1, -0., 0., nan, 1, 1, 2]) + a = xp.asarray([0, 0., -0., -0., xp.nan, xp.nan, 1, 1]) + b = xp.asarray([0, -1, -0., 0., xp.nan, 1, 1, 2]) - res = asarray([True, False, True, False, True, False, True, False]) - assert all(equal(exactly_equal(a, b), res)) + res = xp.asarray([True, False, True, False, True, False, True, False]) + assert xp.all(xp.equal(exactly_equal(a, b), res)) def test_notequal(): - a = asarray([0, 0., -0., -0., nan, nan, 1, 1]) - b = asarray([0, -1, -0., 0., nan, 1, 1, 2]) + a = xp.asarray([0, 0., -0., -0., xp.nan, xp.nan, 1, 1]) + b = xp.asarray([0, -1, -0., 0., xp.nan, 1, 1, 2]) - res = asarray([False, True, False, False, False, True, False, True]) - assert all(equal(notequal(a, b), res)) + res = xp.asarray([False, True, False, False, False, True, False, True]) + assert xp.all(xp.equal(notequal(a, b), res)) @given(integers(), integer_dtypes) def test_int_to_dtype(x, dtype): - n = dtype_nbits(dtype) - signed = dtype_signed(dtype) + n = dtype_nbits[dtype] + signed = dtype_signed[dtype] try: - d = asarray(x, dtype=dtype) + d = xp.asarray(x, dtype=dtype) except OverflowError: assume(False) assert int_to_dtype(x, n, signed) == d diff --git a/array_api_tests/meta_tests/test_hypothesis_helpers.py b/array_api_tests/meta_tests/test_hypothesis_helpers.py index 6af43a26..93a63f8e 100644 --- a/array_api_tests/meta_tests/test_hypothesis_helpers.py +++ b/array_api_tests/meta_tests/test_hypothesis_helpers.py @@ -6,11 +6,12 @@ from .. import _array_module as xp from .._array_module import _UndefinedStub from .. import array_helpers as ah +from .. import dtype_helpers as dh from .. import hypothesis_helpers as hh from ..test_broadcasting import broadcast_shapes from ..test_elementwise_functions import sanity_check -UNDEFINED_DTYPES = any(isinstance(d, _UndefinedStub) for d in ah.dtype_objects) +UNDEFINED_DTYPES = any(isinstance(d, _UndefinedStub) for d in dh.all_dtypes) pytestmark = [pytest.mark.skipif(UNDEFINED_DTYPES, reason="undefined dtypes")] @given(hh.mutually_promotable_dtypes([xp.float32, xp.float64])) diff --git a/array_api_tests/test_broadcasting.py b/array_api_tests/test_broadcasting.py index 0c43f79e..1a1c9c47 100644 --- a/array_api_tests/test_broadcasting.py +++ b/array_api_tests/test_broadcasting.py @@ -10,8 +10,7 @@ from .hypothesis_helpers import shapes, FILTER_UNDEFINED_DTYPES from .pytest_helpers import raises, doesnt_raise, nargs -from .test_type_promotion import (elementwise_function_input_types, - input_types, dtype_mapping) +from .dtype_helpers import func_in_dtypes from .function_stubs import elementwise_functions from . import _array_module from ._array_module import ones, _UndefinedStub @@ -111,14 +110,14 @@ def test_broadcast_shapes_explicit_spec(): @pytest.mark.parametrize('func_name', [i for i in elementwise_functions.__all__ if nargs(i) > 1]) -@given(shape1=shapes, shape2=shapes, dtype=data()) -def test_broadcasting_hypothesis(func_name, shape1, shape2, dtype): +@given(shape1=shapes, shape2=shapes, data=data()) +def test_broadcasting_hypothesis(func_name, shape1, shape2, data): # Internal consistency checks assert nargs(func_name) == 2 - dtype = dtype_mapping[dtype.draw(sampled_from(input_types[elementwise_function_input_types[func_name]]))] - if FILTER_UNDEFINED_DTYPES and isinstance(dtype, _UndefinedStub): - assume(False) + dtype = data.draw(sampled_from(func_in_dtypes[func_name])) + if FILTER_UNDEFINED_DTYPES: + assume(not isinstance(dtype, _UndefinedStub)) func = getattr(_array_module, func_name) if isinstance(func, _array_module._UndefinedStub): diff --git a/array_api_tests/test_creation_functions.py b/array_api_tests/test_creation_functions.py index 91b93882..302bef49 100644 --- a/array_api_tests/test_creation_functions.py +++ b/array_api_tests/test_creation_functions.py @@ -4,11 +4,11 @@ full_like, equal, all, linspace, ones, ones_like, zeros, zeros_like, isnan) from . import _array_module as xp -from .array_helpers import (is_integer_dtype, dtype_ranges, - assert_exactly_equal, isintegral, is_float_dtype) +from .array_helpers import assert_exactly_equal, isintegral from .hypothesis_helpers import (numeric_dtypes, dtypes, MAX_ARRAY_SIZE, shapes, sizes, sqrt_sizes, shared_dtypes, scalars, kwargs) +from . import dtype_helpers as dh from . import xps from hypothesis import assume, given @@ -24,8 +24,8 @@ and (abs(x) > 0.01 if isinstance(x, float) else True)), one_of(none(), numeric_dtypes)) def test_arange(start, stop, step, dtype): - if dtype in dtype_ranges: - m, M = dtype_ranges[dtype] + if dtype in dh.dtype_ranges: + m, M = dh.dtype_ranges[dtype] if (not (m <= start <= M) or isinstance(stop, int) and not (m <= stop <= M) or isinstance(step, int) and not (m <= step <= M)): @@ -33,7 +33,7 @@ def test_arange(start, stop, step, dtype): kwargs = {} if dtype is None else {'dtype': dtype} - all_int = (is_integer_dtype(dtype) + all_int = (dh.is_int_dtype(dtype) and isinstance(start, int) and (stop is None or isinstance(stop, int)) and (step is None or isinstance(step, int))) @@ -75,7 +75,7 @@ def test_empty(shape, kw): out = empty(shape, **kw) dtype = kw.get("dtype", None) or xp.float64 if kw.get("dtype", None) is None: - assert is_float_dtype(out.dtype), f"empty() returned an array with dtype {out.dtype}, but should be the default float dtype" + assert dh.is_float_dtype(out.dtype), f"empty() returned an array with dtype {out.dtype}, but should be the default float dtype" else: assert out.dtype == dtype, f"{dtype=!s}, but empty() returned an array with dtype {out.dtype}" if isinstance(shape, int): @@ -110,7 +110,7 @@ def test_eye(n_rows, n_cols, k, dtype): else: a = eye(n_rows, n_cols, **kwargs) if dtype is None: - assert is_float_dtype(a.dtype), "eye() should return an array with the default floating point dtype" + assert dh.is_float_dtype(a.dtype), "eye() should return an array with the default floating point dtype" else: assert a.dtype == dtype, "eye() did not produce the correct dtype" @@ -152,7 +152,7 @@ def test_full(shape, fill_value, kw): dtype = xp.float64 if kw.get("dtype", None) is None: if dtype == xp.float64: - assert is_float_dtype(out.dtype), f"full() returned an array with dtype {out.dtype}, but should be the default float dtype" + assert dh.is_float_dtype(out.dtype), f"full() returned an array with dtype {out.dtype}, but should be the default float dtype" elif dtype == xp.int64: assert out.dtype == xp.int32 or out.dtype == xp.int64, f"full() returned an array with dtype {out.dtype}, but should be the default integer dtype" else: @@ -160,7 +160,7 @@ def test_full(shape, fill_value, kw): else: assert out.dtype == dtype assert out.shape == shape, f"{shape=}, but full() returned an array with shape {out.shape}" - if is_float_dtype(out.dtype) and math.isnan(fill_value): + if dh.is_float_dtype(out.dtype) and math.isnan(fill_value): assert all(isnan(out)), "full() array did not equal the fill value" else: assert all(equal(out, asarray(fill_value, dtype=dtype))), "full() array did not equal the fill value" @@ -186,7 +186,7 @@ def test_full_like(x, fill_value, kw): else: assert out.dtype == dtype, f"{dtype=!s}, but full_like() returned an array with dtype {out.dtype}" assert out.shape == x.shape, "{x.shape=}, but full_like() returned an array with shape {out.shape}" - if is_float_dtype(dtype) and math.isnan(fill_value): + if dh.is_float_dtype(dtype) and math.isnan(fill_value): assert all(isnan(out)), "full_like() array did not equal the fill value" else: assert all(equal(out, asarray(fill_value, dtype=dtype))), "full_like() array did not equal the fill value" @@ -200,7 +200,7 @@ def test_full_like(x, fill_value, kw): def test_linspace(start, stop, num, dtype, endpoint): # Skip on int start or stop that cannot be exactly represented as a float, # since we do not have good approx_equal helpers yet. - if ((dtype is None or is_float_dtype(dtype)) + if ((dtype is None or dh.is_float_dtype(dtype)) and ((isinstance(start, int) and not isintegral(asarray(start, dtype=dtype))) or (isinstance(stop, int) and not isintegral(asarray(stop, dtype=dtype))))): assume(False) @@ -210,7 +210,7 @@ def test_linspace(start, stop, num, dtype, endpoint): a = linspace(start, stop, num, **kwargs) if dtype is None: - assert is_float_dtype(a.dtype), "linspace() should return an array with the default floating point dtype" + assert dh.is_float_dtype(a.dtype), "linspace() should return an array with the default floating point dtype" else: assert a.dtype == dtype, "linspace() did not produce the correct dtype" @@ -237,9 +237,9 @@ def test_linspace(start, stop, num, dtype, endpoint): def make_one(dtype): - if kwargs is None or is_float_dtype(dtype): + if kwargs is None or dh.is_float_dtype(dtype): return 1.0 - elif is_integer_dtype(dtype): + elif dh.is_int_dtype(dtype): return 1 else: return True @@ -250,7 +250,7 @@ def test_ones(shape, kw): out = ones(shape, **kw) dtype = kw.get("dtype", None) or xp.float64 if kw.get("dtype", None) is None: - assert is_float_dtype(out.dtype), f"ones() returned an array with dtype {out.dtype}, but should be the default float dtype" + assert dh.is_float_dtype(out.dtype), f"ones() returned an array with dtype {out.dtype}, but should be the default float dtype" else: assert out.dtype == dtype, f"{dtype=!s}, but ones() returned an array with dtype {out.dtype}" assert out.shape == shape, f"{shape=}, but empty() returned an array with shape {out.shape}" @@ -273,9 +273,9 @@ def test_ones_like(x, kw): def make_zero(dtype): - if is_float_dtype(dtype): + if dh.is_float_dtype(dtype): return 0.0 - elif is_integer_dtype(dtype): + elif dh.is_int_dtype(dtype): return 0 else: return False @@ -286,7 +286,7 @@ def test_zeros(shape, kw): out = zeros(shape, **kw) dtype = kw.get("dtype", None) or xp.float64 if kw.get("dtype", None) is None: - assert is_float_dtype(out.dtype), "zeros() returned an array with dtype {out.dtype}, but should be the default float dtype" + assert dh.is_float_dtype(out.dtype), "zeros() returned an array with dtype {out.dtype}, but should be the default float dtype" else: assert out.dtype == dtype, f"{dtype=!s}, but zeros() returned an array with dtype {out.dtype}" assert out.shape == shape, "zeros() produced an array with incorrect shape" diff --git a/array_api_tests/test_elementwise_functions.py b/array_api_tests/test_elementwise_functions.py index 70c69f99..ee714458 100644 --- a/array_api_tests/test_elementwise_functions.py +++ b/array_api_tests/test_elementwise_functions.py @@ -17,6 +17,7 @@ from . import _array_module as xp from . import array_helpers as ah from . import hypothesis_helpers as hh +from . import dtype_helpers as dh from . import xps # We might as well use this implementation rather than requiring # mod.broadcast_shapes(). See test_equal() and others. @@ -28,11 +29,11 @@ integer_or_boolean_scalars = hh.array_scalars(hh.integer_or_boolean_dtypes) boolean_scalars = hh.array_scalars(hh.boolean_dtypes) -two_integer_dtypes = hh.mutually_promotable_dtypes(hh.integer_dtype_objects) -two_floating_dtypes = hh.mutually_promotable_dtypes(hh.floating_dtype_objects) -two_numeric_dtypes = hh.mutually_promotable_dtypes(hh.numeric_dtype_objects) -two_integer_or_boolean_dtypes = hh.mutually_promotable_dtypes(hh.integer_or_boolean_dtype_objects) -two_boolean_dtypes = hh.mutually_promotable_dtypes(hh.boolean_dtype_objects) +two_integer_dtypes = hh.mutually_promotable_dtypes(dh.all_int_dtypes) +two_floating_dtypes = hh.mutually_promotable_dtypes(dh.float_dtypes) +two_numeric_dtypes = hh.mutually_promotable_dtypes(dh.numeric_dtypes) +two_integer_or_boolean_dtypes = hh.mutually_promotable_dtypes(dh.bool_and_all_int_dtypes) +two_boolean_dtypes = hh.mutually_promotable_dtypes((xp.bool,)) two_any_dtypes = hh.mutually_promotable_dtypes() @st.composite @@ -44,14 +45,14 @@ def two_array_scalars(draw, dtype1, dtype2): # TODO: refactor this into dtype_helpers.py, see https://github.com/data-apis/array-api-tests/pull/26 def sanity_check(x1, x2): try: - ah.promote_dtypes(x1.dtype, x2.dtype) + dh.promotion_table[x1.dtype, x2.dtype] except ValueError: raise RuntimeError("Error in test generation (probably a bug in the test suite") @given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes)) def test_abs(x): - if ah.is_integer_dtype(x.dtype): - minval = ah.dtype_ranges[x.dtype][0] + if dh.is_int_dtype(x.dtype): + minval = dh.dtype_ranges[x.dtype][0] if minval < 0: # abs of the smallest representable negative integer is not defined mask = xp.not_equal(x, ah.full(x.shape, minval, dtype=x.dtype)) @@ -91,7 +92,7 @@ def test_acosh(x): # to nan, which is already tested in the special cases. ah.assert_exactly_equal(domain, codomain) -@given(*hh.two_mutual_arrays(hh.numeric_dtype_objects)) +@given(*hh.two_mutual_arrays(dh.numeric_dtypes)) def test_add(x1, x2): sanity_check(x1, x2) a = xp.add(x1, x2) @@ -133,7 +134,7 @@ def test_atan(x): # mapped to nan, which is already tested in the special cases. ah.assert_exactly_equal(domain, codomain) -@given(*hh.two_mutual_arrays(hh.floating_dtype_objects)) +@given(*hh.two_mutual_arrays(dh.float_dtypes)) def test_atan2(x1, x2): sanity_check(x1, x2) a = xp.atan2(x1, x2) @@ -180,9 +181,8 @@ def test_atanh(x): # mapped to nan, which is already tested in the special cases. ah.assert_exactly_equal(domain, codomain) -@given(*hh.two_mutual_arrays(ah.integer_or_boolean_dtype_objects)) +@given(*hh.two_mutual_arrays(dh.bool_and_all_int_dtypes)) def test_bitwise_and(x1, x2): - from .test_type_promotion import dtype_nbits, dtype_signed sanity_check(x1, x2) out = xp.bitwise_and(x1, x2) @@ -205,13 +205,11 @@ def test_bitwise_and(x1, x2): val2 = int(_x2[idx]) res = int(out[idx]) vals_and = val1 & val2 - vals_and = ah.int_to_dtype(vals_and, dtype_nbits(out.dtype), dtype_signed(out.dtype)) + vals_and = ah.int_to_dtype(vals_and, dh.dtype_nbits[out.dtype], dh.dtype_signed[out.dtype]) assert vals_and == res - -@given(*hh.two_mutual_arrays(ah.integer_dtype_objects)) +@given(*hh.two_mutual_arrays(dh.all_int_dtypes)) def test_bitwise_left_shift(x1, x2): - from .test_type_promotion import dtype_nbits, dtype_signed sanity_check(x1, x2) assume(not ah.any(ah.isnegative(x2))) out = xp.bitwise_left_shift(x1, x2) @@ -228,15 +226,13 @@ def test_bitwise_left_shift(x1, x2): val2 = int(_x2[idx]) res = int(out[idx]) # We avoid shifting very large ints - vals_shift = val1 << val2 if val2 < dtype_nbits(out.dtype) else 0 - vals_shift = ah.int_to_dtype(vals_shift, dtype_nbits(out.dtype), dtype_signed(out.dtype)) + vals_shift = val1 << val2 if val2 < dh.dtype_nbits[out.dtype] else 0 + vals_shift = ah.int_to_dtype(vals_shift, dh.dtype_nbits[out.dtype], dh.dtype_signed[out.dtype]) assert vals_shift == res @given(xps.arrays(dtype=hh.integer_or_boolean_dtypes, shape=hh.shapes)) def test_bitwise_invert(x): - from .test_type_promotion import dtype_nbits, dtype_signed out = xp.bitwise_invert(x) - # Compare against the Python ~ operator. if out.dtype == xp.bool: for idx in ah.ndindex(out.shape): @@ -248,12 +244,11 @@ def test_bitwise_invert(x): val = int(x[idx]) res = int(out[idx]) val_invert = ~val - val_invert = ah.int_to_dtype(val_invert, dtype_nbits(out.dtype), dtype_signed(out.dtype)) + val_invert = ah.int_to_dtype(val_invert, dh.dtype_nbits[out.dtype], dh.dtype_signed[out.dtype]) assert val_invert == res -@given(*hh.two_mutual_arrays(ah.integer_or_boolean_dtype_objects)) +@given(*hh.two_mutual_arrays(dh.bool_and_all_int_dtypes)) def test_bitwise_or(x1, x2): - from .test_type_promotion import dtype_nbits, dtype_signed sanity_check(x1, x2) out = xp.bitwise_or(x1, x2) @@ -276,12 +271,11 @@ def test_bitwise_or(x1, x2): val2 = int(_x2[idx]) res = int(out[idx]) vals_or = val1 | val2 - vals_or = ah.int_to_dtype(vals_or, dtype_nbits(out.dtype), dtype_signed(out.dtype)) + vals_or = ah.int_to_dtype(vals_or, dh.dtype_nbits[out.dtype], dh.dtype_signed[out.dtype]) assert vals_or == res -@given(*hh.two_mutual_arrays(ah.integer_dtype_objects)) +@given(*hh.two_mutual_arrays(dh.all_int_dtypes)) def test_bitwise_right_shift(x1, x2): - from .test_type_promotion import dtype_nbits, dtype_signed sanity_check(x1, x2) assume(not ah.any(ah.isnegative(x2))) out = xp.bitwise_right_shift(x1, x2) @@ -298,12 +292,11 @@ def test_bitwise_right_shift(x1, x2): val2 = int(_x2[idx]) res = int(out[idx]) vals_shift = val1 >> val2 - vals_shift = ah.int_to_dtype(vals_shift, dtype_nbits(out.dtype), dtype_signed(out.dtype)) + vals_shift = ah.int_to_dtype(vals_shift, dh.dtype_nbits[out.dtype], dh.dtype_signed[out.dtype]) assert vals_shift == res -@given(*hh.two_mutual_arrays(ah.integer_or_boolean_dtype_objects)) +@given(*hh.two_mutual_arrays(dh.bool_and_all_int_dtypes)) def test_bitwise_xor(x1, x2): - from .test_type_promotion import dtype_nbits, dtype_signed sanity_check(x1, x2) out = xp.bitwise_xor(x1, x2) @@ -326,7 +319,7 @@ def test_bitwise_xor(x1, x2): val2 = int(_x2[idx]) res = int(out[idx]) vals_xor = val1 ^ val2 - vals_xor = ah.int_to_dtype(vals_xor, dtype_nbits(out.dtype), dtype_signed(out.dtype)) + vals_xor = ah.int_to_dtype(vals_xor, dh.dtype_nbits[out.dtype], dh.dtype_signed[out.dtype]) assert vals_xor == res @given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes)) @@ -361,7 +354,7 @@ def test_cosh(x): # mapped to nan, which is already tested in the special cases. ah.assert_exactly_equal(domain, codomain) -@given(*hh.two_mutual_arrays(hh.floating_dtype_objects)) +@given(*hh.two_mutual_arrays(dh.float_dtypes)) def test_divide(x1, x2): sanity_check(x1, x2) xp.divide(x1, x2) @@ -401,14 +394,15 @@ def test_equal(x1, x2): # test_elementwise_function_two_arg_bool_type_promotion() in # test_type_promotion.py. The type promotion for ah.equal() is not *really* # tested in that file, because doing so requires doing the consistency - # check we do here rather than st.just checking the res dtype. - promoted_dtype = ah.promote_dtypes(x1.dtype, x2.dtype) + + # check we do here rather than just checking the result dtype. + promoted_dtype = dh.promotion_table[x1.dtype, x2.dtype] _x1 = ah.asarray(_x1, dtype=promoted_dtype) _x2 = ah.asarray(_x2, dtype=promoted_dtype) - if ah.is_integer_dtype(promoted_dtype): + if dh.is_int_dtype(promoted_dtype): scalar_func = int - elif ah.is_float_dtype(promoted_dtype): + elif dh.is_float_dtype(promoted_dtype): scalar_func = float else: scalar_func = bool @@ -453,10 +447,10 @@ def test_floor(x): integers = ah.isintegral(x) ah.assert_exactly_equal(a[integers], x[integers]) -@given(*hh.two_mutual_arrays(hh.numeric_dtype_objects)) +@given(*hh.two_mutual_arrays(dh.numeric_dtypes)) def test_floor_divide(x1, x2): sanity_check(x1, x2) - if ah.is_integer_dtype(x1.dtype): + if dh.is_int_dtype(x1.dtype): # The spec does not specify the behavior for division by 0 for integer # dtypes. A library may choose to raise an exception in this case, so # we avoid passing it in entirely. @@ -477,7 +471,7 @@ def test_floor_divide(x1, x2): # TODO: Test the exact output for floor_divide. -@given(*hh.two_mutual_arrays(hh.numeric_dtype_objects)) +@given(*hh.two_mutual_arrays(dh.numeric_dtypes)) def test_greater(x1, x2): sanity_check(x1, x2) a = xp.greater(x1, x2) @@ -488,13 +482,13 @@ def test_greater(x1, x2): _x1 = xp.broadcast_to(x1, shape) _x2 = xp.broadcast_to(x2, shape) - promoted_dtype = ah.promote_dtypes(x1.dtype, x2.dtype) + promoted_dtype = dh.promotion_table[x1.dtype, x2.dtype] _x1 = ah.asarray(_x1, dtype=promoted_dtype) _x2 = ah.asarray(_x2, dtype=promoted_dtype) - if ah.is_integer_dtype(promoted_dtype): + if dh.is_int_dtype(promoted_dtype): scalar_func = int - elif ah.is_float_dtype(promoted_dtype): + elif dh.is_float_dtype(promoted_dtype): scalar_func = float else: scalar_func = bool @@ -506,7 +500,7 @@ def test_greater(x1, x2): assert aidx.shape == x1idx.shape == x2idx.shape assert bool(aidx) == (scalar_func(x1idx) > scalar_func(x2idx)) -@given(*hh.two_mutual_arrays(hh.numeric_dtype_objects)) +@given(*hh.two_mutual_arrays(dh.numeric_dtypes)) def test_greater_equal(x1, x2): sanity_check(x1, x2) a = xp.greater_equal(x1, x2) @@ -517,13 +511,13 @@ def test_greater_equal(x1, x2): _x1 = xp.broadcast_to(x1, shape) _x2 = xp.broadcast_to(x2, shape) - promoted_dtype = ah.promote_dtypes(x1.dtype, x2.dtype) + promoted_dtype = dh.promotion_table[x1.dtype, x2.dtype] _x1 = ah.asarray(_x1, dtype=promoted_dtype) _x2 = ah.asarray(_x2, dtype=promoted_dtype) - if ah.is_integer_dtype(promoted_dtype): + if dh.is_int_dtype(promoted_dtype): scalar_func = int - elif ah.is_float_dtype(promoted_dtype): + elif dh.is_float_dtype(promoted_dtype): scalar_func = float else: scalar_func = bool @@ -539,14 +533,14 @@ def test_greater_equal(x1, x2): def test_isfinite(x): a = ah.isfinite(x) TRUE = ah.true(x.shape) - if ah.is_integer_dtype(x.dtype): + if dh.is_int_dtype(x.dtype): ah.assert_exactly_equal(a, TRUE) # Test that isfinite, isinf, and isnan are self-consistent. inf = ah.logical_or(xp.isinf(x), ah.isnan(x)) ah.assert_exactly_equal(a, ah.logical_not(inf)) # Test the exact value by comparing to the math version - if ah.is_float_dtype(x.dtype): + if dh.is_float_dtype(x.dtype): for idx in ah.ndindex(x.shape): s = float(x[idx]) assert bool(a[idx]) == math.isfinite(s) @@ -555,13 +549,13 @@ def test_isfinite(x): def test_isinf(x): a = xp.isinf(x) FALSE = ah.false(x.shape) - if ah.is_integer_dtype(x.dtype): + if dh.is_int_dtype(x.dtype): ah.assert_exactly_equal(a, FALSE) finite_or_nan = ah.logical_or(ah.isfinite(x), ah.isnan(x)) ah.assert_exactly_equal(a, ah.logical_not(finite_or_nan)) # Test the exact value by comparing to the math version - if ah.is_float_dtype(x.dtype): + if dh.is_float_dtype(x.dtype): for idx in ah.ndindex(x.shape): s = float(x[idx]) assert bool(a[idx]) == math.isinf(s) @@ -570,18 +564,18 @@ def test_isinf(x): def test_isnan(x): a = ah.isnan(x) FALSE = ah.false(x.shape) - if ah.is_integer_dtype(x.dtype): + if dh.is_int_dtype(x.dtype): ah.assert_exactly_equal(a, FALSE) finite_or_inf = ah.logical_or(ah.isfinite(x), xp.isinf(x)) ah.assert_exactly_equal(a, ah.logical_not(finite_or_inf)) # Test the exact value by comparing to the math version - if ah.is_float_dtype(x.dtype): + if dh.is_float_dtype(x.dtype): for idx in ah.ndindex(x.shape): s = float(x[idx]) assert bool(a[idx]) == math.isnan(s) -@given(*hh.two_mutual_arrays(hh.numeric_dtype_objects)) +@given(*hh.two_mutual_arrays(dh.numeric_dtypes)) def test_less(x1, x2): sanity_check(x1, x2) a = ah.less(x1, x2) @@ -592,13 +586,13 @@ def test_less(x1, x2): _x1 = xp.broadcast_to(x1, shape) _x2 = xp.broadcast_to(x2, shape) - promoted_dtype = ah.promote_dtypes(x1.dtype, x2.dtype) + promoted_dtype = dh.promotion_table[x1.dtype, x2.dtype] _x1 = ah.asarray(_x1, dtype=promoted_dtype) _x2 = ah.asarray(_x2, dtype=promoted_dtype) - if ah.is_integer_dtype(promoted_dtype): + if dh.is_int_dtype(promoted_dtype): scalar_func = int - elif ah.is_float_dtype(promoted_dtype): + elif dh.is_float_dtype(promoted_dtype): scalar_func = float else: scalar_func = bool @@ -610,7 +604,7 @@ def test_less(x1, x2): assert aidx.shape == x1idx.shape == x2idx.shape assert bool(aidx) == (scalar_func(x1idx) < scalar_func(x2idx)) -@given(*hh.two_mutual_arrays(hh.numeric_dtype_objects)) +@given(*hh.two_mutual_arrays(dh.numeric_dtypes)) def test_less_equal(x1, x2): sanity_check(x1, x2) a = ah.less_equal(x1, x2) @@ -621,13 +615,13 @@ def test_less_equal(x1, x2): _x1 = xp.broadcast_to(x1, shape) _x2 = xp.broadcast_to(x2, shape) - promoted_dtype = ah.promote_dtypes(x1.dtype, x2.dtype) + promoted_dtype = dh.promotion_table[x1.dtype, x2.dtype] _x1 = ah.asarray(_x1, dtype=promoted_dtype) _x2 = ah.asarray(_x2, dtype=promoted_dtype) - if ah.is_integer_dtype(promoted_dtype): + if dh.is_int_dtype(promoted_dtype): scalar_func = int - elif ah.is_float_dtype(promoted_dtype): + elif dh.is_float_dtype(promoted_dtype): scalar_func = float else: scalar_func = bool @@ -683,7 +677,7 @@ def test_log10(x): # mapped to nan, which is already tested in the special cases. ah.assert_exactly_equal(domain, codomain) -@given(*hh.two_mutual_arrays(hh.floating_dtype_objects)) +@given(*hh.two_mutual_arrays(dh.float_dtypes)) def test_logaddexp(x1, x2): sanity_check(x1, x2) xp.logaddexp(x1, x2) @@ -737,7 +731,7 @@ def test_logical_xor(x1, x2): for idx in ah.ndindex(shape): assert a[idx] == (bool(_x1[idx]) ^ bool(_x2[idx])) -@given(*hh.two_mutual_arrays(hh.numeric_dtype_objects)) +@given(*hh.two_mutual_arrays(dh.numeric_dtypes)) def test_multiply(x1, x2): sanity_check(x1, x2) a = xp.multiply(x1, x2) @@ -754,8 +748,8 @@ def test_negative(x): ah.assert_exactly_equal(x, ah.negative(out)) mask = ah.isfinite(x) - if ah.is_integer_dtype(x.dtype): - minval = ah.dtype_ranges[x.dtype][0] + if dh.is_int_dtype(x.dtype): + minval = dh.dtype_ranges[x.dtype][0] if minval < 0: # negative of the smallest representable negative integer is not defined mask = xp.not_equal(x, ah.full(x.shape, minval, dtype=x.dtype)) @@ -777,13 +771,13 @@ def test_not_equal(x1, x2): _x1 = xp.broadcast_to(x1, shape) _x2 = xp.broadcast_to(x2, shape) - promoted_dtype = ah.promote_dtypes(x1.dtype, x2.dtype) + promoted_dtype = dh.promotion_table[x1.dtype, x2.dtype] _x1 = ah.asarray(_x1, dtype=promoted_dtype) _x2 = ah.asarray(_x2, dtype=promoted_dtype) - if ah.is_integer_dtype(promoted_dtype): + if dh.is_int_dtype(promoted_dtype): scalar_func = int - elif ah.is_float_dtype(promoted_dtype): + elif dh.is_float_dtype(promoted_dtype): scalar_func = float else: scalar_func = bool @@ -802,7 +796,7 @@ def test_positive(x): # Positive does nothing ah.assert_exactly_equal(out, x) -@given(*hh.two_mutual_arrays(hh.floating_dtype_objects)) +@given(*hh.two_mutual_arrays(dh.float_dtypes)) def test_pow(x1, x2): sanity_check(x1, x2) xp.pow(x1, x2) @@ -812,7 +806,7 @@ def test_pow(x1, x2): # numbers. We could test that this does implement IEEE 754 pow, but we # don't yet have those sorts in general for this module. -@given(*hh.two_mutual_arrays(hh.numeric_dtype_objects)) +@given(*hh.two_mutual_arrays(dh.numeric_dtypes)) def test_remainder(x1, x2): assume(len(x1.shape) <= len(x2.shape)) # TODO: rework same sign testing below to remove this sanity_check(x1, x2) @@ -892,7 +886,7 @@ def test_trunc(x): out = xp.trunc(x) assert out.dtype == x.dtype, f"{x.dtype=!s} but {out.dtype=!s}" assert out.shape == x.shape, f"{x.shape=} but {out.shape=}" - if x.dtype in hh.integer_dtype_objects: + if x.dtype in dh.all_int_dtypes: assert ah.all(ah.equal(x, out)), f"{x=!s} but {out=!s}" else: finite = ah.isfinite(x) diff --git a/array_api_tests/test_linalg.py b/array_api_tests/test_linalg.py index 7349276c..bc0b3ae5 100644 --- a/array_api_tests/test_linalg.py +++ b/array_api_tests/test_linalg.py @@ -17,9 +17,7 @@ from hypothesis.strategies import (booleans, composite, none, tuples, integers, shared, sampled_from) -from .array_helpers import (assert_exactly_equal, ndindex, asarray, - numeric_dtype_objects, promote_dtypes, equal, - zero, infinity) +from .array_helpers import assert_exactly_equal, ndindex, asarray, equal, zero, infinity from .hypothesis_helpers import (xps, dtypes, shapes, kwargs, matrix_shapes, square_matrix_shapes, symmetric_matrices, positive_definite_matrices, MAX_ARRAY_SIZE, @@ -27,6 +25,7 @@ mutually_promotable_dtypes, one_d_shapes, two_mutually_broadcastable_shapes, SQRT_MAX_ARRAY_SIZE) from .pytest_helpers import raises +from . import dtype_helpers as dh from .test_broadcasting import broadcast_shapes @@ -94,7 +93,7 @@ def test_cholesky(x, kw): @composite -def cross_args(draw, dtype_objects=numeric_dtype_objects): +def cross_args(draw, dtype_objects=dh.numeric_dtypes): """ cross() requires two arrays with a size 3 in the 'axis' dimension @@ -135,7 +134,7 @@ def test_cross(x1_x2_kw): res = linalg.cross(x1, x2, **kw) - assert res.dtype == promote_dtypes(x1, x2), "cross() did not return the correct dtype" + assert res.dtype == dh.promotion_table[x1, x2], "cross() did not return the correct dtype" assert res.shape == shape, "cross() did not return the correct shape" # cross is too different from other functions to use _test_stacks, and it @@ -254,7 +253,7 @@ def test_inv(x): # TODO: Test that the result is actually the inverse @given( - *two_mutual_arrays(numeric_dtype_objects) + *two_mutual_arrays(dh.numeric_dtypes) ) def test_matmul(x1, x2): # TODO: Make this also test the @ operator @@ -271,7 +270,7 @@ def test_matmul(x1, x2): else: res = linalg.matmul(x1, x2) - assert res.dtype == promote_dtypes(x1, x2), "matmul() did not return the correct dtype" + assert res.dtype == dh.promotion_table[x1, x2], "matmul() did not return the correct dtype" if len(x1.shape) == len(x2.shape) == 1: assert res.shape == () @@ -342,7 +341,7 @@ def test_matrix_transpose(x): _test_stacks(linalg.matrix_transpose, x, res=res, true_val=true_val) @given( - *two_mutual_arrays(dtype_objects=numeric_dtype_objects, + *two_mutual_arrays(dtype_objs=dh.numeric_dtypes, two_shapes=tuples(one_d_shapes, one_d_shapes)) ) def test_outer(x1, x2): @@ -352,7 +351,7 @@ def test_outer(x1, x2): shape = (x1.shape[0], x2.shape[0]) assert res.shape == shape, "outer() did not return the correct shape" - assert res.dtype == promote_dtypes(x1, x2), "outer() did not return the correct dtype" + assert res.dtype == dh.promotion_table[x1, x2], "outer() did not return the correct dtype" if 0 in shape: true_res = _array_module.empty(shape, dtype=res.dtype) diff --git a/array_api_tests/test_signatures.py b/array_api_tests/test_signatures.py index 10bcd4ba..e8106985 100644 --- a/array_api_tests/test_signatures.py +++ b/array_api_tests/test_signatures.py @@ -4,7 +4,7 @@ from ._array_module import mod, mod_name, ones, eye, float64, bool, int64 from .pytest_helpers import raises, doesnt_raise -from .test_type_promotion import elementwise_function_input_types, operators_to_functions +from . import dtype_helpers as dh from . import function_stubs @@ -160,12 +160,13 @@ def test_function_positional_args(name): dtype = None if (name.startswith('__i') and name not in ['__int__', '__invert__', '__index__'] or name.startswith('__r') and name != '__rshift__'): - n = operators_to_functions[name[:2] + name[3:]] + n = f'__{name[3:]}' else: - n = operators_to_functions.get(name, name) - if 'boolean' in elementwise_function_input_types.get(n, 'floating'): + n = name + in_dtypes = dh.func_in_dtypes.get(n, dh.float_dtypes) + if bool in in_dtypes: dtype = bool - elif 'integer' in elementwise_function_input_types.get(n, 'floating'): + elif all(d in in_dtypes for d in dh.all_int_dtypes): dtype = int64 if array_method(name): diff --git a/array_api_tests/test_type_promotion.py b/array_api_tests/test_type_promotion.py index 50ba2e77..ea4e6eb8 100644 --- a/array_api_tests/test_type_promotion.py +++ b/array_api_tests/test_type_promotion.py @@ -1,777 +1,291 @@ """ https://data-apis.github.io/array-api/latest/API_specification/type_promotion.html """ +from collections import defaultdict +from typing import Tuple, Type, Union, List import pytest - -from hypothesis import given -from hypothesis.strategies import from_type, data, integers, just - -from .hypothesis_helpers import (shapes, two_mutually_broadcastable_shapes, - two_broadcastable_shapes, scalars) +from hypothesis import assume, given, reject +from hypothesis import strategies as st + +from . import _array_module as xp +from . import array_helpers as ah +from . import dtype_helpers as dh +from . import hypothesis_helpers as hh +from . import xps +from .function_stubs import elementwise_functions from .pytest_helpers import nargs -from .array_helpers import assert_exactly_equal, dtype_ranges -from .function_stubs import elementwise_functions -from ._array_module import (full, int8, int16, int32, int64, uint8, - uint16, uint32, uint64, float32, float64, bool as - bool_dtype) -from . import _array_module - -dtype_mapping = { - 'i1': int8, - 'i2': int16, - 'i4': int32, - 'i8': int64, - 'u1': uint8, - 'u2': uint16, - 'u4': uint32, - 'u8': uint64, - 'f4': float32, - 'f8': float64, - 'b': bool_dtype, -} - -reverse_dtype_mapping = {v: k for k, v in dtype_mapping.items()} - -def dtype_nbits(dtype): - if dtype == int8: - return 8 - elif dtype == int16: - return 16 - elif dtype == int32: - return 32 - elif dtype == int64: - return 64 - elif dtype == uint8: - return 8 - elif dtype == uint16: - return 16 - elif dtype == uint32: - return 32 - elif dtype == uint64: - return 64 - elif dtype == float32: - return 32 - elif dtype == float64: - return 64 - else: - raise ValueError(f"dtype_nbits is not defined for {dtype}") - -def dtype_signed(dtype): - if dtype in [int8, int16, int32, int64]: - return True - elif dtype in [uint8, uint16, uint32, uint64]: - return False - raise ValueError("dtype_signed is only defined for integer dtypes") - -signed_integer_promotion_table = { - ('i1', 'i1'): 'i1', - ('i1', 'i2'): 'i2', - ('i1', 'i4'): 'i4', - ('i1', 'i8'): 'i8', - ('i2', 'i1'): 'i2', - ('i2', 'i2'): 'i2', - ('i2', 'i4'): 'i4', - ('i2', 'i8'): 'i8', - ('i4', 'i1'): 'i4', - ('i4', 'i2'): 'i4', - ('i4', 'i4'): 'i4', - ('i4', 'i8'): 'i8', - ('i8', 'i1'): 'i8', - ('i8', 'i2'): 'i8', - ('i8', 'i4'): 'i8', - ('i8', 'i8'): 'i8', -} - -unsigned_integer_promotion_table = { - ('u1', 'u1'): 'u1', - ('u1', 'u2'): 'u2', - ('u1', 'u4'): 'u4', - ('u1', 'u8'): 'u8', - ('u2', 'u1'): 'u2', - ('u2', 'u2'): 'u2', - ('u2', 'u4'): 'u4', - ('u2', 'u8'): 'u8', - ('u4', 'u1'): 'u4', - ('u4', 'u2'): 'u4', - ('u4', 'u4'): 'u4', - ('u4', 'u8'): 'u8', - ('u8', 'u1'): 'u8', - ('u8', 'u2'): 'u8', - ('u8', 'u4'): 'u8', - ('u8', 'u8'): 'u8', -} - -mixed_signed_unsigned_promotion_table = { - ('i1', 'u1'): 'i2', - ('i1', 'u2'): 'i4', - ('i1', 'u4'): 'i8', - ('i2', 'u1'): 'i2', - ('i2', 'u2'): 'i4', - ('i2', 'u4'): 'i8', - ('i4', 'u1'): 'i4', - ('i4', 'u2'): 'i4', - ('i4', 'u4'): 'i8', - ('i8', 'u1'): 'i8', - ('i8', 'u2'): 'i8', - ('i8', 'u4'): 'i8', -} - -flipped_mixed_signed_unsigned_promotion_table = {(u, i): p for (i, u), p in mixed_signed_unsigned_promotion_table.items()} - -float_promotion_table = { - ('f4', 'f4'): 'f4', - ('f4', 'f8'): 'f8', - ('f8', 'f4'): 'f8', - ('f8', 'f8'): 'f8', -} - -boolean_promotion_table = { - ('b', 'b'): 'b', -} - -promotion_table = { - **signed_integer_promotion_table, - **unsigned_integer_promotion_table, - **mixed_signed_unsigned_promotion_table, - **flipped_mixed_signed_unsigned_promotion_table, - **float_promotion_table, - **boolean_promotion_table, -} - -input_types = { - 'any': sorted(set(promotion_table.values())), - 'boolean': sorted(set(boolean_promotion_table.values())), - 'floating': sorted(set(float_promotion_table.values())), - 'integer': sorted(set({**signed_integer_promotion_table, - **unsigned_integer_promotion_table}.values())), - 'integer_or_boolean': sorted(set({**signed_integer_promotion_table, - **unsigned_integer_promotion_table, - **boolean_promotion_table}.values())), - 'numeric': sorted(set({**float_promotion_table, - **signed_integer_promotion_table, - **unsigned_integer_promotion_table}.values())), -} - -binary_operators = { - '__add__': '+', - '__and__': '&', - '__eq__': '==', - '__floordiv__': '//', - '__ge__': '>=', - '__gt__': '>', - '__le__': '<=', - '__lshift__': '<<', - '__lt__': '<', - '__matmul__': '@', - '__mod__': '%', - '__mul__': '*', - '__ne__': '!=', - '__or__': '|', - '__pow__': '**', - '__rshift__': '>>', - '__sub__': '-', - '__truediv__': '/', - '__xor__': '^', -} - -unary_operators = { - '__abs__': 'abs()', - '__invert__': '~', - '__neg__': '-', - '__pos__': '+', -} - - -operators_to_functions = { - '__abs__': 'abs', - '__add__': 'add', - '__and__': 'bitwise_and', - '__eq__': 'equal', - '__floordiv__': 'floor_divide', - '__ge__': 'greater_equal', - '__gt__': 'greater', - '__le__': 'less_equal', - '__lshift__': 'bitwise_left_shift', - '__lt__': 'less', - '__matmul__': 'matmul', - '__mod__': 'remainder', - '__mul__': 'multiply', - '__ne__': 'not_equal', - '__or__': 'bitwise_or', - '__pow__': 'pow', - '__rshift__': 'bitwise_right_shift', - '__sub__': 'subtract', - '__truediv__': 'divide', - '__xor__': 'bitwise_xor', - '__invert__': 'bitwise_invert', - '__neg__': 'negative', - '__pos__': 'positive', -} - -dtypes_to_scalars = { - 'b': [bool], - 'i1': [int], - 'i2': [int], - 'i4': [int], - 'i8': [int], - # Note: unsigned int dtypes only correspond to positive integers - 'u1': [int], - 'u2': [int], - 'u4': [int], - 'u8': [int], - 'f4': [int, float], - 'f8': [int, float], -} - -elementwise_function_input_types = { - 'abs': 'numeric', - 'acos': 'floating', - 'acosh': 'floating', - 'add': 'numeric', - 'asin': 'floating', - 'asinh': 'floating', - 'atan': 'floating', - 'atan2': 'floating', - 'atanh': 'floating', - 'bitwise_and': 'integer_or_boolean', - 'bitwise_invert': 'integer_or_boolean', - 'bitwise_left_shift': 'integer', - 'bitwise_or': 'integer_or_boolean', - 'bitwise_right_shift': 'integer', - 'bitwise_xor': 'integer_or_boolean', - 'ceil': 'numeric', - 'cos': 'floating', - 'cosh': 'floating', - 'divide': 'floating', - 'equal': 'any', - 'exp': 'floating', - 'expm1': 'floating', - 'floor': 'numeric', - 'floor_divide': 'numeric', - 'greater': 'numeric', - 'greater_equal': 'numeric', - 'isfinite': 'numeric', - 'isinf': 'numeric', - 'isnan': 'numeric', - 'less': 'numeric', - 'less_equal': 'numeric', - 'log': 'floating', - 'logaddexp': 'floating', - 'log10': 'floating', - 'log1p': 'floating', - 'log2': 'floating', - 'logical_and': 'boolean', - 'logical_not': 'boolean', - 'logical_or': 'boolean', - 'logical_xor': 'boolean', - 'multiply': 'numeric', - 'negative': 'numeric', - 'not_equal': 'any', - 'positive': 'numeric', - 'pow': 'floating', - 'remainder': 'numeric', - 'round': 'numeric', - 'sign': 'numeric', - 'sin': 'floating', - 'sinh': 'floating', - 'sqrt': 'floating', - 'square': 'numeric', - 'subtract': 'numeric', - 'tan': 'floating', - 'tanh': 'floating', - 'trunc': 'numeric', -} - -elementwise_function_output_types = { - 'abs': 'promoted', - 'acos': 'promoted', - 'acosh': 'promoted', - 'add': 'promoted', - 'asin': 'promoted', - 'asinh': 'promoted', - 'atan': 'promoted', - 'atan2': 'promoted', - 'atanh': 'promoted', - 'bitwise_and': 'promoted', - 'bitwise_invert': 'promoted', - 'bitwise_left_shift': 'promoted', - 'bitwise_or': 'promoted', - 'bitwise_right_shift': 'promoted', - 'bitwise_xor': 'promoted', - 'ceil': 'promoted', - 'cos': 'promoted', - 'cosh': 'promoted', - 'divide': 'promoted', - 'equal': 'bool', - 'exp': 'promoted', - 'expm1': 'promoted', - 'floor': 'promoted', - 'floor_divide': 'promoted', - 'greater': 'bool', - 'greater_equal': 'bool', - 'isfinite': 'bool', - 'isinf': 'bool', - 'isnan': 'bool', - 'less': 'bool', - 'less_equal': 'bool', - 'log': 'promoted', - 'logaddexp': 'promoted', - 'log10': 'promoted', - 'log1p': 'promoted', - 'log2': 'promoted', - 'logical_and': 'bool', - 'logical_not': 'bool', - 'logical_or': 'bool', - 'logical_xor': 'bool', - 'multiply': 'promoted', - 'negative': 'promoted', - 'not_equal': 'bool', - 'positive': 'promoted', - 'pow': 'promoted', - 'remainder': 'promoted', - 'round': 'promoted', - 'sign': 'promoted', - 'sin': 'promoted', - 'sinh': 'promoted', - 'sqrt': 'promoted', - 'square': 'promoted', - 'subtract': 'promoted', - 'tan': 'promoted', - 'tanh': 'promoted', - 'trunc': 'promoted', -} - -elementwise_function_two_arg_func_names = [func_name for func_name in - elementwise_functions.__all__ if - nargs(func_name) > 1] - -elementwise_function_two_arg_func_names_bool = [func_name for func_name in - elementwise_function_two_arg_func_names - if - elementwise_function_output_types[func_name] - == 'bool'] - -elementwise_function_two_arg_bool_parametrize_inputs = [(func_name, dtypes) - for func_name in elementwise_function_two_arg_func_names_bool - for dtypes in promotion_table.keys() if all(d in - input_types[elementwise_function_input_types[func_name]] - for d in dtypes) - ] - -elementwise_function_two_arg_bool_parametrize_ids = ['-'.join((n, d1, d2)) for n, (d1, d2) - in elementwise_function_two_arg_bool_parametrize_inputs] - -# TODO: These functions should still do type promotion internally, but we do -# not test this here (it is tested in the corresponding tests in -# test_elementwise_functions.py). This can affect the resulting values if not -# done correctly. For example, greater_equal(array(1.0, dtype=float32), -# array(1.00000001, dtype=float64)) will be wrong if the float64 array is -# downcast to float32. See for instance -# https://github.com/numpy/numpy/issues/10322. -@pytest.mark.parametrize('func_name,dtypes', - elementwise_function_two_arg_bool_parametrize_inputs, - ids=elementwise_function_two_arg_bool_parametrize_ids) -# The spec explicitly requires type promotion to work for shape 0 -# Unfortunately, data(), isn't compatible with @example, so this is commented -# out for now. -# @example(shape=(0,)) -@given(two_shapes=two_mutually_broadcastable_shapes, fillvalues=data()) -def test_elementwise_function_two_arg_bool_type_promotion(func_name, - two_shapes, dtypes, - fillvalues): - assert nargs(func_name) == 2 - func = getattr(_array_module, func_name) - - type1, type2 = dtypes - dtype1 = dtype_mapping[type1] - dtype2 = dtype_mapping[type2] - - fillvalue1 = fillvalues.draw(scalars(just(dtype1))) - if func_name in ['bitwise_left_shift', 'bitwise_right_shift']: - fillvalue2 = fillvalues.draw(scalars(just(dtype2)).filter(lambda x: x > 0)) - else: - fillvalue2 = fillvalues.draw(scalars(just(dtype2))) - - - for i in [func, dtype1, dtype2]: - if isinstance(i, _array_module._UndefinedStub): - i._raise() - - shape1, shape2 = two_shapes - a1 = full(shape1, fillvalue1, dtype=dtype1) - a2 = full(shape2, fillvalue2, dtype=dtype2) - res = func(a1, a2) - - assert res.dtype == bool_dtype, f"{func_name}({dtype1}, {dtype2}) promoted to {res.dtype}, should have promoted to bool (shapes={shape1, shape2})" - -elementwise_function_two_arg_func_names_promoted = [func_name for func_name in - elementwise_function_two_arg_func_names - if - elementwise_function_output_types[func_name] - == 'promoted'] - -elementwise_function_two_arg_promoted_parametrize_inputs = [(func_name, dtypes) - for func_name in elementwise_function_two_arg_func_names_promoted - for dtypes in promotion_table.items() if all(d in - input_types[elementwise_function_input_types[func_name]] - for d in dtypes[0]) - ] - -elementwise_function_two_arg_promoted_parametrize_ids = ['-'.join((n, d1, d2)) for n, ((d1, d2), _) - in elementwise_function_two_arg_promoted_parametrize_inputs] - -# TODO: Extend this to all functions (not just elementwise), and handle -# functions that take more than 2 args -@pytest.mark.parametrize('func_name,dtypes', - elementwise_function_two_arg_promoted_parametrize_inputs, - ids=elementwise_function_two_arg_promoted_parametrize_ids) -# The spec explicitly requires type promotion to work for shape 0 -# Unfortunately, data(), isn't compatible with @example, so this is commented -# out for now. -# @example(shape=(0,)) -@given(two_shapes=two_mutually_broadcastable_shapes, fillvalues=data()) -def test_elementwise_function_two_arg_promoted_type_promotion(func_name, - two_shapes, dtypes, - fillvalues): - assert nargs(func_name) == 2 - func = getattr(_array_module, func_name) - - (type1, type2), res_type = dtypes - dtype1 = dtype_mapping[type1] - dtype2 = dtype_mapping[type2] - res_dtype = dtype_mapping[res_type] - fillvalue1 = fillvalues.draw(scalars(just(dtype1))) - if func_name in ['bitwise_left_shift', 'bitwise_right_shift']: - fillvalue2 = fillvalues.draw(scalars(just(dtype2)).filter(lambda x: x > 0)) - else: - fillvalue2 = fillvalues.draw(scalars(just(dtype2))) - - - for i in [func, dtype1, dtype2, res_dtype]: - if isinstance(i, _array_module._UndefinedStub): - i._raise() - - shape1, shape2 = two_shapes - a1 = full(shape1, fillvalue1, dtype=dtype1) - a2 = full(shape2, fillvalue2, dtype=dtype2) - res = func(a1, a2) - - assert res.dtype == res_dtype, f"{func_name}({dtype1}, {dtype2}) promoted to {res.dtype}, should have promoted to {res_dtype} (shapes={shape1, shape2})" - -elementwise_function_one_arg_func_names = [func_name for func_name in - elementwise_functions.__all__ if - nargs(func_name) == 1] - -elementwise_function_one_arg_func_names_bool = [func_name for func_name in - elementwise_function_one_arg_func_names - if - elementwise_function_output_types[func_name] - == 'bool'] - -elementwise_function_one_arg_bool_parametrize_inputs = [(func_name, dtypes) - for func_name in elementwise_function_one_arg_func_names_bool - for dtypes in input_types[elementwise_function_input_types[func_name]]] -elementwise_function_one_arg_bool_parametrize_ids = ['-'.join((n, d)) for n, d - in elementwise_function_one_arg_bool_parametrize_inputs] - -# TODO: Extend this to all functions (not just elementwise), and handle -# functions that take more than 2 args -@pytest.mark.parametrize('func_name,dtype_name', - elementwise_function_one_arg_bool_parametrize_inputs, - ids=elementwise_function_one_arg_bool_parametrize_ids) -# The spec explicitly requires type promotion to work for shape 0 -# Unfortunately, data(), isn't compatible with @example, so this is commented -# out for now. -# @example(shape=(0,)) -@given(shape=shapes, fillvalues=data()) -def test_elementwise_function_one_arg_bool(func_name, shape, - dtype_name, fillvalues): - assert nargs(func_name) == 1 - func = getattr(_array_module, func_name) - - dtype = dtype_mapping[dtype_name] - fillvalue = fillvalues.draw(scalars(just(dtype))) - - for i in [func, dtype]: - if isinstance(i, _array_module._UndefinedStub): - i._raise() - - x = full(shape, fillvalue, dtype=dtype) - res = func(x) - - assert res.dtype == bool_dtype, f"{func_name}({dtype}) returned to {res.dtype}, should have promoted to bool (shape={shape})" - -elementwise_function_one_arg_func_names_promoted = [func_name for func_name in - elementwise_function_one_arg_func_names - if - elementwise_function_output_types[func_name] - == 'promoted'] - -elementwise_function_one_arg_promoted_parametrize_inputs = [(func_name, dtypes) - for func_name in elementwise_function_one_arg_func_names_promoted - for dtypes in input_types[elementwise_function_input_types[func_name]]] -elementwise_function_one_arg_promoted_parametrize_ids = ['-'.join((n, d)) for n, d - in elementwise_function_one_arg_promoted_parametrize_inputs] - -# TODO: Extend this to all functions (not just elementwise), and handle -# functions that take more than 2 args -@pytest.mark.parametrize('func_name,dtype_name', - elementwise_function_one_arg_promoted_parametrize_inputs, - ids=elementwise_function_one_arg_promoted_parametrize_ids) -# The spec explicitly requires type promotion to work for shape 0 -# Unfortunately, data(), isn't compatible with @example, so this is commented -# out for now. -# @example(shape=(0,)) -@given(shape=shapes, fillvalues=data()) -def test_elementwise_function_one_arg_type_promotion(func_name, shape, - dtype_name, fillvalues): - assert nargs(func_name) == 1 - func = getattr(_array_module, func_name) - - dtype = dtype_mapping[dtype_name] - fillvalue = fillvalues.draw(scalars(just(dtype))) - - for i in [func, dtype]: - if isinstance(i, _array_module._UndefinedStub): - i._raise() - - x = full(shape, fillvalue, dtype=dtype) - res = func(x) - - assert res.dtype == dtype, f"{func_name}({dtype}) returned to {res.dtype}, should have promoted to {dtype} (shape={shape})" - -unary_operators_promoted = [unary_op_name for unary_op_name in sorted(unary_operators) - if elementwise_function_output_types[operators_to_functions[unary_op_name]] == 'promoted'] -operator_one_arg_promoted_parametrize_inputs = [(unary_op_name, dtypes) - for unary_op_name in unary_operators_promoted - for dtypes in input_types[elementwise_function_input_types[operators_to_functions[unary_op_name]]] - ] -operator_one_arg_promoted_parametrize_ids = ['-'.join((n, d)) for n, d - in operator_one_arg_promoted_parametrize_inputs] - - -# TODO: Extend this to all functions (not just elementwise), and handle -# functions that take more than 2 args -@pytest.mark.parametrize('unary_op_name,dtype_name', - operator_one_arg_promoted_parametrize_inputs, - ids=operator_one_arg_promoted_parametrize_ids) -# The spec explicitly requires type promotion to work for shape 0 -# Unfortunately, data(), isn't compatible with @example, so this is commented -# out for now. -# @example(shape=(0,)) -@given(shape=shapes, fillvalues=data()) -def test_operator_one_arg_type_promotion(unary_op_name, shape, dtype_name, fillvalues): - unary_op = unary_operators[unary_op_name] - - dtype = dtype_mapping[dtype_name] - fillvalue = fillvalues.draw(scalars(just(dtype))) - - if isinstance(dtype, _array_module._UndefinedStub): - dtype._raise() - - a = full(shape, fillvalue, dtype=dtype) - - get_locals = lambda: dict(a=a) - - if unary_op_name == '__abs__': - # This is the built-in abs(), not the array module's abs() - expression = 'abs(a)' + +bitwise_shift_funcs = [ + 'bitwise_left_shift', + 'bitwise_right_shift', + '__lshift__', + '__rshift__', + '__ilshift__', + '__irshift__', +] + + +DT = Type +ScalarType = Union[Type[bool], Type[int], Type[float]] + + +# We apply filters to xps.arrays() so we don't generate array elements that +# are erroneous or undefined for a function/operator. +filters = defaultdict( + lambda: lambda _: True, + {func: lambda x: ah.all(x > 0) for func in bitwise_shift_funcs}, +) + + +def make_id( + func_name: str, in_dtypes: Tuple[Union[DT, ScalarType], ...], out_dtype: DT +) -> str: + f_in_dtypes = [] + for dtype in in_dtypes: + try: + f_in_dtypes.append(dh.dtype_to_name[dtype]) + except KeyError: + # i.e. dtype is bool, int, or float + f_in_dtypes.append(dtype.__name__) + f_args = ', '.join(f_in_dtypes) + f_out_dtype = dh.dtype_to_name[out_dtype] + return f'{func_name}({f_args}) -> {f_out_dtype}' + + +func_params: List[Tuple[str, Tuple[DT, ...], DT]] = [] +for func_name in elementwise_functions.__all__: + valid_in_dtypes = dh.func_in_dtypes[func_name] + ndtypes = nargs(func_name) + if ndtypes == 1: + for in_dtype in valid_in_dtypes: + out_dtype = xp.bool if dh.func_returns_bool[func_name] else in_dtype + p = pytest.param( + func_name, + (in_dtype,), + out_dtype, + id=make_id(func_name, (in_dtype,), out_dtype), + ) + func_params.append(p) + elif ndtypes == 2: + for (in_dtype1, in_dtype2), promoted_dtype in dh.promotion_table.items(): + if in_dtype1 in valid_in_dtypes and in_dtype2 in valid_in_dtypes: + out_dtype = ( + xp.bool if dh.func_returns_bool[func_name] else promoted_dtype + ) + p = pytest.param( + func_name, + (in_dtype1, in_dtype2), + out_dtype, + id=make_id(func_name, (in_dtype1, in_dtype2), out_dtype), + ) + func_params.append(p) else: - expression = f'{unary_op} a' - res = eval(expression, get_locals()) - - assert res.dtype == dtype, f"{unary_op}({dtype}) returned to {res.dtype}, should have promoted to {dtype} (shape={shape})" - -# Note: the boolean binary operators do not have reversed or in-place variants -binary_operators_bool = [binary_op_name for binary_op_name in sorted(set(binary_operators) - {'__matmul__'}) - if elementwise_function_output_types[operators_to_functions[binary_op_name]] == 'bool'] -operator_two_arg_bool_parametrize_inputs = [(binary_op_name, dtypes) - for binary_op_name in binary_operators_bool - for dtypes in promotion_table.keys() - if all(d in input_types[elementwise_function_input_types[operators_to_functions[binary_op_name]]] for d in dtypes) - ] -operator_two_arg_bool_parametrize_ids = ['-'.join((n, d1, d2)) for n, (d1, d2) - in operator_two_arg_bool_parametrize_inputs] - -@pytest.mark.parametrize('binary_op_name,dtypes', - operator_two_arg_bool_parametrize_inputs, - ids=operator_two_arg_bool_parametrize_ids) -@given(two_shapes=two_mutually_broadcastable_shapes, fillvalues=data()) -def test_operator_two_arg_bool_promotion(binary_op_name, dtypes, two_shapes, - fillvalues): - binary_op = binary_operators[binary_op_name] - - type1, type2 = dtypes - dtype1 = dtype_mapping[type1] - dtype2 = dtype_mapping[type2] - fillvalue1 = fillvalues.draw(scalars(just(dtype1))) - fillvalue2 = fillvalues.draw(scalars(just(dtype2))) - - for i in [dtype1, dtype2]: - if isinstance(i, _array_module._UndefinedStub): - i._raise() - - shape1, shape2 = two_shapes - a1 = full(shape1, fillvalue1, dtype=dtype1) - a2 = full(shape2, fillvalue2, dtype=dtype2) - - get_locals = lambda: dict(a1=a1, a2=a2) - expression = f'a1 {binary_op} a2' - res = eval(expression, get_locals()) - - assert res.dtype == bool_dtype, f"{dtype1} {binary_op} {dtype2} promoted to {res.dtype}, should have promoted to bool (shape={shape1, shape2})" - -binary_operators_promoted = [binary_op_name for binary_op_name in sorted(set(binary_operators) - {'__matmul__'}) - if elementwise_function_output_types[operators_to_functions[binary_op_name]] == 'promoted'] -operator_two_arg_promoted_parametrize_inputs = [(binary_op_name, dtypes) - for binary_op_name in binary_operators_promoted - for dtypes in promotion_table.items() - if all(d in input_types[elementwise_function_input_types[operators_to_functions[binary_op_name]]] for d in dtypes[0]) - ] -operator_two_arg_promoted_parametrize_ids = ['-'.join((n, d1, d2)) for n, ((d1, d2), _) - in operator_two_arg_promoted_parametrize_inputs] - -@pytest.mark.parametrize('binary_op_name,dtypes', - operator_two_arg_promoted_parametrize_inputs, - ids=operator_two_arg_promoted_parametrize_ids) -@given(two_shapes=two_mutually_broadcastable_shapes, fillvalues=data()) -def test_operator_two_arg_promoted_promotion(binary_op_name, dtypes, two_shapes, - fillvalues): - binary_op = binary_operators[binary_op_name] - - (type1, type2), res_type = dtypes - dtype1 = dtype_mapping[type1] - dtype2 = dtype_mapping[type2] - res_dtype = dtype_mapping[res_type] - fillvalue1 = fillvalues.draw(scalars(just(dtype1))) - if binary_op_name in ['>>', '<<']: - fillvalue2 = fillvalues.draw(scalars(just(dtype2)).filter(lambda x: x > 0)) + raise NotImplementedError() + + +@pytest.mark.parametrize('func_name, in_dtypes, out_dtype', func_params) +@given(data=st.data()) +def test_func_promotion(func_name, in_dtypes, out_dtype, data): + func = getattr(xp, func_name) + x_filter = filters[func_name] + if len(in_dtypes) == 1: + x = data.draw( + xps.arrays(dtype=in_dtypes[0], shape=hh.shapes).filter(x_filter), label='x' + ) + out = func(x) else: - fillvalue2 = fillvalues.draw(scalars(just(dtype2))) - - - for i in [dtype1, dtype2, res_dtype]: - if isinstance(i, _array_module._UndefinedStub): - i._raise() - - shape1, shape2 = two_shapes - a1 = full(shape1, fillvalue1, dtype=dtype1) - a2 = full(shape2, fillvalue2, dtype=dtype2) - - get_locals = lambda: dict(a1=a1, a2=a2) - expression = f'a1 {binary_op} a2' - res = eval(expression, get_locals()) - - assert res.dtype == res_dtype, f"{dtype1} {binary_op} {dtype2} promoted to {res.dtype}, should have promoted to {res_dtype} (shape={shape1, shape2})" - -operator_inplace_two_arg_promoted_parametrize_inputs = [(binary_op, dtypes) for binary_op, dtypes in operator_two_arg_promoted_parametrize_inputs - if dtypes[0][0] == dtypes[1]] -operator_inplace_two_arg_promoted_parametrize_ids = ['-'.join((n[:2] + 'i' + n[2:], d1, d2)) for n, ((d1, d2), _) - in operator_inplace_two_arg_promoted_parametrize_inputs] - -@pytest.mark.parametrize('binary_op_name,dtypes', - operator_inplace_two_arg_promoted_parametrize_inputs, - ids=operator_inplace_two_arg_promoted_parametrize_ids) -@given(two_shapes=two_broadcastable_shapes(), fillvalues=data()) -def test_operator_inplace_two_arg_promoted_promotion(binary_op_name, dtypes, two_shapes, - fillvalues): - binary_op = binary_operators[binary_op_name] - - (type1, type2), res_type = dtypes - dtype1 = dtype_mapping[type1] - dtype2 = dtype_mapping[type2] - res_dtype = dtype_mapping[res_type] - fillvalue1 = fillvalues.draw(scalars(just(dtype1))) - if binary_op_name in ['>>', '<<']: - fillvalue2 = fillvalues.draw(scalars(just(dtype2)).filter(lambda x: x > 0)) + arrays = [] + shapes = data.draw( + hh.mutually_broadcastable_shapes(len(in_dtypes)), label='shapes' + ) + for i, (dtype, shape) in enumerate(zip(in_dtypes, shapes), 1): + x = data.draw( + xps.arrays(dtype=dtype, shape=shape).filter(x_filter), label=f'x{i}' + ) + arrays.append(x) + try: + out = func(*arrays) + except OverflowError: + reject() + assert out.dtype == out_dtype, f'{out.dtype=!s}, but should be {out_dtype}' + + +op_params: List[Tuple[str, str, Tuple[DT, ...], DT]] = [] +op_to_symbol = {**dh.unary_op_to_symbol, **dh.binary_op_to_symbol} +for op, symbol in op_to_symbol.items(): + if op == '__matmul__': + continue + valid_in_dtypes = dh.func_in_dtypes[op] + ndtypes = nargs(op) + if ndtypes == 1: + for in_dtype in valid_in_dtypes: + out_dtype = xp.bool if dh.func_returns_bool[op] else in_dtype + p = pytest.param( + op, + f'{symbol}x', + (in_dtype,), + out_dtype, + id=make_id(op, (in_dtype,), out_dtype), + ) + op_params.append(p) else: - fillvalue2 = fillvalues.draw(scalars(just(dtype2))) - - for i in [dtype1, dtype2, res_dtype]: - if isinstance(i, _array_module._UndefinedStub): - i._raise() - - shape1, shape2 = two_shapes - a1 = full(shape1, fillvalue1, dtype=dtype1) - a2 = full(shape2, fillvalue2, dtype=dtype2) - - get_locals = lambda: dict(a1=a1, a2=a2) - - res_locals = get_locals() - expression = f'a1 {binary_op}= a2' - exec(expression, res_locals) - res = res_locals['a1'] - - assert res.dtype == res_dtype, f"{dtype1} {binary_op}= {dtype2} promoted to {res.dtype}, should have promoted to {res_dtype} (shape={shape1, shape2})" - -scalar_promotion_parametrize_inputs = [(binary_op_name, dtype_name, scalar_type) - for binary_op_name in sorted(set(binary_operators) - {'__matmul__'}) - for dtype_name in input_types[elementwise_function_input_types[operators_to_functions[binary_op_name]]] - for scalar_type in dtypes_to_scalars[dtype_name]] - -@pytest.mark.parametrize('binary_op_name,dtype_name,scalar_type', - scalar_promotion_parametrize_inputs) -@given(shape=shapes, python_scalars=data(), fillvalues=data()) -def test_operator_scalar_promotion(binary_op_name, dtype_name, scalar_type, - shape, python_scalars, fillvalues): - """ - See https://data-apis.github.io/array-api/latest/API_specification/type_promotion.html#mixing-arrays-with-python-scalars - """ - binary_op = binary_operators[binary_op_name] - if binary_op == '@': - pytest.skip("matmul (@) is not supported for scalars") - dtype = dtype_mapping[dtype_name] - - if dtype_name in input_types['integer']: - s = python_scalars.draw(integers(*dtype_ranges[dtype])) + for (in_dtype1, in_dtype2), promoted_dtype in dh.promotion_table.items(): + if in_dtype1 in valid_in_dtypes and in_dtype2 in valid_in_dtypes: + out_dtype = xp.bool if dh.func_returns_bool[op] else promoted_dtype + p = pytest.param( + op, + f'x1 {symbol} x2', + (in_dtype1, in_dtype2), + out_dtype, + id=make_id(op, (in_dtype1, in_dtype2), out_dtype), + ) + op_params.append(p) +# We generate params for abs seperately as it does not have an associated symbol +for in_dtype in dh.func_in_dtypes['__abs__']: + p = pytest.param( + '__abs__', + 'abs(x)', + (in_dtype,), + in_dtype, + id=make_id('__abs__', (in_dtype,), in_dtype), + ) + op_params.append(p) + + +@pytest.mark.parametrize('op, expr, in_dtypes, out_dtype', op_params) +@given(data=st.data()) +def test_op_promotion(op, expr, in_dtypes, out_dtype, data): + x_filter = filters[op] + if len(in_dtypes) == 1: + x = data.draw( + xps.arrays(dtype=in_dtypes[0], shape=hh.shapes).filter(x_filter), label='x' + ) + out = eval(expr, {'x': x}) else: - s = python_scalars.draw(from_type(scalar_type)) - scalar_as_array = _array_module.asarray(s, dtype=dtype) - get_locals = lambda: dict(a=a, s=s, scalar_as_array=scalar_as_array) - - fillvalue = fillvalues.draw(scalars(just(dtype))) - a = full(shape, fillvalue, dtype=dtype) - - # As per the spec: - - # The expected behavior is then equivalent to: - # - # 1. Convert the scalar to a 0-D array with the same dtype as that of the - # array used in the expression. - # - # 2. Execute the operation for `array 0-D array` (or `0-D array - # array` if `scalar` was the left-hand argument). - - array_scalar = f'a {binary_op} s' - array_scalar_expected = f'a {binary_op} scalar_as_array' - res = eval(array_scalar, get_locals()) - expected = eval(array_scalar_expected, get_locals()) - assert_exactly_equal(res, expected) - - scalar_array = f's {binary_op} a' - scalar_array_expected = f'scalar_as_array {binary_op} a' - res = eval(scalar_array, get_locals()) - expected = eval(scalar_array_expected, get_locals()) - assert_exactly_equal(res, expected) - - # Test in-place operators - if binary_op in ['==', '!=', '<', '>', '<=', '>=']: - return - array_scalar = f'a {binary_op}= s' - array_scalar_expected = f'a {binary_op}= scalar_as_array' - a = full(shape, fillvalue, dtype=dtype) - res_locals = get_locals() - exec(array_scalar, get_locals()) - res = res_locals['a'] - a = full(shape, fillvalue, dtype=dtype) - expected_locals = get_locals() - exec(array_scalar_expected, get_locals()) - expected = expected_locals['a'] - assert_exactly_equal(res, expected) + locals_ = {} + shapes = data.draw( + hh.mutually_broadcastable_shapes(len(in_dtypes)), label='shapes' + ) + for i, (dtype, shape) in enumerate(zip(in_dtypes, shapes), 1): + locals_[f'x{i}'] = data.draw( + xps.arrays(dtype=dtype, shape=shape).filter(x_filter), label=f'x{i}' + ) + try: + out = eval(expr, locals_) + except OverflowError: + reject() + assert out.dtype == out_dtype, f'{out.dtype=!s}, but should be {out_dtype}' + + +inplace_params: List[Tuple[str, str, Tuple[DT, ...], DT]] = [] +for op, symbol in dh.inplace_op_to_symbol.items(): + if op == '__imatmul__': + continue + valid_in_dtypes = dh.func_in_dtypes[op] + for (in_dtype1, in_dtype2), promoted_dtype in dh.promotion_table.items(): + if ( + in_dtype1 == promoted_dtype + and in_dtype1 in valid_in_dtypes + and in_dtype2 in valid_in_dtypes + ): + p = pytest.param( + op, + f'x1 {symbol} x2', + (in_dtype1, in_dtype2), + promoted_dtype, + id=make_id(op, (in_dtype1, in_dtype2), promoted_dtype), + ) + inplace_params.append(p) + + +@pytest.mark.parametrize('op, expr, in_dtypes, out_dtype', inplace_params) +@given(shapes=hh.mutually_broadcastable_shapes(2), data=st.data()) +def test_inplace_op_promotion(op, expr, in_dtypes, out_dtype, shapes, data): + assume(len(shapes[0]) >= len(shapes[1])) + x_filter = filters[op] + x1 = data.draw( + xps.arrays(dtype=in_dtypes[0], shape=shapes[0]).filter(x_filter), label='x1' + ) + x2 = data.draw( + xps.arrays(dtype=in_dtypes[1], shape=shapes[1]).filter(x_filter), label='x2' + ) + locals_ = {'x1': x1, 'x2': x2} + try: + exec(expr, locals_) + except OverflowError: + reject() + x1 = locals_['x1'] + assert x1.dtype == out_dtype, f'{x1.dtype=!s}, but should be {out_dtype}' + + +op_scalar_params: List[Tuple[str, str, DT, ScalarType, DT]] = [] +for op, symbol in dh.binary_op_to_symbol.items(): + if op == '__matmul__': + continue + for in_dtype in dh.func_in_dtypes[op]: + out_dtype = xp.bool if dh.func_returns_bool[op] else in_dtype + for in_stype in dh.dtype_to_scalars[in_dtype]: + p = pytest.param( + op, + f'x {symbol} s', + in_dtype, + in_stype, + out_dtype, + id=make_id(op, (in_dtype, in_stype), out_dtype), + ) + op_scalar_params.append(p) + + +@pytest.mark.parametrize('op, expr, in_dtype, in_stype, out_dtype', op_scalar_params) +@given(data=st.data()) +def test_op_scalar_promotion(op, expr, in_dtype, in_stype, out_dtype, data): + x_filter = filters[op] + kw = {k: in_stype is float for k in ('allow_nan', 'allow_infinity')} + s = data.draw(xps.from_dtype(in_dtype, **kw).map(in_stype), label='scalar') + x = data.draw( + xps.arrays(dtype=in_dtype, shape=hh.shapes).filter(x_filter), label='x' + ) + try: + out = eval(expr, {'x': x, 's': s}) + except OverflowError: + reject() + assert out.dtype == out_dtype, f'{out.dtype=!s}, but should be {out_dtype}' + + +inplace_scalar_params: List[Tuple[str, str, DT, ScalarType]] = [] +for op, symbol in dh.inplace_op_to_symbol.items(): + if op == '__imatmul__': + continue + for dtype in dh.func_in_dtypes[op]: + for in_stype in dh.dtype_to_scalars[dtype]: + p = pytest.param( + op, + f'x {symbol} s', + dtype, + in_stype, + id=make_id(op, (dtype, in_stype), dtype), + ) + inplace_scalar_params.append(p) + + +@pytest.mark.parametrize('op, expr, dtype, in_stype', inplace_scalar_params) +@given(data=st.data()) +def test_inplace_op_scalar_promotion(op, expr, dtype, in_stype, data): + x_filter = filters[op] + kw = {k: in_stype is float for k in ('allow_nan', 'allow_infinity')} + s = data.draw(xps.from_dtype(dtype, **kw).map(in_stype), label='scalar') + x = data.draw(xps.arrays(dtype=dtype, shape=hh.shapes).filter(x_filter), label='x') + locals_ = {'x': x, 's': s} + try: + exec(expr, locals_) + except OverflowError: + reject() + x = locals_['x'] + assert x.dtype == dtype, f'{x.dtype=!s}, but should be {dtype}' if __name__ == '__main__': - for (i, j), p in promotion_table.items(): - print(f"({i}, {j}) -> {p}") + for (i, j), p in dh.promotion_table.items(): + print(f'({i}, {j}) -> {p}')